Skip to content

Commit

Permalink
Visualization (#332)
Browse files Browse the repository at this point in the history
* Add some visualization support

Update docstrings

Add plotting test

Add matplotlib to tox dependencies

* Add test for image plotting

* Add plot method for 2D images
  • Loading branch information
fepegar committed Oct 15, 2020
1 parent bc7283c commit a8a78b1
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 15 deletions.
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ codecov
coverage
coveralls
flake8
matplotlib
mypy
pip
pre-commit
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
'torchio-transform=torchio.cli:apply_transform',
],
},
extras_require={
'plot': ['matplotlib', 'seaborn'],
},
install_requires=requirements,
license='MIT license',
long_description=readme + '\n\n' + history,
Expand Down
4 changes: 4 additions & 0 deletions tests/data/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,7 @@ def test_axis_name_2d(self):
width_idx = image.axis_name_to_index('l')
self.assertEqual(image.height, image.shape[height_idx])
self.assertEqual(image.width, image.shape[width_idx])

def test_plot(self):
image = self.sample.t1
image.plot(show=False, output_path=self.dir / 'image.png')
10 changes: 10 additions & 0 deletions tests/data/test_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ def test_inconsistent_spatial_shape(self):
)
with self.assertRaises(RuntimeError):
subject.spatial_shape

def test_plot(self):
self.sample.plot(
show=False,
output_path=self.dir / 'figure.png',
cmap_dict=dict(
t2='viridis',
label={0: 'yellow', 1: 'blue'},
),
)
63 changes: 48 additions & 15 deletions torchio/data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import humanize
import numpy as np
from PIL import Image as ImagePIL
import nibabel as nib
import SimpleITK as sitk

Expand Down Expand Up @@ -117,8 +118,8 @@ def __init__(
raise ValueError('A value for path or tensor must be given')
self._loaded = False

tensor = self.parse_tensor(tensor)
affine = self.parse_affine(affine)
tensor = self._parse_tensor(tensor)
affine = self._parse_affine(affine)
if tensor is not None:
self[DATA] = tensor
self[AFFINE] = affine
Expand Down Expand Up @@ -173,65 +174,84 @@ def __copy__(self):
return self.__class__(**kwargs)

@property
def data(self):
def data(self) -> torch.Tensor:
"""Tensor data. Same as :py:class:`Image.tensor`."""
return self[DATA]

@property
def tensor(self):
def tensor(self) -> torch.Tensor:
"""Tensor data. Same as :py:class:`Image.data`."""
return self.data

@property
def affine(self):
def affine(self) -> np.ndarray:
"""Affine matrix to transform voxel indices into world coordinates."""
return self[AFFINE]

@property
def type(self):
def type(self) -> str:
return self[TYPE]

@property
def shape(self) -> Tuple[int, int, int, int]:
"""Tensor shape as :math:`(C, W, H, D)`."""
return tuple(self.data.shape)

@property
def spatial_shape(self) -> TypeTripletInt:
"""Tensor spatial shape as :math:`(W, H, D)`."""
return self.shape[1:]

def check_is_2d(self):
def check_is_2d(self) -> None:
if not self.is_2d():
message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
raise RuntimeError(message)

@property
def height(self) -> int:
"""Image height, if 2D."""
self.check_is_2d()
return self.spatial_shape[1]

@property
def width(self) -> int:
"""Image width, if 2D."""
self.check_is_2d()
return self.spatial_shape[0]

@property
def orientation(self):
def orientation(self) -> Tuple[str, str, str]:
"""Orientation codes."""
return nib.aff2axcodes(self.affine)

@property
def spacing(self):
def spacing(self) -> Tuple[float, float, float]:
"""Voxel spacing in mm."""
_, spacing = get_rotation_and_spacing_from_affine(self.affine)
return tuple(spacing)

@property
def memory(self):
def memory(self) -> float:
"""Number of Bytes that the tensor takes in the RAM."""
return np.prod(self.shape) * 4 # float32, i.e. 4 bytes per voxel

@property
def bounds(self) -> np.ndarray:
"""Position of centers of voxels in smallest and largest coordinates."""
ini = 0, 0, 0
fin = np.array(self.spatial_shape) - 1
point_ini = nib.affines.apply_affine(self.affine, ini)
point_fin = nib.affines.apply_affine(self.affine, fin)
return np.array((point_ini, point_fin))

def axis_name_to_index(self, axis: str):
"""Convert an axis name to an axis index.
Args:
axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions
and first letters are also valid, as only the first letter will be
used.
``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
versions and first letters are also valid, as only the first
letter will be used.
.. note:: If you are working with animals, you should probably use
``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'``
Expand Down Expand Up @@ -323,7 +343,7 @@ def _parse_path(
else:
return [self._parse_single_path(p) for p in path]

def parse_tensor(self, tensor: TypeData) -> torch.Tensor:
def _parse_tensor(self, tensor: TypeData) -> torch.Tensor:
if tensor is None:
return None
if isinstance(tensor, np.ndarray):
Expand All @@ -340,7 +360,7 @@ def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor:
return ensure_4d(tensor)

@staticmethod
def parse_affine(affine: np.ndarray) -> np.ndarray:
def _parse_affine(affine: np.ndarray) -> np.ndarray:
if affine is None:
return np.eye(4)
if not isinstance(affine, np.ndarray):
Expand Down Expand Up @@ -420,6 +440,11 @@ def as_sitk(self, **kwargs) -> sitk.Image:
"""Get the image as an instance of :py:class:`sitk.Image`."""
return nib_to_sitk(self[DATA], self[AFFINE], **kwargs)

def as_pil(self):
"""Get the image as an instance of :py:class:`PIL.Image`."""
self.check_is_2d()
return ImagePIL.open(self.path)

def get_center(self, lps: bool = False) -> TypeTripletFloat:
"""Get image center in RAS+ or LPS+ coordinates.
Expand All @@ -439,6 +464,14 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat:
def set_check_nans(self, check_nans: bool):
self.check_nans = check_nans

def plot(self, **kwargs) -> None:
"""Plot image."""
if self.is_2d():
self.as_pil().show()
else:
from ..visualization import plot_volume # avoid circular import
plot_volume(self, **kwargs)

def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
new_origin = nib.affines.apply_affine(self.affine, index_ini)
new_affine = self.affine.copy()
Expand Down
12 changes: 12 additions & 0 deletions torchio/data/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __copy__(self):
new.history = self.history[:]
return new

def __len__(self):
return len(self.get_images(intensity_only=False))

@staticmethod
def _parse_images(images: List[Tuple[str, Image]]) -> None:
# Check that it's not empty
Expand Down Expand Up @@ -148,10 +151,12 @@ def add_transform(
self.history.append((transform.name, parameters_dict))

def load(self):
"""Load images in subject."""
for image in self.get_images(intensity_only=False):
image.load()

def crop(self, index_ini, index_fin):
"""Make a copy of the subject with a reduced field of view (patch)."""
result_dict = {}
for key, value in self.items():
if isinstance(value, Image):
Expand All @@ -169,8 +174,15 @@ def update_attributes(self):
self.__dict__.update(self)

def add_image(self, image: Image, image_name: str) -> None:
"""Add an image."""
self[image_name] = image
self.update_attributes()

def remove_image(self, image_name: str) -> None:
"""Remove an image."""
del self[image_name]

def plot(self, **kwargs) -> None:
"""Plot images."""
from ..visualization import plot_subject # avoid circular import
plot_subject(self, **kwargs)
Loading

0 comments on commit a8a78b1

Please sign in to comment.