Skip to content

Commit

Permalink
Enable passing an ax in quick flatmap: quickshow(fig=ax) (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDLT committed Apr 5, 2019
1 parent 345123a commit 6224efd
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 55 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -48,4 +48,7 @@ nosetests.xml
# vim temp files
*~

# vscode
.vscode

*.nfs*
51 changes: 23 additions & 28 deletions cortex/quickflat/composite.py
Expand Up @@ -3,12 +3,9 @@
from .. import dataset
from ..database import db
from ..options import config
from ..svgoverlay import get_overlay
from ..utils import get_shared_voxels, get_mapper
from .utils import _get_height, _get_extents, _convert_svg_kwargs, _has_cmap, _get_images, _parse_defaults
from .utils import make_flatmap_image, _make_hatch_image, _return_pixel_pairs, get_flatmask, get_flatcache
from .utils import make_flatmap_image, _make_hatch_image, _get_fig_and_ax, get_flatmask, get_flatcache

import time

""" --- Individual compositing functions --- """

Expand All @@ -20,7 +17,7 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
Parameters
----------
fig : figure
fig : figure or ax
figure into which to plot image of curvature
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
Expand Down Expand Up @@ -112,7 +109,7 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
curv_im = (curv_im - 0.5) * contrast + brightness
if extents is None:
extents = _get_extents(fig)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
cvimg = ax.imshow(curv_im,
aspect='equal',
extent=extents,
Expand All @@ -129,7 +126,7 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True,
Parameters
----------
fig : figure
fig : figure or ax
Figure into which to plot image of curvature
braindata : one of: {cortex.Volume, cortex.Vertex, cortex.Dataview)
Object containing containing data to be plotted, subject (surface identifier),
Expand Down Expand Up @@ -161,11 +158,11 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True,
raise TypeError('Please provide a Dataview, not a Dataset')
# Generate image (2D array, maybe 3D array)
im, extents = make_flatmap_image(dataview, recache=recache, pixelwise=pixelwise, sampler=sampler,
height=height, thick=thick, depth=depth)
height=height, thick=thick, depth=depth)
# Check whether dataview has a cmap instance
cmapdict = _has_cmap(dataview)
# Plot
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(im,
aspect='equal',
extent=extents,
Expand All @@ -181,7 +178,7 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
Parameters
----------
fig : figure
fig : figure or ax
figure into which to plot image of curvature
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
Expand All @@ -201,7 +198,6 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
img : matplotlib.image.AxesImage
matplotlib axes image object for plotted data
"""
from ..svgoverlay import get_overlay
if extents is None:
extents = _get_extents(fig)
if height is None:
Expand All @@ -211,7 +207,7 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
layer_kws = _parse_defaults('rois_paths')
layer_kws.update(svg_kws)
im = svgobject.get_texture('rois', height, labels=with_labels, shape_list=roi_list, **layer_kws)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(im,
aspect='equal',
interpolation='bicubic',
Expand All @@ -226,7 +222,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar
Parameters
----------
fig : figure
fig : figure or ax
figure into which to plot image of curvature
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
Expand All @@ -249,15 +245,14 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar
img : matplotlib.image.AxesImage
matplotlib axes image object for plotted data
"""
from ..svgoverlay import get_overlay
svgobject = db.get_overlay(dataview.subject)
svg_kws = _convert_svg_kwargs(kwargs)
layer_kws = _parse_defaults('sulci_paths')
layer_kws.update(svg_kws)
sulc = svgobject.get_texture('sulci', height, labels=with_labels, **layer_kws)
if extents is None:
extents = _get_extents(fig)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(sulc,
aspect='equal',
interpolation='bicubic',
Expand All @@ -268,7 +263,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar


def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
hatch_color=(0, 0, 0), sampler='nearest', recache=False):
hatch_color=(0, 0, 0), sampler='nearest', recache=False):
"""Add hatching to figure at locations specified in hatch_data
Parameters
Expand Down Expand Up @@ -303,18 +298,17 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
-----
Possibly to add: add hatch_width, hatch_offset arguments.
"""
from ..svgoverlay import get_overlay
if extents is None:
extents = _get_extents(fig)
if height is None:
height = _get_height(fig)
hatchim = _make_hatch_image(hatch_data, height, sampler, recache=recache,
hatch_space=hatch_space)
hatch_space=hatch_space)
hatchim[:,:,0] = hatch_color[0]
hatchim[:,:,1] = hatch_color[1]
hatchim[:,:,2] = hatch_color[2]

ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(hatchim,
aspect="equal",
interpolation="bicubic",
Expand All @@ -325,7 +319,7 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,


def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0.2, 0.04),
orientation='horizontal'):
orientation='horizontal'):
"""Add a colorbar to a flatmap plot
Parameters
Expand All @@ -343,13 +337,14 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0
orientation : string
'vertical' or 'horizontal'
"""
fig, _ = _get_fig_and_ax(fig)
cbar = fig.add_axes(colorbar_location)
fig.colorbar(cimg, cax=cbar, orientation=orientation, ticks=colorbar_ticks)
return cbar


