Skip to content

Commit

Permalink
Isolated matplotlib.pyplot dependencies; fixes #690
Browse files Browse the repository at this point in the history
  • Loading branch information
bmcfee committed Apr 30, 2018
1 parent 440cf77 commit c4fa01f
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions librosa/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import warnings

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from matplotlib.axes import Axes
from matplotlib.ticker import Formatter, ScalarFormatter
from matplotlib.ticker import LogLocator, FixedLocator, MaxNLocator
from matplotlib.ticker import SymmetricalLogLocator
Expand Down Expand Up @@ -303,7 +304,7 @@ def cmap(data, robust=True, cmap_seq='magma', cmap_bool='gray_r', cmap_div='cool
data = np.atleast_1d(data)

if data.dtype == 'bool':
return plt.get_cmap(cmap_bool)
return get_cmap(cmap_bool)

data = data[np.isfinite(data)]

Expand All @@ -316,9 +317,9 @@ def cmap(data, robust=True, cmap_seq='magma', cmap_bool='gray_r', cmap_div='cool
min_val = np.percentile(data, min_p)

if min_val >= 0 or max_val <= 0:
return plt.get_cmap(cmap_seq)
return get_cmap(cmap_seq)

return plt.get_cmap(cmap_div)
return get_cmap(cmap_div)


def __envelope(x, hop):
Expand Down Expand Up @@ -683,8 +684,7 @@ def specshow(data, x_coords=None, y_coords=None,

axes = __check_axes(ax)
out = axes.pcolormesh(x_coords, y_coords, data, **kwargs)
if ax is None:
plt.sci(out)
__set_current_image(ax, out)

axes.set_xlim(x_coords.min(), x_coords.max())
axes.set_ylim(y_coords.min(), y_coords.max())
Expand All @@ -700,6 +700,18 @@ def specshow(data, x_coords=None, y_coords=None,
return axes


def __set_current_image(ax, img):
'''Helper to set the current image in pyplot mode.
If the provided `ax` is not `None`, then we assume that the user is using the object API.
In this case, the pyplot current image is not set.
'''

if ax is None:
import matplotlib.pyplot as plt
plt.sci(img)


def __mesh_coords(ax_type, coords, n, **kwargs):
'''Compute axis coordinates'''

Expand Down Expand Up @@ -734,10 +746,11 @@ def __mesh_coords(ax_type, coords, n, **kwargs):
def __check_axes(axes):
'''Check if "axes" is an instance of an axis object. If not, use `gca`.'''
if axes is None:
import matplotlib.pyplot as plt
axes = plt.gca()
if not isinstance(axes, plt.Axes):
raise ValueError("`axes` must be an instance of plt.Axes. "
"Found type {}".format(type(axes)))
elif not isinstance(axes, Axes):
raise ValueError("`axes` must be an instance of matplotlib.axes.Axes. "
"Found type(axes)={}".format(type(axes)))
return axes


Expand Down

0 comments on commit c4fa01f

Please sign in to comment.