diff --git a/setup.py b/setup.py index 3af8bbcbd..97712ffcb 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,9 @@ 'torchio-transform=torchio.cli:apply_transform', ], }, + extras_require={ + 'plot': ['matplotlib', 'seaborn'], + }, install_requires=requirements, license='MIT license', long_description=readme + '\n\n' + history, diff --git a/torchio/data/image.py b/torchio/data/image.py index 06df1b7a0..2bb12d97c 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -173,19 +173,19 @@ def __copy__(self): return self.__class__(**kwargs) @property - def data(self): + def data(self) -> torch.Tensor: return self[DATA] @property - def tensor(self): + def tensor(self) -> torch.Tensor: return self.data @property - def affine(self): + def affine(self) -> np.ndarray: return self[AFFINE] @property - def type(self): + def type(self) -> str: return self[TYPE] @property @@ -196,7 +196,7 @@ def shape(self) -> Tuple[int, int, int, int]: def spatial_shape(self) -> TypeTripletInt: return self.shape[1:] - def check_is_2d(self): + def check_is_2d(self) -> None: if not self.is_2d(): message = f'Image is not 2D. Spatial shape: {self.spatial_shape}' raise RuntimeError(message) @@ -212,18 +212,26 @@ def width(self) -> int: return self.spatial_shape[0] @property - def orientation(self): + def orientation(self) -> Tuple[str, str, str]: return nib.aff2axcodes(self.affine) @property - def spacing(self): + def spacing(self) -> Tuple[float, float, float]: _, spacing = get_rotation_and_spacing_from_affine(self.affine) return tuple(spacing) @property - def memory(self): + def memory(self) -> float: return np.prod(self.shape) * 4 # float32, i.e. 4 bytes per voxel + @property + def bounds(self) -> np.ndarray: + ini = 0, 0, 0 + fin = np.array(self.spatial_shape) - 1 + point_ini = nib.affines.apply_affine(self.affine, ini) + point_fin = nib.affines.apply_affine(self.affine, fin) + return np.array((point_ini, point_fin)) + def axis_name_to_index(self, axis: str): """Convert an axis name to an axis index. @@ -439,6 +447,10 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat: def set_check_nans(self, check_nans: bool): self.check_nans = check_nans + def plot(self, **kwargs) -> None: + from ..visualization import plot_image # avoid circular import + plot_image(self, **kwargs) + def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt): new_origin = nib.affines.apply_affine(self.affine, index_ini) new_affine = self.affine.copy() diff --git a/torchio/data/subject.py b/torchio/data/subject.py index 2ec8f3e1c..01392e3ff 100644 --- a/torchio/data/subject.py +++ b/torchio/data/subject.py @@ -69,6 +69,9 @@ def __copy__(self): new.history = self.history[:] return new + def __len__(self): + return len(self.get_images(intensity_only=False)) + @staticmethod def _parse_images(images: List[Tuple[str, Image]]) -> None: # Check that it's not empty @@ -174,3 +177,7 @@ def add_image(self, image: Image, image_name: str) -> None: def remove_image(self, image_name: str) -> None: del self[image_name] + + def plot(self, **kwargs) -> None: + from ..visualization import plot_subject # avoid circular import + plot_subject(self, **kwargs) diff --git a/torchio/datasets/fpg.py b/torchio/datasets/fpg.py index a122de25d..40e51ff86 100644 --- a/torchio/datasets/fpg.py +++ b/torchio/datasets/fpg.py @@ -40,3 +40,169 @@ def __init__(self): ), } super().__init__(subject_dict) + self.gif_colors = GIF_COLORS + + +GIF_COLORS = { + 0: (0, 0, 0), + 1: (0, 0, 0), + 5: (127, 255, 212), + 12: (240, 230, 140), + 16: (176, 48, 96), + 24: (48, 176, 96), + 31: (48, 176, 96), + 32: (103, 255, 255), + 33: (103, 255, 255), + 35: (238, 186, 243), + 36: (119, 159, 176), + 37: (122, 186, 220), + 38: (122, 186, 220), + 39: (96, 204, 96), + 40: (96, 204, 96), + 41: (220, 247, 164), + 42: (220, 247, 164), + 43: (205, 62, 78), + 44: (205, 62, 78), + 45: (225, 225, 225), + 46: (225, 225, 225), + 47: (60, 60, 60), + 48: (220, 216, 20), + 49: (220, 216, 20), + 50: (196, 58, 250), + 51: (196, 58, 250), + 52: (120, 18, 134), + 53: (120, 18, 134), + 54: (255, 165, 0), + 55: (255, 165, 0), + 56: (12, 48, 255), + 57: (12, 48, 225), + 58: (236, 13, 176), + 59: (236, 13, 176), + 60: (0, 118, 14), + 61: (0, 118, 14), + 62: (165, 42, 42), + 63: (165, 42, 42), + 64: (160, 32, 240), + 65: (160, 32, 240), + 66: (56, 192, 255), + 67: (56, 192, 255), + 70: (255, 225, 225), + 72: (184, 237, 194), + 73: (180, 231, 250), + 74: (225, 183, 231), + 76: (180, 180, 180), + 77: (180, 180, 180), + 81: (245, 255, 200), + 82: (255, 230, 255), + 83: (245, 245, 245), + 84: (220, 255, 220), + 85: (220, 220, 220), + 86: (200, 255, 255), + 87: (250, 220, 200), + 89: (245, 255, 200), + 90: (255, 230, 255), + 91: (245, 245, 245), + 92: (220, 255, 220), + 93: (220, 220, 220), + 94: (200, 255, 255), + 96: (140, 125, 255), + 97: (140, 125, 255), + 101: (255, 62, 150), + 102: (255, 62, 150), + 103: (160, 82, 45), + 104: (160, 82, 45), + 105: (165, 42, 42), + 106: (165, 42, 42), + 107: (205, 91, 69), + 108: (205, 91, 69), + 109: (100, 149, 237), + 110: (100, 149, 237), + 113: (135, 206, 235), + 114: (135, 206, 235), + 115: (250, 128, 114), + 116: (250, 128, 114), + 117: (255, 255, 0), + 118: (255, 255, 0), + 119: (221, 160, 221), + 120: (221, 160, 221), + 121: (0, 238, 0), + 122: (0, 238, 0), + 123: (205, 92, 92), + 124: (205, 92, 92), + 125: (176, 48, 96), + 126: (176, 48, 96), + 129: (152, 251, 152), + 130: (152, 251, 152), + 133: (50, 205, 50), + 134: (50, 205, 50), + 135: (0, 100, 0), + 136: (0, 100, 0), + 137: (173, 216, 230), + 138: (173, 216, 230), + 139: (153, 50, 204), + 140: (153, 50, 204), + 141: (160, 32, 240), + 142: (160, 32, 240), + 143: (0, 206, 208), + 144: (0, 206, 208), + 145: (51, 50, 135), + 146: (51, 50, 135), + 147: (135, 50, 74), + 148: (135, 50, 74), + 149: (218, 112, 214), + 150: (218, 112, 214), + 151: (240, 230, 140), + 152: (240, 230, 140), + 153: (255, 255, 0), + 154: (255, 255, 0), + 155: (255, 110, 180), + 156: (255, 110, 180), + 157: (0, 255, 255), + 158: (0, 255, 255), + 161: (100, 50, 100), + 162: (100, 50, 100), + 163: (178, 34, 34), + 164: (178, 34, 34), + 165: (255, 0, 255), + 166: (255, 0, 255), + 167: (39, 64, 139), + 168: (39, 64, 139), + 169: (255, 99, 71), + 170: (255, 99, 71), + 171: (255, 69, 0), + 172: (255, 69, 0), + 173: (210, 180, 140), + 174: (210, 180, 140), + 175: (0, 255, 127), + 176: (0, 255, 127), + 177: (74, 155, 60), + 178: (74, 155, 60), + 179: (255, 215, 0), + 180: (255, 215, 0), + 181: (238, 0, 0), + 182: (238, 0, 0), + 183: (46, 139, 87), + 184: (46, 139, 87), + 185: (238, 201, 0), + 186: (238, 201, 0), + 187: (102, 205, 170), + 188: (102, 205, 170), + 191: (255, 218, 185), + 192: (255, 218, 185), + 193: (238, 130, 238), + 194: (238, 130, 238), + 195: (255, 165, 0), + 196: (255, 165, 0), + 197: (255, 192, 203), + 198: (255, 192, 203), + 199: (244, 222, 179), + 200: (244, 222, 179), + 201: (208, 32, 144), + 202: (208, 32, 144), + 203: (34, 139, 34), + 204: (34, 139, 34), + 205: (125, 255, 212), + 206: (127, 255, 212), + 207: (0, 0, 128), + 208: (0, 0, 128), +} diff --git a/torchio/visualization.py b/torchio/visualization.py new file mode 100644 index 000000000..962161e65 --- /dev/null +++ b/torchio/visualization.py @@ -0,0 +1,85 @@ +import numpy as np + +from .data.image import Image, LabelMap +from .data.subject import Subject +from .transforms.preprocessing.spatial.to_canonical import ToCanonical + + +def import_pyplot(): + try: + import matplotlib.pyplot as plt + except ImportError as e: + raise ImportError('Install matplotlib for plotting support') from e + return plt + + +def rotate(image): + return np.rot90(image) + + +def plot_image( + image: Image, + channel=0, + axes=None, + show=True, + cmap=None, + ): + plt = import_pyplot() + if axes is None: + _, axes = plt.subplots(1, 3) + image = ToCanonical()(image) + data = image.data[channel] + indices = np.array(data.shape) // 2 + i, j, k = indices + slice_x = rotate(data[i, :, :]) + slice_y = rotate(data[:, j, :]) + slice_z = rotate(data[:, :, k]) + kwargs = {} + is_label = isinstance(image, LabelMap) + if isinstance(cmap, dict): + slices = slice_x, slice_y, slice_z + slice_x, slice_y, slice_z = color_labels(slices, cmap) + else: + if cmap is None: + cmap = 'inferno' if is_label else 'gray' + kwargs['cmap'] = cmap + if is_label: + kwargs['interpolation'] = 'none' + x_extent, y_extent, z_extent = [tuple(b) for b in image.bounds.T] + axes[0].imshow(slice_x, extent=y_extent + z_extent, **kwargs) + axes[1].imshow(slice_y, extent=x_extent + z_extent, **kwargs) + axes[2].imshow(slice_z, extent=x_extent + y_extent, **kwargs) + plt.tight_layout() + if show: + plt.show() + + +def plot_subject( + subject: Subject, + cmap_dict=None, + ): + plt = import_pyplot() + _, axes = plt.subplots(len(subject), 3) + iterable = enumerate(subject.get_images_dict(intensity_only=False).items()) + axes_names = 'sagittal', 'coronal', 'axial' + for row, (name, image) in iterable: + row_axes = axes[row] + cmap = None + if cmap_dict is not None and name in cmap_dict: + cmap = cmap_dict[name] + plot_image(image, axes=row_axes, show=False, cmap=cmap) + for axis, axis_name in zip(row_axes, axes_names): + axis.set_title(f'{name} ({axis_name})') + plt.tight_layout() + plt.show() + + +def color_labels(arrays, cmap_dict): + results = [] + for array in arrays: + si, sj = array.shape + rgb = np.zeros((si, sj, 3), dtype=np.uint8) + for label, value in cmap_dict.items(): + rgb[array == label] = value + results.append(rgb) + return results