Skip to content

Commit

Permalink
Set default mode for plot to real if real object else magnitude
Browse files Browse the repository at this point in the history
  • Loading branch information
frankong committed Aug 21, 2019
1 parent 229ede1 commit 593656f
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions sigpy/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def __init__(
z=None,
c=None,
hide_axes=False,
mode='m',
mode=None,
title='',
interpolation='lanczos',
interpolation='nearest',
save_basename='Figure',
fps=10):
if im.ndim < 2:
Expand Down Expand Up @@ -394,6 +394,12 @@ def update_image(self):
imv = np.transpose(imv, np.argsort(np.argsort(imv_dims)))
imv = array_to_image(imv, color=self.c is not None)

if self.mode is None:
if np.isrealobj(imv):
self.mode = 'r'
else:
self.mode = 'm'

if self.mode == 'm':
imv = np.abs(imv)
elif self.mode == 'p':
Expand Down Expand Up @@ -457,7 +463,8 @@ def update_axes(self):

if (self.flips[i] == -1 and (i == self.x or
i == self.y or
i == self.z)):
i == self.z or
i == self.c)):
caption += '-'

if i == self.x:
Expand Down Expand Up @@ -510,8 +517,11 @@ def array_to_image(arr, color=False):
"""
Flattens all dimensions except the last two
Args:
arr (array): shape [z, x, y, c] if color, else [z, x, y]
"""
if color:
if color and not (arr.max() == 0 and arr.min() == 0):
arr = arr / np.abs(arr).max()

if arr.ndim == 2:
Expand All @@ -520,28 +530,29 @@ def array_to_image(arr, color=False):
return arr

if color:
ndim = 3
img_shape = arr.shape[-3:]
batch = sp.prod(arr.shape[:-3])
mshape = mosaic_shape(batch)
else:
ndim = 2

shape = arr.shape
batch = sp.prod(shape[:-ndim])
mshape = mosaic_shape(batch)
img_shape = arr.shape[-2:]
batch = sp.prod(arr.shape[:-2])
mshape = mosaic_shape(batch)

if sp.prod(mshape) == batch:
img = arr.reshape((batch, ) + shape[-ndim:])
img = arr.reshape((batch, ) + img_shape)
else:
img = np.zeros((sp.prod(mshape), ) + shape[-ndim:], dtype=arr.dtype)
img[:batch, ...] = arr.reshape((batch, ) + shape[-ndim:])
img = np.zeros((sp.prod(mshape), ) + img_shape, dtype=arr.dtype)
img[:batch, ...] = arr.reshape((batch, ) + img_shape)

img = img.reshape(mshape + shape[-ndim:])
img = img.reshape(mshape + img_shape)
if color:
img = np.transpose(img, (0, 2, 1, 3, 4))
img = img.reshape(
(shape[-3] * mshape[-2], shape[-2] * mshape[-1], shape[-1]))
img = img.reshape((img_shape[0] * mshape[0],
img_shape[1] * mshape[1], 3))
else:
img = np.transpose(img, (0, 2, 1, 3))
img = img.reshape((shape[-2] * mshape[-2], shape[-1] * mshape[-1]))
img = img.reshape((img_shape[0] * mshape[0],
img_shape[1] * mshape[1]))

return img

Expand Down

0 comments on commit 593656f

Please sign in to comment.