Skip to content

Commit

Permalink
Modify cmap_discretisation() and colour_bar_index()
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeqfu committed Oct 11, 2022
1 parent fa35748 commit dc2f70c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
36 changes: 29 additions & 7 deletions pyhelpers/ops.py
Expand Up @@ -20,6 +20,7 @@
import sys
import urllib.parse
import urllib.request
import warnings

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1498,9 +1499,21 @@ def cmap_discretisation(cmap, n_colours):
>>> plt.close()
"""

matplotlib_cm, matplotlib_colors = map(_check_dependency, ['matplotlib.cm', 'matplotlib.colors'])
try: # new version
mpl_cm, mpl_colors = map(_check_dependency, ['matplotlib.colormaps', 'matplotlib.colors'])
except ModuleNotFoundError:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
mpl_cm, mpl_colors = map(_check_dependency, ['matplotlib.cm', 'matplotlib.colors'])

cmap_ = matplotlib_cm.get_cmap(cmap) if isinstance(cmap, str) else copy.copy(cmap)
if isinstance(cmap, str):
try: # new version
cmap_ = mpl_cm[cmap]
except TypeError:
cmap_ = mpl_cm.get_cmap(cmap)
else:
assert isinstance(cmap, mpl_colors.ListedColormap)
cmap_ = copy.copy(cmap)

colours_i = np.concatenate((np.linspace(0, 1., n_colours), (0., 0., 0., 0.)))
colours_rgba = cmap_(colours_i)
Expand All @@ -1512,7 +1525,7 @@ def cmap_discretisation(cmap, n_colours):
(indices[x], colours_rgba[x - 1, ki], colours_rgba[x, ki]) for x in range(n_colours + 1)
]

colour_map = matplotlib_colors.LinearSegmentedColormap(cmap.name + '_%d' % n_colours, c_dict, 1024)
colour_map = mpl_colors.LinearSegmentedColormap(cmap.name + '_%d' % n_colours, c_dict, 1024)

return colour_map

Expand Down Expand Up @@ -1545,9 +1558,12 @@ def colour_bar_index(cmap, n_colours, labels=None, **kwargs):
**Examples**::
>>> from pyhelpers.ops import colour_bar_index
>>> import matplotlib
>>> import matplotlib.pyplot as plt
>>> import matplotlib.cm
>>> matplotlib.use('TkAgg')
>>> plt.figure(figsize=(2, 6))
>>> cbar = colour_bar_index(cmap=matplotlib.cm.get_cmap('Accent'), n_colours=5)
Expand Down Expand Up @@ -1600,15 +1616,21 @@ def colour_bar_index(cmap, n_colours, labels=None, **kwargs):
>>> plt.close(fig='all')
"""

matplotlib_cm, matplotlib_pyplot = map(_check_dependency, ['matplotlib.cm', 'matplotlib.pyplot'])
try: # new version
mpl_cm, mpl_plt = map(_check_dependency, ['matplotlib.colormaps', 'matplotlib.pyplot'])
except ModuleNotFoundError:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
mpl_cm, mpl_plt = map(_check_dependency, ['matplotlib.cm', 'matplotlib.pyplot'])

cmap = cmap_discretisation(cmap, n_colours)
# assert isinstance(cmap, mpl_cm.ListedColormap)
cmap_ = cmap_discretisation(cmap, n_colours)

mappable = matplotlib_cm.ScalarMappable(cmap=cmap)
mappable = mpl_cm.ScalarMappable(cmap=cmap_)
mappable.set_array(np.array([]))
mappable.set_clim(-0.5, n_colours + 0.5)

colour_bar = matplotlib_pyplot.colorbar(mappable=mappable, **kwargs)
colour_bar = mpl_plt.colorbar(mappable=mappable, ax=mpl_plt.gca(), **kwargs)
colour_bar.set_ticks(np.linspace(0, n_colours, n_colours))
colour_bar.set_ticklabels(range(n_colours))

Expand Down
18 changes: 13 additions & 5 deletions tests/test_ops.py
Expand Up @@ -3,6 +3,7 @@
import datetime
import os
import typing
import warnings

import matplotlib.cm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -360,20 +361,27 @@ def test_find_closest_date():
def test_cmap_discretisation():
from pyhelpers.ops import cmap_discretisation

cm_accent = cmap_discretisation(matplotlib.cm.get_cmap('Accent'), n_colours=5)
assert cm_accent.name == 'Accent_5'
with warnings.catch_warnings():
warnings.simplefilter('ignore')

cm_accent = cmap_discretisation(matplotlib.cm.get_cmap('Accent'), n_colours=5)
assert cm_accent.name == 'Accent_5'


def test_colour_bar_index():
from pyhelpers.ops import colour_bar_index

matplotlib.use('TkAgg')

plt.figure(figsize=(2, 6))

cbar = colour_bar_index(cmap=matplotlib.cm.get_cmap('Accent'), n_colours=5, labels=list('abcde'))
with warnings.catch_warnings():
warnings.simplefilter('ignore')

cbar.ax.tick_params(labelsize=14)
cbar = colour_bar_index(cmap=matplotlib.cm.get_cmap('Accent'), n_colours=5, labels=list('abcde'))
cbar.ax.tick_params(labelsize=14)

plt.close(fig='all')
plt.close(fig='all')


def test_is_network_connected():
Expand Down

0 comments on commit dc2f70c

Please sign in to comment.