Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Colorbar ticks at threshold values #2887

Merged
merged 31 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bce4b22
Set colorbar ticks at threshold values
NicolasGensollen Jun 25, 2021
c7ada75
Add tests
NicolasGensollen Jun 25, 2021
2f2f99c
Fix PEP8 issues
NicolasGensollen Jun 25, 2021
767163d
[circle full] request full build
NicolasGensollen Jun 25, 2021
86eee55
[circle full] threshold non symmetrical colorbars and refactor code
NicolasGensollen Jun 28, 2021
b669d39
[circle full] Fix PEP8
NicolasGensollen Jun 28, 2021
fad61b1
Fixes and refactoring
NicolasGensollen Jun 29, 2021
a43ec87
Add tests
NicolasGensollen Jun 29, 2021
4ea410f
Fix PEP8 issues
NicolasGensollen Jun 30, 2021
3469345
More refactoring
NicolasGensollen Jun 30, 2021
a522edb
[circle full] extend to plot_surf
NicolasGensollen Jun 30, 2021
6dd669a
[circle full] Fix
NicolasGensollen Jun 30, 2021
fcd2637
Add a whatsnew entry
NicolasGensollen Jul 1, 2021
09b15c0
Merge branch 'main' into colorbar-ticks-at-threshold
Remi-Gau Aug 4, 2023
5bcad37
rm colorbar
Remi-Gau Aug 4, 2023
59f8122
Apply suggestions from code review
Remi-Gau Aug 4, 2023
51b0fb8
rm extra code
Remi-Gau Aug 4, 2023
35215c5
rm extra code
Remi-Gau Aug 4, 2023
5f7bf3e
bring back code for assymetric colorbars
Remi-Gau Aug 4, 2023
16a0059
tests pass
Remi-Gau Aug 4, 2023
476ce8c
semantic line break
Remi-Gau Aug 5, 2023
ac1c326
fix tests
Remi-Gau Aug 5, 2023
c202298
Merge branch 'main' into colorbar-ticks-at-threshold
Remi-Gau Aug 28, 2023
42eba50
isort
Remi-Gau Aug 28, 2023
fead0f4
Merge branch 'main' into colorbar-ticks-at-threshold
Remi-Gau Sep 19, 2023
73e7b1b
add threshold
Remi-Gau Sep 20, 2023
ae4a92f
Fix test
Nov 24, 2023
3a18f5c
Allow with plot_img_on_surf
Nov 24, 2023
1b4c8f7
Merge branch 'main' into colorbar-ticks-at-threshold
ymzayek Nov 24, 2023
df7470b
Fix test 2
Nov 27, 2023
91ea227
Merge branch 'main' into colorbar-ticks-at-threshold
ymzayek Nov 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Enhancements
parameters or common lists of options for example. The standard parts are defined
in a single location (`nilearn._utils.docs.py`) which makes them easier to
maintain and update. (See `#2875 <https://github.com/nilearn/nilearn/pull/2875>`_).
- When plotting thresholded statistical maps with a colorbar, the threshold
value(s) will now be displayed as tick labels on the colorbar.
See issue `#2833 <https://github.com/nilearn/nilearn/issues/2833>`_.

Changes
-------
Expand Down
35 changes: 30 additions & 5 deletions nilearn/plotting/displays.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,11 +953,7 @@ def _show_colorbar(self, cmap, norm, threshold=None):
self._colorbar_ax.set_axis_bgcolor('w')

our_cmap = mpl_cm.get_cmap(cmap)
# edge case where the data has a single value
# yields a cryptic matplotlib error message
# when trying to plot the color bar
nb_ticks = 5 if norm.vmin != norm.vmax else 1
ticks = np.linspace(norm.vmin, norm.vmax, nb_ticks)
ticks = _get_cbar_ticks(norm.vmin, norm.vmax, offset, nb_ticks=5)
bounds = np.linspace(norm.vmin, norm.vmax, our_cmap.N)

# some colormap hacking
Expand Down Expand Up @@ -1191,6 +1187,35 @@ def savefig(self, filename, dpi=None):
edgecolor=edgecolor)


def _get_cbar_ticks(vmin, vmax, offset, nb_ticks=5):
"""Helper function for BaseSlicer."""
# edge case where the data has a single value yields a cryptic
# matplotlib error message when trying to plot the color bar
if vmin == vmax:
return np.linspace(vmin, vmax, 1)

