Skip to content
Permalink
Browse files

Enable passing an ax in quick flatmap: quickshow(fig=ax) (#325)

  • Loading branch information...
TomDLT committed Apr 5, 2019
1 parent 345123a commit 6224efda4cbd5ba668723dacf883d7c60376da66
Showing with 74 additions and 55 deletions.
  1. +3 −0 .gitignore
  2. +23 −28 cortex/quickflat/composite.py
  3. +14 −1 cortex/quickflat/utils.py
  4. +34 −26 cortex/quickflat/view.py
@@ -48,4 +48,7 @@ nosetests.xml
# vim temp files
*~

# vscode
.vscode

*.nfs*
@@ -3,12 +3,9 @@
from .. import dataset
from ..database import db
from ..options import config
from ..svgoverlay import get_overlay
from ..utils import get_shared_voxels, get_mapper
from .utils import _get_height, _get_extents, _convert_svg_kwargs, _has_cmap, _get_images, _parse_defaults
from .utils import make_flatmap_image, _make_hatch_image, _return_pixel_pairs, get_flatmask, get_flatcache
from .utils import make_flatmap_image, _make_hatch_image, _get_fig_and_ax, get_flatmask, get_flatcache

import time

""" --- Individual compositing functions --- """

@@ -20,7 +17,7 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
Parameters
----------
fig : figure
fig : figure or ax
figure into which to plot image of curvature
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
@@ -112,7 +109,7 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
curv_im = (curv_im - 0.5) * contrast + brightness
if extents is None:
extents = _get_extents(fig)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
cvimg = ax.imshow(curv_im,
aspect='equal',
extent=extents,
@@ -129,7 +126,7 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True,
Parameters
----------
fig : figure
fig : figure or ax
Figure into which to plot image of curvature
braindata : one of: {cortex.Volume, cortex.Vertex, cortex.Dataview)
Object containing containing data to be plotted, subject (surface identifier),
@@ -161,11 +158,11 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True,
raise TypeError('Please provide a Dataview, not a Dataset')
# Generate image (2D array, maybe 3D array)
im, extents = make_flatmap_image(dataview, recache=recache, pixelwise=pixelwise, sampler=sampler,
height=height, thick=thick, depth=depth)
height=height, thick=thick, depth=depth)
# Check whether dataview has a cmap instance
cmapdict = _has_cmap(dataview)
# Plot
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(im,
aspect='equal',
extent=extents,
@@ -181,7 +178,7 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
Parameters
----------
fig : figure
fig : figure or ax
figure into which to plot image of curvature
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
@@ -201,7 +198,6 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
img : matplotlib.image.AxesImage
matplotlib axes image object for plotted data
"""
from ..svgoverlay import get_overlay
if extents is None:
extents = _get_extents(fig)
if height is None:
@@ -211,7 +207,7 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
layer_kws = _parse_defaults('rois_paths')
layer_kws.update(svg_kws)
im = svgobject.get_texture('rois', height, labels=with_labels, shape_list=roi_list, **layer_kws)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(im,
aspect='equal',
interpolation='bicubic',
@@ -226,7 +222,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar
Parameters
----------
fig : figure
fig : figure or ax
figure into which to plot image of curvature
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
@@ -249,15 +245,14 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar
img : matplotlib.image.AxesImage
matplotlib axes image object for plotted data
"""
from ..svgoverlay import get_overlay
svgobject = db.get_overlay(dataview.subject)
svg_kws = _convert_svg_kwargs(kwargs)
layer_kws = _parse_defaults('sulci_paths')
layer_kws.update(svg_kws)
sulc = svgobject.get_texture('sulci', height, labels=with_labels, **layer_kws)
if extents is None:
extents = _get_extents(fig)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(sulc,
aspect='equal',
interpolation='bicubic',
@@ -268,7 +263,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar


def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
hatch_color=(0, 0, 0), sampler='nearest', recache=False):
hatch_color=(0, 0, 0), sampler='nearest', recache=False):
"""Add hatching to figure at locations specified in hatch_data
Parameters
@@ -303,18 +298,17 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
-----
Possibly to add: add hatch_width, hatch_offset arguments.
"""
from ..svgoverlay import get_overlay
if extents is None:
extents = _get_extents(fig)
if height is None:
height = _get_height(fig)
hatchim = _make_hatch_image(hatch_data, height, sampler, recache=recache,
hatch_space=hatch_space)
hatch_space=hatch_space)
hatchim[:,:,0] = hatch_color[0]
hatchim[:,:,1] = hatch_color[1]
hatchim[:,:,2] = hatch_color[2]

ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(hatchim,
aspect="equal",
interpolation="bicubic",
@@ -325,7 +319,7 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,


def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0.2, 0.04),
orientation='horizontal'):
orientation='horizontal'):
"""Add a colorbar to a flatmap plot
Parameters
@@ -343,13 +337,14 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0
orientation : string
'vertical' or 'horizontal'
"""
fig, _ = _get_fig_and_ax(fig)
cbar = fig.add_axes(colorbar_location)
fig.colorbar(cimg, cax=cbar, orientation=orientation, ticks=colorbar_ticks)
return cbar


