From 65dc92d27cb8fb58a7604343abf0fad652b2e502 Mon Sep 17 00:00:00 2001 From: Fernando Date: Thu, 15 Oct 2020 10:41:39 +0100 Subject: [PATCH 1/2] Add some visualization support Update docstrings Add plotting test Add matplotlib to tox dependencies --- requirements_dev.txt | 1 + setup.py | 3 + tests/data/test_subject.py | 10 +++ torchio/data/image.py | 54 ++++++++---- torchio/data/subject.py | 12 +++ torchio/datasets/fpg.py | 166 +++++++++++++++++++++++++++++++++++++ torchio/visualization.py | 99 ++++++++++++++++++++++ tox.ini | 1 + 8 files changed, 331 insertions(+), 15 deletions(-) create mode 100644 torchio/visualization.py diff --git a/requirements_dev.txt b/requirements_dev.txt index 409765798..79df3c8f8 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,6 +4,7 @@ codecov coverage coveralls flake8 +matplotlib mypy pip pre-commit diff --git a/setup.py b/setup.py index 3af8bbcbd..97712ffcb 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,9 @@ 'torchio-transform=torchio.cli:apply_transform', ], }, + extras_require={ + 'plot': ['matplotlib', 'seaborn'], + }, install_requires=requirements, license='MIT license', long_description=readme + '\n\n' + history, diff --git a/tests/data/test_subject.py b/tests/data/test_subject.py index 72adad4a6..cf071c515 100644 --- a/tests/data/test_subject.py +++ b/tests/data/test_subject.py @@ -48,3 +48,13 @@ def test_inconsistent_spatial_shape(self): ) with self.assertRaises(RuntimeError): subject.spatial_shape + + def test_plot(self): + self.sample.plot( + show=False, + output_path=self.dir / 'figure.png', + cmap_dict=dict( + t2='viridis', + label={0: 'yellow', 1: 'blue'}, + ), + ) diff --git a/torchio/data/image.py b/torchio/data/image.py index 06df1b7a0..d62e70403 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -117,8 +117,8 @@ def __init__( raise ValueError('A value for path or tensor must be given') self._loaded = False - tensor = self.parse_tensor(tensor) - affine = self.parse_affine(affine) + tensor = self._parse_tensor(tensor) + affine = self._parse_affine(affine) if tensor is not None: self[DATA] = tensor self[AFFINE] = affine @@ -173,65 +173,84 @@ def __copy__(self): return self.__class__(**kwargs) @property - def data(self): + def data(self) -> torch.Tensor: + """Tensor data. Same as :py:class:`Image.tensor`.""" return self[DATA] @property - def tensor(self): + def tensor(self) -> torch.Tensor: + """Tensor data. Same as :py:class:`Image.data`.""" return self.data @property - def affine(self): + def affine(self) -> np.ndarray: + """Affine matrix to transform voxel indices into world coordinates.""" return self[AFFINE] @property - def type(self): + def type(self) -> str: return self[TYPE] @property def shape(self) -> Tuple[int, int, int, int]: + """Tensor shape as :math:`(C, W, H, D)`.""" return tuple(self.data.shape) @property def spatial_shape(self) -> TypeTripletInt: + """Tensor spatial shape as :math:`(W, H, D)`.""" return self.shape[1:] - def check_is_2d(self): + def check_is_2d(self) -> None: if not self.is_2d(): message = f'Image is not 2D. Spatial shape: {self.spatial_shape}' raise RuntimeError(message) @property def height(self) -> int: + """Image height, if 2D.""" self.check_is_2d() return self.spatial_shape[1] @property def width(self) -> int: + """Image width, if 2D.""" self.check_is_2d() return self.spatial_shape[0] @property - def orientation(self): + def orientation(self) -> Tuple[str, str, str]: + """Orientation codes.""" return nib.aff2axcodes(self.affine) @property - def spacing(self): + def spacing(self) -> Tuple[float, float, float]: + """Voxel spacing in mm.""" _, spacing = get_rotation_and_spacing_from_affine(self.affine) return tuple(spacing) @property - def memory(self): + def memory(self) -> float: + """Number of Bytes that the tensor takes in the RAM.""" return np.prod(self.shape) * 4 # float32, i.e. 4 bytes per voxel + @property + def bounds(self) -> np.ndarray: + """Position of centers of voxels in smallest and largest coordinates.""" + ini = 0, 0, 0 + fin = np.array(self.spatial_shape) - 1 + point_ini = nib.affines.apply_affine(self.affine, ini) + point_fin = nib.affines.apply_affine(self.affine, fin) + return np.array((point_ini, point_fin)) + def axis_name_to_index(self, axis: str): """Convert an axis name to an axis index. Args: axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``, - ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case versions - and first letters are also valid, as only the first letter will be - used. + ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case + versions and first letters are also valid, as only the first + letter will be used. .. note:: If you are working with animals, you should probably use ``'Superior'``, ``'Inferior'``, ``'Anterior'`` and ``'Posterior'`` @@ -323,7 +342,7 @@ def _parse_path( else: return [self._parse_single_path(p) for p in path] - def parse_tensor(self, tensor: TypeData) -> torch.Tensor: + def _parse_tensor(self, tensor: TypeData) -> torch.Tensor: if tensor is None: return None if isinstance(tensor, np.ndarray): @@ -340,7 +359,7 @@ def parse_tensor_shape(self, tensor: torch.Tensor) -> torch.Tensor: return ensure_4d(tensor) @staticmethod - def parse_affine(affine: np.ndarray) -> np.ndarray: + def _parse_affine(affine: np.ndarray) -> np.ndarray: if affine is None: return np.eye(4) if not isinstance(affine, np.ndarray): @@ -439,6 +458,11 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat: def set_check_nans(self, check_nans: bool): self.check_nans = check_nans + def plot(self, **kwargs) -> None: + """Plot central slices of the image.""" + from ..visualization import plot_volume # avoid circular import + plot_volume(self, **kwargs) + def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt): new_origin = nib.affines.apply_affine(self.affine, index_ini) new_affine = self.affine.copy() diff --git a/torchio/data/subject.py b/torchio/data/subject.py index 2ec8f3e1c..464c9c4e0 100644 --- a/torchio/data/subject.py +++ b/torchio/data/subject.py @@ -69,6 +69,9 @@ def __copy__(self): new.history = self.history[:] return new + def __len__(self): + return len(self.get_images(intensity_only=False)) + @staticmethod def _parse_images(images: List[Tuple[str, Image]]) -> None: # Check that it's not empty @@ -148,10 +151,12 @@ def add_transform( self.history.append((transform.name, parameters_dict)) def load(self): + """Load images in subject.""" for image in self.get_images(intensity_only=False): image.load() def crop(self, index_ini, index_fin): + """Make a copy of the subject with a reduced field of view (patch).""" result_dict = {} for key, value in self.items(): if isinstance(value, Image): @@ -169,8 +174,15 @@ def update_attributes(self): self.__dict__.update(self) def add_image(self, image: Image, image_name: str) -> None: + """Add an image.""" self[image_name] = image self.update_attributes() def remove_image(self, image_name: str) -> None: + """Remove an image.""" del self[image_name] + + def plot(self, **kwargs) -> None: + """Plot images.""" + from ..visualization import plot_subject # avoid circular import + plot_subject(self, **kwargs) diff --git a/torchio/datasets/fpg.py b/torchio/datasets/fpg.py index a122de25d..40e51ff86 100644 --- a/torchio/datasets/fpg.py +++ b/torchio/datasets/fpg.py @@ -40,3 +40,169 @@ def __init__(self): ), } super().__init__(subject_dict) + self.gif_colors = GIF_COLORS + + +GIF_COLORS = { + 0: (0, 0, 0), + 1: (0, 0, 0), + 5: (127, 255, 212), + 12: (240, 230, 140), + 16: (176, 48, 96), + 24: (48, 176, 96), + 31: (48, 176, 96), + 32: (103, 255, 255), + 33: (103, 255, 255), + 35: (238, 186, 243), + 36: (119, 159, 176), + 37: (122, 186, 220), + 38: (122, 186, 220), + 39: (96, 204, 96), + 40: (96, 204, 96), + 41: (220, 247, 164), + 42: (220, 247, 164), + 43: (205, 62, 78), + 44: (205, 62, 78), + 45: (225, 225, 225), + 46: (225, 225, 225), + 47: (60, 60, 60), + 48: (220, 216, 20), + 49: (220, 216, 20), + 50: (196, 58, 250), + 51: (196, 58, 250), + 52: (120, 18, 134), + 53: (120, 18, 134), + 54: (255, 165, 0), + 55: (255, 165, 0), + 56: (12, 48, 255), + 57: (12, 48, 225), + 58: (236, 13, 176), + 59: (236, 13, 176), + 60: (0, 118, 14), + 61: (0, 118, 14), + 62: (165, 42, 42), + 63: (165, 42, 42), + 64: (160, 32, 240), + 65: (160, 32, 240), + 66: (56, 192, 255), + 67: (56, 192, 255), + 70: (255, 225, 225), + 72: (184, 237, 194), + 73: (180, 231, 250), + 74: (225, 183, 231), + 76: (180, 180, 180), + 77: (180, 180, 180), + 81: (245, 255, 200), + 82: (255, 230, 255), + 83: (245, 245, 245), + 84: (220, 255, 220), + 85: (220, 220, 220), + 86: (200, 255, 255), + 87: (250, 220, 200), + 89: (245, 255, 200), + 90: (255, 230, 255), + 91: (245, 245, 245), + 92: (220, 255, 220), + 93: (220, 220, 220), + 94: (200, 255, 255), + 96: (140, 125, 255), + 97: (140, 125, 255), + 101: (255, 62, 150), + 102: (255, 62, 150), + 103: (160, 82, 45), + 104: (160, 82, 45), + 105: (165, 42, 42), + 106: (165, 42, 42), + 107: (205, 91, 69), + 108: (205, 91, 69), + 109: (100, 149, 237), + 110: (100, 149, 237), + 113: (135, 206, 235), + 114: (135, 206, 235), + 115: (250, 128, 114), + 116: (250, 128, 114), + 117: (255, 255, 0), + 118: (255, 255, 0), + 119: (221, 160, 221), + 120: (221, 160, 221), + 121: (0, 238, 0), + 122: (0, 238, 0), + 123: (205, 92, 92), + 124: (205, 92, 92), + 125: (176, 48, 96), + 126: (176, 48, 96), + 129: (152, 251, 152), + 130: (152, 251, 152), + 133: (50, 205, 50), + 134: (50, 205, 50), + 135: (0, 100, 0), + 136: (0, 100, 0), + 137: (173, 216, 230), + 138: (173, 216, 230), + 139: (153, 50, 204), + 140: (153, 50, 204), + 141: (160, 32, 240), + 142: (160, 32, 240), + 143: (0, 206, 208), + 144: (0, 206, 208), + 145: (51, 50, 135), + 146: (51, 50, 135), + 147: (135, 50, 74), + 148: (135, 50, 74), + 149: (218, 112, 214), + 150: (218, 112, 214), + 151: (240, 230, 140), + 152: (240, 230, 140), + 153: (255, 255, 0), + 154: (255, 255, 0), + 155: (255, 110, 180), + 156: (255, 110, 180), + 157: (0, 255, 255), + 158: (0, 255, 255), + 161: (100, 50, 100), + 162: (100, 50, 100), + 163: (178, 34, 34), + 164: (178, 34, 34), + 165: (255, 0, 255), + 166: (255, 0, 255), + 167: (39, 64, 139), + 168: (39, 64, 139), + 169: (255, 99, 71), + 170: (255, 99, 71), + 171: (255, 69, 0), + 172: (255, 69, 0), + 173: (210, 180, 140), + 174: (210, 180, 140), + 175: (0, 255, 127), + 176: (0, 255, 127), + 177: (74, 155, 60), + 178: (74, 155, 60), + 179: (255, 215, 0), + 180: (255, 215, 0), + 181: (238, 0, 0), + 182: (238, 0, 0), + 183: (46, 139, 87), + 184: (46, 139, 87), + 185: (238, 201, 0), + 186: (238, 201, 0), + 187: (102, 205, 170), + 188: (102, 205, 170), + 191: (255, 218, 185), + 192: (255, 218, 185), + 193: (238, 130, 238), + 194: (238, 130, 238), + 195: (255, 165, 0), + 196: (255, 165, 0), + 197: (255, 192, 203), + 198: (255, 192, 203), + 199: (244, 222, 179), + 200: (244, 222, 179), + 201: (208, 32, 144), + 202: (208, 32, 144), + 203: (34, 139, 34), + 204: (34, 139, 34), + 205: (125, 255, 212), + 206: (127, 255, 212), + 207: (0, 0, 128), + 208: (0, 0, 128), +} diff --git a/torchio/visualization.py b/torchio/visualization.py new file mode 100644 index 000000000..1e96d0635 --- /dev/null +++ b/torchio/visualization.py @@ -0,0 +1,99 @@ +import numpy as np + +from .data.image import Image, LabelMap +from .data.subject import Subject +from .transforms.preprocessing.spatial.to_canonical import ToCanonical + + +def import_mpl_plt(): + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + except ImportError as e: + raise ImportError('Install matplotlib for plotting support') from e + return mpl, plt + + +def rotate(image): + return np.rot90(image) + + +def plot_volume( + image: Image, + channel=0, + axes=None, + cmap=None, + output_path=None, + show=True, + ): + _, plt = import_mpl_plt() + fig = None + if axes is None: + fig, axes = plt.subplots(1, 3) + image = ToCanonical()(image) + data = image.data[channel] + indices = np.array(data.shape) // 2 + i, j, k = indices + slice_x = rotate(data[i, :, :]) + slice_y = rotate(data[:, j, :]) + slice_z = rotate(data[:, :, k]) + kwargs = {} + is_label = isinstance(image, LabelMap) + if isinstance(cmap, dict): + slices = slice_x, slice_y, slice_z + slice_x, slice_y, slice_z = color_labels(slices, cmap) + else: + if cmap is None: + cmap = 'inferno' if is_label else 'gray' + kwargs['cmap'] = cmap + if is_label: + kwargs['interpolation'] = 'none' + x_extent, y_extent, z_extent = [tuple(b) for b in image.bounds.T] + axes[0].imshow(slice_x, extent=y_extent + z_extent, **kwargs) + axes[1].imshow(slice_y, extent=x_extent + z_extent, **kwargs) + axes[2].imshow(slice_z, extent=x_extent + y_extent, **kwargs) + plt.tight_layout() + if output_path is not None and fig is not None: + fig.savefig(output_path) + if show: + plt.show() + + +def plot_subject( + subject: Subject, + cmap_dict=None, + show=True, + output_path=None, + ): + _, plt = import_mpl_plt() + fig, axes = plt.subplots(len(subject), 3) + iterable = enumerate(subject.get_images_dict(intensity_only=False).items()) + axes_names = 'sagittal', 'coronal', 'axial' + for row, (name, image) in iterable: + row_axes = axes[row] + cmap = None + if cmap_dict is not None and name in cmap_dict: + cmap = cmap_dict[name] + plot_volume(image, axes=row_axes, show=False, cmap=cmap) + for axis, axis_name in zip(row_axes, axes_names): + axis.set_title(f'{name} ({axis_name})') + plt.tight_layout() + if output_path is not None: + fig.savefig(output_path) + if show: + plt.show() + + +def color_labels(arrays, cmap_dict): + results = [] + for array in arrays: + si, sj = array.shape + rgb = np.zeros((si, sj, 3), dtype=np.uint8) + for label, color in cmap_dict.items(): + if isinstance(color, str): + mpl, _ = import_mpl_plt() + color = mpl.colors.to_rgb(color) + color = [255 * n for n in color] + rgb[array == label] = color + results.append(rgb) + return results diff --git a/tox.ini b/tox.ini index ca3126a9f..e371986e6 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,7 @@ deps = pytest pytest-cov coveralls + matplotlib commands = pytest --cov=. --cov-report xml coverage run --source=torchio -m pytest From d85d12050128d3add4a03935f2f95923316c6f1c Mon Sep 17 00:00:00 2001 From: Fernando Date: Thu, 15 Oct 2020 14:06:58 +0100 Subject: [PATCH 2/2] Add test for image plotting --- tests/data/test_image.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/data/test_image.py b/tests/data/test_image.py index d16df1f33..f8603b4a0 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -160,3 +160,7 @@ def test_axis_name_2d(self): width_idx = image.axis_name_to_index('l') self.assertEqual(image.height, image.shape[height_idx]) self.assertEqual(image.width, image.shape[width_idx]) + + def test_plot(self): + image = self.sample.t1 + image.plot(show=False, output_path=self.dir / 'image.png')