Skip to content

Commit

Permalink
Merge pull request #687 from pfriesch/axis-instance-plotting
Browse files Browse the repository at this point in the history
Specifying axis instances in plotting functions
  • Loading branch information
bmcfee committed Mar 20, 2018
2 parents 6d79124 + 701f24b commit 27b099b
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 15 deletions.
16 changes: 6 additions & 10 deletions docs/examples/plot_music_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,19 @@
# Alternative Visualization in the Time Domain
# --------------------------------------------
#
# We can also visualize the wariping path directly on our time domain signals.
# We can also visualize the warping path directly on our time domain signals.
# Red lines connect corresponding time positions in the input signals.
# (Thanks to F. Zalkow for the nice visualization.)

fig = plt.figure(figsize=(16, 8))
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 8))

# Plot x_1
plt.subplot(2, 1, 1)
librosa.display.waveplot(x_1, sr=fs)
plt.title('Slower Version $X_1$')
ax1 = plt.gca()
librosa.display.waveplot(x_1, sr=fs, ax=ax1)
ax1.set(title='Slower Version $X_1$')

# Plot x_2
plt.subplot(2, 1, 2)
librosa.display.waveplot(x_2, sr=fs)
plt.title('Slower Version $X_2$')
ax2 = plt.gca()
librosa.display.waveplot(x_2, sr=fs, ax=ax2)
ax2.set(title='Slower Version $X_2$')

plt.tight_layout()

Expand Down
28 changes: 23 additions & 5 deletions librosa/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def __envelope(x, hop):
return util.frame(x, hop_length=hop, frame_length=hop).max(axis=0)


def waveplot(y, sr=22050, max_points=5e4, x_axis='time', offset=0.0, max_sr=1000,
**kwargs):
def waveplot(y, sr=22050, max_points=5e4, x_axis='time', offset=0.0,
max_sr=1000, ax=None, **kwargs):
'''Plot the amplitude envelope of a waveform.
If `y` is monophonic, a filled curve is drawn between `[-abs(y), abs(y)]`.
Expand Down Expand Up @@ -356,6 +356,9 @@ def waveplot(y, sr=22050, max_points=5e4, x_axis='time', offset=0.0, max_sr=1000
x_axis : str {'time', 'off', 'none'} or None
If 'time', the x-axis is given time tick-marks.
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
offset : float
Horizontal offset (in time) to start the waveform plot
Expand Down Expand Up @@ -435,7 +438,7 @@ def waveplot(y, sr=22050, max_points=5e4, x_axis='time', offset=0.0, max_sr=1000
y_top = y
y_bottom = -y

axes = plt.gca()
axes = __check_axes(ax)

kwargs.setdefault('color', next(axes._get_lines.prop_cycler)['color'])

Expand All @@ -461,6 +464,7 @@ def specshow(data, x_coords=None, y_coords=None,
sr=22050, hop_length=512,
fmin=None, fmax=None,
bins_per_octave=12,
ax=None,
**kwargs):
'''Display a spectrogram/chromagram/cqt/etc.
Expand Down Expand Up @@ -542,6 +546,9 @@ def specshow(data, x_coords=None, y_coords=None,
bins_per_octave : int > 0 [scalar]
Number of bins per octave. Used for CQT frequency scale.
ax : matplotlib.axes.Axes or None
Axes to plot on instead of the default `plt.gca()`.
kwargs : additional keyword arguments
Arguments passed through to `matplotlib.pyplot.pcolormesh`.
Expand Down Expand Up @@ -674,9 +681,10 @@ def specshow(data, x_coords=None, y_coords=None,
y_coords = __mesh_coords(y_axis, y_coords, data.shape[0], **all_params)
x_coords = __mesh_coords(x_axis, x_coords, data.shape[1], **all_params)

axes = plt.gca()
axes = __check_axes(ax)
out = axes.pcolormesh(x_coords, y_coords, data, **kwargs)
plt.sci(out)
if ax is None:
plt.sci(out)

axes.set_xlim(x_coords.min(), x_coords.max())
axes.set_ylim(y_coords.min(), y_coords.max())
Expand Down Expand Up @@ -723,6 +731,16 @@ def __mesh_coords(ax_type, coords, n, **kwargs):
return coord_map[ax_type](n, **kwargs)


def __check_axes(axes):
'''Check if "axes" is an instance of an axis object. If not, use `gca`.'''
if axes is None:
axes = plt.gca()
if not isinstance(axes, plt.Axes):
raise ValueError("`axes` must be an instance of plt.Axes. "
"Found type {}".format(type(axes)))
return axes


def __scale_axes(axes, ax_type, which):
'''Set the axis scaling'''

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ def test_xaxis_none_yaxis_linear():
librosa.display.specshow(S_bin, y_axis='linear')


@image_comparison(baseline_images=['specshow_ext_axes'], extensions=['png'])
def test_specshow_ext_axes():
plt.figure()
ax_left = plt.subplot(1, 2, 1)
ax_right = plt.subplot(1, 2, 2)

# implicitly ax_right
librosa.display.specshow(S_abs, cmap='gray')
librosa.display.specshow(S_abs, cmap='magma', ax=ax_left)


@image_comparison(baseline_images=['x_none_y_log'], extensions=['png'])
def test_xaxis_none_yaxis_log():
plt.figure()
Expand Down Expand Up @@ -278,6 +289,17 @@ def test_waveplot_mono():
librosa.display.waveplot(y, sr=sr, x_axis='time')


@image_comparison(baseline_images=['waveplot_ext_axes'], extensions=['png'])
def test_waveplot_ext_axes():
plt.figure()
ax_left = plt.subplot(1, 2, 1)
ax_right = plt.subplot(1, 2, 2)

# implicitly ax_right
librosa.display.waveplot(y, color='blue')
librosa.display.waveplot(y, color='red', ax=ax_left)


@image_comparison(baseline_images=['waveplot_stereo'], extensions=['png'])
def test_waveplot_stereo():

Expand Down

0 comments on commit 27b099b

Please sign in to comment.