def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12):
colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12):
"""Add a 2D colorbar to a flatmap plot
Parameters
@@ -371,6 +366,7 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
import os
cmap_dir = config.get('webgl', 'colormaps')
cim = plt.imread(os.path.join(cmap_dir, cmap_name + '.png'))
fig, _ = _get_fig_and_ax(fig)
fig.add_axes(colorbar_location)
cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation='bilinear')
cbar.axes.set_xticks(colorbar_ticks[:2])
@@ -381,7 +377,7 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
return cbar

def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_labels=False,
shape_list=None, **kwargs):
shape_list=None, **kwargs):
"""Add a custom data layer
Parameters
@@ -435,7 +431,7 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la
labels=with_labels,
shape_list=shape_list,
**layer_kws)
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
img = ax.imshow(im,
aspect="equal",
interpolation="nearest",
@@ -530,13 +526,13 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None,
# Add line collection
# (This is the most time consuming step, as it draws many lines)
# print('plotting lines...')
fig, ax = _get_fig_and_ax(fig)
lc = LineCollection(pix_array_scaled,
transform=fig.transFigure,
figure=fig,
colors=color,
alpha=alpha,
linewidths=linewidth)
ax = fig.gca()
lc_object = ax.add_collection(lc)
return lc_object

@@ -545,7 +541,7 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None):
Parameters
----------
fig : figure
fig : figure or ax
figure to which to add cutouts
name : str
name of cutout shape within cutouts layer to use to crop the rest of the figure
@@ -562,13 +558,12 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None):
extents of figure. None defaults to previously specified extents.
[unclear if it's worth it to keep this input.]
"""
from ..svgoverlay import get_overlay
if layers is None:
layers = _get_images(fig)
if height is None:
height = _get_height(fig)
if extents is None:
extents = _get_extents(fig)
extents = _get_extents(fig)
svgobject = db.get_overlay(dataview.subject)
# Set other cutouts to be invisible
for co_name, co_shape in svgobject.cutouts.shapes.items():
@@ -634,7 +629,7 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None):
imsize = (np.abs(np.diff(iy))[0], np.abs(np.diff(ix))[0])

# Re-set figure limits
ax = fig.gca()
fig, ax = _get_fig_and_ax(fig)
ax.set_xlim(LL, RR)
ax.set_ylim(BB, TT)
inch_size = np.array(imsize)[::-1] / float(fig.dpi)
@@ -237,10 +237,23 @@ def _parse_defaults(section):
defaults[k] = '{}, {}'.format(*defaults[k])
return defaults

def _get_fig_and_ax(fig):
"""Get figure and current ax. Input can be either a figure or an ax."""
import matplotlib.pyplot as plt
if isinstance(fig, plt.Axes):
ax = fig
fig = ax.figure
elif isinstance(fig, plt.Figure):
ax = fig.gca()
else:
raise ValueError("fig should be a matplotlib Figure or Axes instance.")

return fig, ax

def _get_images(fig):
"""Get all images in a given matplotlib axis"""
from matplotlib.image import AxesImage
ax = fig.gca()
_, ax = _get_fig_and_ax(fig)
images = dict((x.get_label(), x) for x in ax.get_children() if isinstance(x, AxesImage))
return images

Oops, something went wrong.

0 comments on commit 6224efd

Please sign in to comment.
You can’t perform that action at this time.