def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12):
colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12):
"""Add a 2D colorbar to a flatmap plot
Parameters
Expand All @@ -371,6 +366,7 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
import os
cmap_dir = config.get('webgl', 'colormaps')
cim = plt.imread(os.path.join(cmap_dir, cmap_name + '.png'))
fig, _ = _get_fig_and_ax(fig)
fig.add_axes(colorbar_location)
cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation='bilinear')
cbar.axes.set_xticks(colorbar_ticks[:2])
Expand All @@ -381,7 +377,7 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
return cbar

def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_labels=False,
shape_list=None, **kwargs):
shape_list=None, **kwargs):
"""Add a custom data layer
Parameters
Expand Down Expand Up @@ -435,7 +431,7 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la
labels=with_labels,
shape_list=shape_list,
**layer_kws)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(im,
aspect="equal",
interpolation="nearest",
Expand Down Expand Up @@ -530,13 +526,13 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None,
# Add line collection
# (This is the most time consuming step, as it draws many lines)
# print('plotting lines...')
fig, ax = _get_fig_and_ax(fig)
lc = LineCollection(pix_array_scaled,
transform=fig.transFigure,
figure=fig,
colors=color,
alpha=alpha,
linewidths=linewidth)
ax = fig.gca()
lc_object = ax.add_collection(lc)
return lc_object

Expand All @@ -545,7 +541,7 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None):
Parameters
----------
fig : figure
fig : figure or ax
figure to which to add cutouts
name : str
name of cutout shape within cutouts layer to use to crop the rest of the figure
Expand All @@ -562,13 +558,12 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None):
extents of figure. None defaults to previously specified extents.
[unclear if it's worth it to keep this input.]
"""
from ..svgoverlay import get_overlay
if layers is None:
layers = _get_images(fig)
if height is None:
height = _get_height(fig)
if extents is None:
extents = _get_extents(fig)
extents = _get_extents(fig)
svgobject = db.get_overlay(dataview.subject)
# Set other cutouts to be invisible
for co_name, co_shape in svgobject.cutouts.shapes.items():
Expand Down Expand Up @@ -634,7 +629,7 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None):
imsize = (np.abs(np.diff(iy))[0], np.abs(np.diff(ix))[0])

# Re-set figure limits
ax = fig.gca()
fig, ax = _get_fig_and_ax(fig)
ax.set_xlim(LL, RR)
ax.set_ylim(BB, TT)
inch_size = np.array(imsize)[::-1] / float(fig.dpi)
Expand Down
15 changes: 14 additions & 1 deletion cortex/quickflat/utils.py
Expand Up @@ -237,10 +237,23 @@ def _parse_defaults(section):
defaults[k] = '{}, {}'.format(*defaults[k])
return defaults

def _get_fig_and_ax(fig):
"""Get figure and current ax. Input can be either a figure or an ax."""
import matplotlib.pyplot as plt
if isinstance(fig, plt.Axes):
ax = fig
fig = ax.figure
elif isinstance(fig, plt.Figure):
ax = fig.gca()
else:
raise ValueError("fig should be a matplotlib Figure or Axes instance.")

return fig, ax

def _get_images(fig):
"""Get all images in a given matplotlib axis"""
from matplotlib.image import AxesImage
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
images = dict((x.get_label(), x) for x in ax.get_children() if isinstance(x, AxesImage))
return images

Expand Down

0 comments on commit 6224efd

Please sign in to comment.