Skip to content

Commit

Permalink
FIX: Better unlinking
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored and matthew-brett committed Nov 20, 2014
1 parent 8333d0a commit 85248cd
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
3 changes: 2 additions & 1 deletion nibabel/spatialimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,4 +760,5 @@ def plot(self):
consider using viewer.show() (equivalently plt.show()) to show
the figure.
"""
return OrthoSlicer3D(self.get_data(), self.get_affine())
return OrthoSlicer3D(self.get_data(), self.get_affine(),
title=self.get_filename())
4 changes: 3 additions & 1 deletion nibabel/tests/test_viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from numpy.testing.decorators import skipif
from numpy.testing import assert_array_equal

from nose.tools import assert_raises
from nose.tools import assert_raises, assert_true

matplotlib, has_mpl = optional_package('matplotlib')[:2]
needs_mpl = skipif(not has_mpl, 'These tests need matplotlib')
Expand All @@ -35,6 +35,7 @@ def test_viewer():
data = data * np.array([1., 2.]) # give it a # of volumes > 1
v = OrthoSlicer3D(data)
assert_array_equal(v.position, (0, 0, 0))
assert_true('OrthoSlicer3D' in repr(v))

# fake some events, inside and outside axes
v._on_scroll(nt('event', 'button inaxes key')('up', None, None))
Expand All @@ -49,6 +50,7 @@ def test_viewer():
v.set_volume_idx(1)
v.set_volume_idx(1) # should just pass
v.close()
v._draw() # should be safe

# non-multi-volume
v = OrthoSlicer3D(data[:, :, :, 0])
Expand Down
21 changes: 18 additions & 3 deletions nibabel/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class OrthoSlicer3D(object):
"""
# Skip doctest above b/c not all systems have mpl installed
def __init__(self, data, affine=None, axes=None, cmap='gray',
pcnt_range=(1., 99.), figsize=(8, 8)):
pcnt_range=(1., 99.), figsize=(8, 8), title=None):
"""
Parameters
----------
Expand Down Expand Up @@ -60,6 +60,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
plt, _, _ = optional_package('matplotlib.pyplot')
mpl_img, _, _ = optional_package('matplotlib.image')
mpl_patch, _, _ = optional_package('matplotlib.patches')
self._title = title
self._closed = False

data = np.asanyarray(data)
if data.ndim < 3:
Expand Down Expand Up @@ -107,6 +109,8 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
if self.n_volumes <= 1:
fig.delaxes(self._axes[3])
self._axes.pop(-1)
if self._title is not None:
fig.canvas.set_window_title(str(title))
else:
self._axes = [axes[0], axes[1], axes[2]]
if len(axes) > 3:
Expand Down Expand Up @@ -196,6 +200,14 @@ def __init__(self, data, affine=None, axes=None, cmap='gray',
self._set_position(0., 0., 0.)
self._draw()

def __repr__(self):
title = '' if self._title is None else ('%s ' % self._title)
vol = '' if self.n_volumes <= 1 else (', %s' % self.n_volumes)
r = ('<%s: %s(%s, %s, %s%s)>'
% (self.__class__.__name__, title, self._sizes[0], self._sizes[1],
self._sizes[2], vol))
return r

# User-level functions ###################################################
def show(self):
"""Show the slicer in blocking mode; convenience for ``plt.show()``
Expand All @@ -213,8 +225,9 @@ def close(self):

def _cleanup(self):
"""Clean up before closing"""
for link in self._links:
link()._unlink(self)
self._closed = True
for link in list(self._links): # make a copy before iterating
self._unlink(link())

@property
def n_volumes(self):
Expand Down Expand Up @@ -423,6 +436,8 @@ def _on_keypress(self, event):

def _draw(self):
"""Update all four (or three) plots"""
if self._closed: # make sure we don't draw when we shouldn't
return
for ii in range(3):
ax = self._axes[ii]
ax.draw_artist(self._ims[ii])
Expand Down

0 comments on commit 85248cd

Please sign in to comment.