Skip to content

Commit

Permalink
Added function to plot dtcwt coefficients
Browse files Browse the repository at this point in the history
  • Loading branch information
fbcotter committed Jun 4, 2018
1 parent 12c3941 commit 7c413a1
Showing 1 changed file with 95 additions and 1 deletion.
96 changes: 95 additions & 1 deletion plotters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import scipy.stats as stats

__author__ = "Fergal Cotter"
__version__ = "0.0.6"
__version__ = "0.0.7"
__version_info__ = tuple([int(d) for d in __version__.split(".")]) # noqa


Expand Down Expand Up @@ -565,3 +567,95 @@ def plot_axgrid(h, w, top=1, **kwargs):
gridspec_kw={'hspace': space, 'wspace': space, 'left': space,
'bottom': space, 'top': top - space, 'right': 1 - space})
return fig, axes


def plot_dtcwt(yl, yh, fig=None, f=np.abs, top=1, fmt='chw', imshow_kwargs={}):
""" Plot the dtcwt coefficients of an image on a single figure
Parameters
----------
yl : ndarray
Lowpass output
yh : list(ndarray)
Complex bandpass outputs. Can be (h, w, 6) or (6, h, w)
fig : None or matplotlib.Figure
Figure to plot to (will create one if None)
f : callable
Function to apply to highpasses to convert their outputs to real numbers
top : float
Top of the figure in relative coordinates. I.e. between 0 and 1. Set to
1 by default so plots take up full height, but can reduce this if you
want to add a title.
fmt : str
Either 'chw' or 'hwc' depending on the format of yh
"""
J = len(yh)

space = 0.02
if fig is None:
fig = plt.figure(facecolor='k')

if 'cmap' not in imshow_kwargs.keys():
imshow_kwargs['cmap'] = 'viridis'

widths = [0.5, 0.1, 1, 1, 1, 1, 1, 1]
gs = gridspec.GridSpec(J+1, 8, hspace=space*2, wspace=space,
left=space+0.05, bottom=space, top=top-space,
right=1-space, width_ratios=widths)

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
gradient = gradient.T[::-1]

# Preprocess the data
yh_disp = [f(scale) for scale in yh]

for j in range(J):
vmin = yh_disp[j].min()
vmax = yh_disp[j].max()
hist = fig.add_subplot(gs[j,0], xticks=[])
cmap = fig.add_subplot(gs[j,1], xticks=[], yticks=[])

# Plot the histogram of data
if f == np.abs:
x = np.geomspace(1, vmax+1, 50) - 1
else:
x = np.linspace(vmin, vmax, 50)
# Fit a kernel to it
density = stats.gaussian_kde(yh_disp[j].ravel())
y = density(x)
# Plot it vertically and then make it right-aligned
hist.set_xlim(y.max(), 0)
# Fill in the space between the curve and the axis
hist.fill_between(y, 0, x)

# Plot the colourmap
cmap.imshow(gradient, cmap=imshow_kwargs['cmap'], aspect='auto')

for i in range(6):
ax = fig.add_subplot(gs[j,i+2], xticks=[], yticks=[])
if fmt.lower() == 'chw':
ax.imshow(yh_disp[j][i], vmin=vmin, vmax=vmax, **imshow_kwargs)
else:
ax.imshow(yh_disp[j][:,:,i], vmin=vmin, vmax=vmax,
**imshow_kwargs)

# Plot the lowpass
vmin = yl.min()
vmax = yl.max()
hist = fig.add_subplot(gs[J,0], xticks=[])
cmap = fig.add_subplot(gs[J,1], xticks=[], yticks=[])

x = np.linspace(vmin, vmax, 50)
# Fit a kernel to it
density = stats.gaussian_kde(yl.ravel())
y = density(x)
# Plot it vertically and then make it right-aligned
hist.set_xlim(y.max(), 0)
# Fill in the space between the curve and the axis
hist.fill_between(y, 0, x)

# Plot the colourmap
cmap.imshow(gradient, cmap=imshow_kwargs['cmap'], aspect='auto')
ax = fig.add_subplot(gs[J,2], xticks=[], yticks=[])
ax.imshow(yl, **imshow_kwargs)

0 comments on commit 7c413a1

Please sign in to comment.