# If a threshold is specified, we want two of the ticks to
# correspond to -thresold and +threshold on the colorbar.
# If the threshold is very small compared to vmax, we use
# a simple linspace as the result would be very difficult to see.
ticks = np.linspace(vmin, vmax, nb_ticks)
if offset is not None and offset / vmax > 0.12:
diff = [abs(abs(tick) - offset) for tick in ticks]
# Edge case where the thresholds are exactly at
# the same distance to 4 ticks
if diff.count(min(diff)) == 4:
idx_closest = np.sort(np.argpartition(diff, 4)[:4])
idx_closest = np.in1d(ticks, np.sort(ticks[idx_closest])[1:3])
else:
# Find the closest 2 ticks
idx_closest = np.sort(np.argpartition(diff, 2)[:2])
if 0 in ticks[idx_closest]:
idx_closest = np.sort(np.argpartition(diff, 3)[:3])
idx_closest = idx_closest[[0, 2]]
ticks[idx_closest] = [-offset, offset]
return ticks


###############################################################################
# class OrthoSlicer
###############################################################################
Expand Down
43 changes: 37 additions & 6 deletions nilearn/plotting/img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,44 @@ def _plot_img_with_bg(img, bg_img=None, cut_coords=None,
display.title(title)
if hasattr(display, '_cbar'):
cbar = display._cbar
_crop_colorbar(cbar, cbar_vmin, cbar_vmax)
_crop_colorbar(cbar, cbar_vmin, cbar_vmax, threshold)
if output_file is not None:
display.savefig(output_file)
display.close()
display = None
return display


def _crop_colorbar(cbar, cbar_vmin, cbar_vmax):
def _get_cropped_cbar_ticks(cbar_vmin, cbar_vmax,
threshold=None, n_ticks=5):
"""Helper function for _crop_colobar.
Returns ticks for cropped colorbars.
"""
new_tick_locs = np.linspace(cbar_vmin, cbar_vmax, n_ticks)
if threshold is not None:
# Case where cbar is either all positive or all negative
if 0 <= cbar_vmin <= cbar_vmax or cbar_vmin <= cbar_vmax <= 0:
idx_closest = np.argmin([abs(abs(new_tick_locs) - threshold)
for tick in new_tick_locs])
new_tick_locs[idx_closest] = threshold
else:
# Case where we do a symmetric thresholding within an
# asymmetric cbar and both threshold values are within bounds
if cbar_vmin <= -threshold <= threshold <= cbar_vmax:
from .displays import _get_cbar_ticks
new_tick_locs = _get_cbar_ticks(
cbar_vmin, cbar_vmax, threshold,
nb_ticks=len(new_tick_locs))
# Case where one of the threshold values is out of bounds
else:
idx_closest = np.argmin([abs(new_tick_locs - threshold)
for tick in new_tick_locs])
new_tick_locs[idx_closest] = (
-threshold if threshold > cbar_vmax else threshold)
return new_tick_locs


def _crop_colorbar(cbar, cbar_vmin, cbar_vmax, threshold=None):
"""Crop a colorbar to show from cbar_vmin to cbar_vmax.
Used when symmetric_cbar=False is used.

Expand All @@ -232,11 +261,13 @@ def _crop_colorbar(cbar, cbar_vmin, cbar_vmax):
return
cbar_tick_locs = cbar.locator.locs
if cbar_vmax is None:
cbar_vmax = cbar_tick_locs.max()
cbar_vmax = cbar.norm.vmax
Remi-Gau marked this conversation as resolved.
Show resolved Hide resolved
if cbar_vmin is None:
cbar_vmin = cbar_tick_locs.min()
new_tick_locs = np.linspace(cbar_vmin, cbar_vmax,
len(cbar_tick_locs))
cbar_vmin = cbar.norm.vmin

new_tick_locs = _get_cropped_cbar_ticks(
cbar_vmin, cbar_vmax, threshold,
n_ticks=len(cbar_tick_locs))

# matplotlib >= 3.2.0 no longer normalizes axes between 0 and 1
# See https://matplotlib.org/3.2.1/api/prev_api_changes/api_changes_3.2.0.html
Expand Down
8 changes: 7 additions & 1 deletion nilearn/plotting/surf_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import matplotlib.pyplot as plt
import numpy as np
import warnings

from matplotlib import gridspec
from matplotlib.colorbar import make_axes
Expand Down Expand Up @@ -359,9 +360,14 @@ def custom_function(vertices):
ticks = np.arange(vmin, vmax + 1)
nb_ticks = len(ticks)
else:
ticks = np.linspace(vmin, vmax, nb_ticks)
from nilearn.plotting.displays import _get_cbar_ticks
ticks = _get_cbar_ticks(vmin, vmax, threshold, nb_ticks)
bounds = np.linspace(vmin, vmax, our_cmap.N)
if threshold is not None:
if cbar_tick_format == "%i" and int(threshold) != threshold:
warnings.warn("You provided a non integer threshold "
"but configured the colorbar to use "
"integer formatting.")
cmaplist = [our_cmap(i) for i in range(our_cmap.N)]
# set colors to grey for absolute values < threshold
istart = int(norm(-threshold, clip=True) * (our_cmap.N - 1))
Expand Down
62 changes: 62 additions & 0 deletions nilearn/plotting/tests/test_img_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,68 @@ def test_plot_glass_brain(testdata_3d, tmpdir):
plt.close()


functions = [plotting.plot_stat_map,
plotting.plot_img]
EXPECTED = [(i, ['-10', '-5', '0', '5', '10'])
for i in [0, 0.1, 0.9, 1]]
EXPECTED += [(i, ['-10', f'-{i}', '0', f'{i}', '10'])
for i in [1.3, 2.5, 3, 4.9, 7.5]]
EXPECTED += [(i, [f'-{i}', '-5', '0', '5', f'{i}'])
for i in [7.6, 8, 9.9]]


@pytest.mark.parametrize("plot_func, threshold, expected_ticks",
[(f, e[0], e[1])
for e in EXPECTED for f in functions])
def test_plot_symmetric_colorbar_threshold(tmp_path,
plot_func,
threshold,
expected_ticks):
img_data = np.zeros((10, 10, 10))
img_data[4:6, 2:4, 4:6] = -10
img_data[5:7, 3:7, 3:6] = 10
img = nibabel.Nifti1Image(img_data, affine=np.eye(4))
display = plot_func(img, threshold=threshold, colorbar=True)
plt.savefig(tmp_path / 'test.png')
assert([tick.get_text()
for tick in display._cbar.ax.get_yticklabels()]
== expected_ticks)
plt.close()


functions = [plotting.plot_stat_map]
EXPECTED2 = [(0, ['0', '2.5', '5', '7.5', '10'])]
EXPECTED2 += [(i, [f'{i}', '2.5', '5', '7.5', '10'])
for i in [0.1, 0.3, 1.2]]
EXPECTED2 += [(i, ['0', f'{i}', '5', '7.5', '10'])
for i in [1.3, 1.9, 2.5, 3, 3.7]]
EXPECTED2 += [(i, ['0', '2.5', f'{i}', '7.5', '10'])
for i in [3.8, 4, 5, 6.2]]
EXPECTED2 += [(i, ['0', '2.5', '5', f'{i}', '10'])
for i in [6.3, 7.5, 8, 8.7]]
EXPECTED2 += [(i, ['0', '2.5', '5', '7.5', f'{i}'])
for i in [8.8, 9, 9.9]]


@pytest.mark.parametrize("plot_func, threshold, expected_ticks",
[(f, e[0], e[1])
for e in EXPECTED2 for f in functions])
def test_plot_asymmetric_colorbar_threshold(tmp_path,
plot_func,
threshold,
expected_ticks):
img_data = np.zeros((10, 10, 10))
img_data[4:6, 2:4, 4:6] = 5
img_data[5:7, 3:7, 3:6] = 10
img = nibabel.Nifti1Image(img_data, affine=np.eye(4))
display = plot_func(img, threshold=threshold, colorbar=True)
plt.savefig(tmp_path / 'test.png')
assert([tick.get_text()
for tick in display._cbar.ax.get_yticklabels()]
== expected_ticks)
plt.close()


def test_plot_stat_map(testdata_3d):
img = testdata_3d['img']

Expand Down