diff --git a/.gitignore b/.gitignore index 037696de4..ddf008e4c 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,7 @@ nosetests.xml # vim temp files *~ +# vscode +.vscode + *.nfs* diff --git a/cortex/quickflat/composite.py b/cortex/quickflat/composite.py index e2a969801..8fc57b5a9 100644 --- a/cortex/quickflat/composite.py +++ b/cortex/quickflat/composite.py @@ -3,12 +3,9 @@ from .. import dataset from ..database import db from ..options import config -from ..svgoverlay import get_overlay -from ..utils import get_shared_voxels, get_mapper from .utils import _get_height, _get_extents, _convert_svg_kwargs, _has_cmap, _get_images, _parse_defaults -from .utils import make_flatmap_image, _make_hatch_image, _return_pixel_pairs, get_flatmask, get_flatcache +from .utils import make_flatmap_image, _make_hatch_image, _get_fig_and_ax, get_flatmask, get_flatcache -import time """ --- Individual compositing functions --- """ @@ -20,7 +17,7 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont Parameters ---------- - fig : figure + fig : figure or ax figure into which to plot image of curvature dataview : cortex.Dataview object dataview containing data to be plotted, subject (surface identifier), and transform. @@ -112,7 +109,7 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont curv_im = (curv_im - 0.5) * contrast + brightness if extents is None: extents = _get_extents(fig) - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) cvimg = ax.imshow(curv_im, aspect='equal', extent=extents, @@ -129,7 +126,7 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True, Parameters ---------- - fig : figure + fig : figure or ax Figure into which to plot image of curvature braindata : one of: {cortex.Volume, cortex.Vertex, cortex.Dataview) Object containing containing data to be plotted, subject (surface identifier), @@ -161,11 +158,11 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True, raise TypeError('Please provide a Dataview, not a Dataset') # Generate image (2D array, maybe 3D array) im, extents = make_flatmap_image(dataview, recache=recache, pixelwise=pixelwise, sampler=sampler, - height=height, thick=thick, depth=depth) + height=height, thick=thick, depth=depth) # Check whether dataview has a cmap instance cmapdict = _has_cmap(dataview) # Plot - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) img = ax.imshow(im, aspect='equal', extent=extents, @@ -181,7 +178,7 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis Parameters ---------- - fig : figure + fig : figure or ax figure into which to plot image of curvature dataview : cortex.Dataview object dataview containing data to be plotted, subject (surface identifier), and transform. @@ -201,7 +198,6 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis img : matplotlib.image.AxesImage matplotlib axes image object for plotted data """ - from ..svgoverlay import get_overlay if extents is None: extents = _get_extents(fig) if height is None: @@ -211,7 +207,7 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis layer_kws = _parse_defaults('rois_paths') layer_kws.update(svg_kws) im = svgobject.get_texture('rois', height, labels=with_labels, shape_list=roi_list, **layer_kws) - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) img = ax.imshow(im, aspect='equal', interpolation='bicubic', @@ -226,7 +222,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar Parameters ---------- - fig : figure + fig : figure or ax figure into which to plot image of curvature dataview : cortex.Dataview object dataview containing data to be plotted, subject (surface identifier), and transform. @@ -249,7 +245,6 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar img : matplotlib.image.AxesImage matplotlib axes image object for plotted data """ - from ..svgoverlay import get_overlay svgobject = db.get_overlay(dataview.subject) svg_kws = _convert_svg_kwargs(kwargs) layer_kws = _parse_defaults('sulci_paths') @@ -257,7 +252,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar sulc = svgobject.get_texture('sulci', height, labels=with_labels, **layer_kws) if extents is None: extents = _get_extents(fig) - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) img = ax.imshow(sulc, aspect='equal', interpolation='bicubic', @@ -268,7 +263,7 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, **kwar def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4, - hatch_color=(0, 0, 0), sampler='nearest', recache=False): + hatch_color=(0, 0, 0), sampler='nearest', recache=False): """Add hatching to figure at locations specified in hatch_data Parameters @@ -303,18 +298,17 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4, ----- Possibly to add: add hatch_width, hatch_offset arguments. """ - from ..svgoverlay import get_overlay if extents is None: extents = _get_extents(fig) if height is None: height = _get_height(fig) hatchim = _make_hatch_image(hatch_data, height, sampler, recache=recache, - hatch_space=hatch_space) + hatch_space=hatch_space) hatchim[:,:,0] = hatch_color[0] hatchim[:,:,1] = hatch_color[1] hatchim[:,:,2] = hatch_color[2] - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) img = ax.imshow(hatchim, aspect="equal", interpolation="bicubic", @@ -325,7 +319,7 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4, def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0.2, 0.04), - orientation='horizontal'): + orientation='horizontal'): """Add a colorbar to a flatmap plot Parameters @@ -343,13 +337,14 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0 orientation : string 'vertical' or 'horizontal' """ + fig, _ = _get_fig_and_ax(fig) cbar = fig.add_axes(colorbar_location) fig.colorbar(cimg, cax=cbar, orientation=orientation, ticks=colorbar_ticks) return cbar def add_colorbar_2d(fig, cmap_name, colorbar_ticks, - colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12): + colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12): """Add a 2D colorbar to a flatmap plot Parameters @@ -371,6 +366,7 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks, import os cmap_dir = config.get('webgl', 'colormaps') cim = plt.imread(os.path.join(cmap_dir, cmap_name + '.png')) + fig, _ = _get_fig_and_ax(fig) fig.add_axes(colorbar_location) cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation='bilinear') cbar.axes.set_xticks(colorbar_ticks[:2]) @@ -381,7 +377,7 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks, return cbar def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_labels=False, - shape_list=None, **kwargs): + shape_list=None, **kwargs): """Add a custom data layer Parameters @@ -435,7 +431,7 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la labels=with_labels, shape_list=shape_list, **layer_kws) - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) img = ax.imshow(im, aspect="equal", interpolation="nearest", @@ -530,13 +526,13 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None, # Add line collection # (This is the most time consuming step, as it draws many lines) # print('plotting lines...') + fig, ax = _get_fig_and_ax(fig) lc = LineCollection(pix_array_scaled, transform=fig.transFigure, figure=fig, colors=color, alpha=alpha, linewidths=linewidth) - ax = fig.gca() lc_object = ax.add_collection(lc) return lc_object @@ -545,7 +541,7 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None): Parameters ---------- - fig : figure + fig : figure or ax figure to which to add cutouts name : str name of cutout shape within cutouts layer to use to crop the rest of the figure @@ -562,13 +558,12 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None): extents of figure. None defaults to previously specified extents. [unclear if it's worth it to keep this input.] """ - from ..svgoverlay import get_overlay if layers is None: layers = _get_images(fig) if height is None: height = _get_height(fig) if extents is None: - extents = _get_extents(fig) + extents = _get_extents(fig) svgobject = db.get_overlay(dataview.subject) # Set other cutouts to be invisible for co_name, co_shape in svgobject.cutouts.shapes.items(): @@ -634,7 +629,7 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None): imsize = (np.abs(np.diff(iy))[0], np.abs(np.diff(ix))[0]) # Re-set figure limits - ax = fig.gca() + fig, ax = _get_fig_and_ax(fig) ax.set_xlim(LL, RR) ax.set_ylim(BB, TT) inch_size = np.array(imsize)[::-1] / float(fig.dpi) diff --git a/cortex/quickflat/utils.py b/cortex/quickflat/utils.py index 75f06c001..1135f66a9 100644 --- a/cortex/quickflat/utils.py +++ b/cortex/quickflat/utils.py @@ -237,10 +237,23 @@ def _parse_defaults(section): defaults[k] = '{}, {}'.format(*defaults[k]) return defaults +def _get_fig_and_ax(fig): + """Get figure and current ax. Input can be either a figure or an ax.""" + import matplotlib.pyplot as plt + if isinstance(fig, plt.Axes): + ax = fig + fig = ax.figure + elif isinstance(fig, plt.Figure): + ax = fig.gca() + else: + raise ValueError("fig should be a matplotlib Figure or Axes instance.") + + return fig, ax + def _get_images(fig): """Get all images in a given matplotlib axis""" from matplotlib.image import AxesImage - ax = fig.gca() + _, ax = _get_fig_and_ax(fig) images = dict((x.get_label(), x) for x in ax.get_children() if isinstance(x, AxesImage)) return images diff --git a/cortex/quickflat/view.py b/cortex/quickflat/view.py index d291f7d17..bcdbd80d0 100644 --- a/cortex/quickflat/view.py +++ b/cortex/quickflat/view.py @@ -6,7 +6,6 @@ from .. import utils from .. import dataset -from . import utils as qutils from .utils import make_flatmap_image from . import composite @@ -88,6 +87,8 @@ def make_figure(braindata, recache=False, pixelwise=True, thick=32, sampler='nea Optional extra crosshatch-textured layer, given as (DataView, [r, g, b]) tuple. colorbar_location : tuple, optional Location of the colorbar! Not sure of what the numbers actually mean. Left, bottom, width, height, maybe? + fig : figure or ax + figure into which to plot flatmap """ from matplotlib import pyplot as plt @@ -97,13 +98,19 @@ def make_figure(braindata, recache=False, pixelwise=True, thick=32, sampler='nea if fig is None: fig_resize = True fig = plt.figure() - else: + ax = fig.add_axes((0, 0, 1, 1)) + elif isinstance(fig, plt.Figure): fig_resize = False fig = plt.figure(fig.number) - ax = fig.add_axes((0, 0, 1, 1)) + ax = fig.add_axes((0, 0, 1, 1)) + elif isinstance(fig, plt.Axes): + fig_resize = False + ax = fig + fig = ax.figure + # Add data - data_im, extents = composite.add_data(fig, dataview, pixelwise=pixelwise, thick=thick, sampler=sampler, - height=height, depth=depth, recache=recache) + data_im, extents = composite.add_data(ax, dataview, pixelwise=pixelwise, thick=thick, sampler=sampler, + height=height, depth=depth, recache=recache) layers = dict(data=data_im) # Add curvature @@ -112,8 +119,8 @@ def make_figure(braindata, recache=False, pixelwise=True, thick=32, sampler='nea if any([x in kwargs for x in ['cvmin', 'cvmax', 'cvthr']]): import warnings warnings.warn(("Use of `cvmin`, `cvmax`, and `cvthr` is deprecated! Please use \n" - "`curvature_brightness`, `curvature_contrast`, and `curvature_threshold`\n" - "to set appearance of background curvature.")) + "`curvature_brightness`, `curvature_contrast`, and `curvature_threshold`\n" + "to set appearance of background curvature.")) legacy_mode = True if ('cvmin' in kwargs) and ('cvmax' in kwargs): # Assumes that if one is specified, both are; weird case where only one is @@ -126,7 +133,7 @@ def make_figure(braindata, recache=False, pixelwise=True, thick=32, sampler='nea else: curvature_lims = 0.5 legacy_mode = False - curv_im = composite.add_curvature(fig, dataview, extents, + curv_im = composite.add_curvature(ax, dataview, extents, brightness=curvature_brightness, contrast=curvature_contrast, threshold=curvature_threshold, @@ -143,37 +150,38 @@ def make_figure(braindata, recache=False, pixelwise=True, thick=32, sampler='nea dropout_power = 20 if with_dropout is True else with_dropout if hatch_data is None: hatch_data = utils.get_dropout(dataview.subject, dataview.xfmname, - power=dropout_power) + power=dropout_power) - drop_im = composite.add_hatch(fig, hatch_data, extents=extents, height=height, - sampler=sampler) + drop_im = composite.add_hatch(ax, hatch_data, extents=extents, height=height, + sampler=sampler) layers['dropout'] = drop_im # Add extra hatching if extra_hatch is not None: hatch_data2, hatch_color = extra_hatch - hatch_im = composite.add_hatch(fig, hatch_data2, extents=extents, height=height, - sampler=sampler) + hatch_im = composite.add_hatch(ax, hatch_data2, extents=extents, height=height, + sampler=sampler) layers['hatch'] = hatch_im # Add rois if with_rois: - roi_im = composite.add_rois(fig, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor, - roifill=roifill, shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, with_labels=with_labels) + roi_im = composite.add_rois(ax, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor, + roifill=roifill, shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, + with_labels=with_labels) layers['rois'] = roi_im # Add sulci if with_sulci: - sulc_im = composite.add_sulci(fig, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor, - shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, with_labels=with_labels) + sulc_im = composite.add_sulci(ax, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor, + shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, with_labels=with_labels) layers['sulci'] = sulc_im # Add custom if extra_disp is not None: svgfile, layer = extra_disp - custom_im = composite.add_custom(fig, dataview, svgfile, layer, height=height, extents=extents, - linewidth=linewidth, linecolor=linecolor, shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, - with_labels=with_labels) + custom_im = composite.add_custom(ax, dataview, svgfile, layer, height=height, extents=extents, + linewidth=linewidth, linecolor=linecolor, shadow=shadow, labelsize=labelsize, + labelcolor=labelcolor, with_labels=with_labels) layers['custom'] = custom_im # Add connector lines btw connected vertices if with_connected_vertices: - vertex_lines = composite.add_connected_vertices(fig, dataview) + vertex_lines = composite.add_connected_vertices(ax, dataview) ax.axis('off') ax.set_xlim(extents[0], extents[1]) @@ -185,17 +193,17 @@ def make_figure(braindata, recache=False, pixelwise=True, thick=32, sampler='nea # Add (apply) cutout of flatmap if cutout is not None: - extents = composite.add_cutout(fig, cutout, dataview, layers) + extents = composite.add_cutout(ax, cutout, dataview, layers) if with_colorbar: # Allow 2D colorbars: if isinstance(dataview, dataset.view2D.Dataview2D): - colorbar = composite.add_colorbar_2d(fig, dataview.cmap, - [dataview.vmin, dataview.vmax, dataview.vmin2, dataview.vmax2]) + colorbar = composite.add_colorbar_2d(ax, dataview.cmap, + [dataview.vmin, dataview.vmax, dataview.vmin2, dataview.vmax2]) else: - colorbar = composite.add_colorbar(fig, data_im) + colorbar = composite.add_colorbar(ax, data_im) # Reset axis to main figure axis - plt.axes(ax) + plt.sca(ax) return fig