Skip to content

Commit

Permalink
Add some visualization support
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Oct 15, 2020
1 parent 66b7a26 commit f04b343
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 8 deletions.
3 changes: 3 additions & 0 deletions setup.py
Expand Up @@ -46,6 +46,9 @@
'torchio-transform=torchio.cli:apply_transform',
],
},
extras_require={
'plot': ['matplotlib', 'seaborn'],
},
install_requires=requirements,
license='MIT license',
long_description=readme + '\n\n' + history,
Expand Down
28 changes: 20 additions & 8 deletions torchio/data/image.py
Expand Up @@ -173,19 +173,19 @@ def __copy__(self):
return self.__class__(**kwargs)

@property
def data(self):
def data(self) -> torch.Tensor:
return self[DATA]

@property
def tensor(self):
def tensor(self) -> torch.Tensor:
return self.data

@property
def affine(self):
def affine(self) -> np.ndarray:
return self[AFFINE]

@property
def type(self):
def type(self) -> str:
return self[TYPE]

@property
Expand All @@ -196,7 +196,7 @@ def shape(self) -> Tuple[int, int, int, int]:
def spatial_shape(self) -> TypeTripletInt:
return self.shape[1:]

def check_is_2d(self):
def check_is_2d(self) -> None:
if not self.is_2d():
message = f'Image is not 2D. Spatial shape: {self.spatial_shape}'
raise RuntimeError(message)
Expand All @@ -212,18 +212,26 @@ def width(self) -> int:
return self.spatial_shape[0]

@property
def orientation(self):
def orientation(self) -> Tuple[str, str, str]:
return nib.aff2axcodes(self.affine)

@property
def spacing(self):
def spacing(self) -> Tuple[float, float, float]:
_, spacing = get_rotation_and_spacing_from_affine(self.affine)
return tuple(spacing)

@property
def memory(self):
def memory(self) -> float:
return np.prod(self.shape) * 4 # float32, i.e. 4 bytes per voxel

@property
def bounds(self) -> np.ndarray:
ini = 0, 0, 0
fin = np.array(self.spatial_shape) - 1
point_ini = nib.affines.apply_affine(self.affine, ini)
point_fin = nib.affines.apply_affine(self.affine, fin)
return np.array((point_ini, point_fin))

def axis_name_to_index(self, axis: str):
"""Convert an axis name to an axis index.
Expand Down Expand Up @@ -439,6 +447,10 @@ def get_center(self, lps: bool = False) -> TypeTripletFloat:
def set_check_nans(self, check_nans: bool):
self.check_nans = check_nans

def plot(self, **kwargs) -> None:
from ..visualization import plot_image # avoid circular import
plot_image(self, **kwargs)

def crop(self, index_ini: TypeTripletInt, index_fin: TypeTripletInt):
new_origin = nib.affines.apply_affine(self.affine, index_ini)
new_affine = self.affine.copy()
Expand Down
7 changes: 7 additions & 0 deletions torchio/data/subject.py
Expand Up @@ -69,6 +69,9 @@ def __copy__(self):
new.history = self.history[:]
return new

def __len__(self):
return len(self.get_images(intensity_only=False))

@staticmethod
def _parse_images(images: List[Tuple[str, Image]]) -> None:
# Check that it's not empty
Expand Down Expand Up @@ -174,3 +177,7 @@ def add_image(self, image: Image, image_name: str) -> None:

def remove_image(self, image_name: str) -> None:
del self[image_name]

def plot(self, **kwargs) -> None:
from ..visualization import plot_subject # avoid circular import
plot_subject(self, **kwargs)
166 changes: 166 additions & 0 deletions torchio/datasets/fpg.py
Expand Up @@ -40,3 +40,169 @@ def __init__(self):
),
}
super().__init__(subject_dict)
self.gif_colors = GIF_COLORS


GIF_COLORS = {
0: (0, 0, 0),
1: (0, 0, 0),
5: (127, 255, 212),
12: (240, 230, 140),
16: (176, 48, 96),
24: (48, 176, 96),
31: (48, 176, 96),
32: (103, 255, 255),
33: (103, 255, 255),
35: (238, 186, 243),
36: (119, 159, 176),
37: (122, 186, 220),
38: (122, 186, 220),
39: (96, 204, 96),
40: (96, 204, 96),
41: (220, 247, 164),
42: (220, 247, 164),
43: (205, 62, 78),
44: (205, 62, 78),
45: (225, 225, 225),
46: (225, 225, 225),
47: (60, 60, 60),
48: (220, 216, 20),
49: (220, 216, 20),
50: (196, 58, 250),
51: (196, 58, 250),
52: (120, 18, 134),
53: (120, 18, 134),
54: (255, 165, 0),
55: (255, 165, 0),
56: (12, 48, 255),
57: (12, 48, 225),
58: (236, 13, 176),
59: (236, 13, 176),
60: (0, 118, 14),
61: (0, 118, 14),
62: (165, 42, 42),
63: (165, 42, 42),
64: (160, 32, 240),
65: (160, 32, 240),
66: (56, 192, 255),
67: (56, 192, 255),
70: (255, 225, 225),
72: (184, 237, 194),
73: (180, 231, 250),
74: (225, 183, 231),
76: (180, 180, 180),
77: (180, 180, 180),
81: (245, 255, 200),
82: (255, 230, 255),
83: (245, 245, 245),
84: (220, 255, 220),
85: (220, 220, 220),
86: (200, 255, 255),
87: (250, 220, 200),
89: (245, 255, 200),
90: (255, 230, 255),
91: (245, 245, 245),
92: (220, 255, 220),
93: (220, 220, 220),
94: (200, 255, 255),
96: (140, 125, 255),
97: (140, 125, 255),
101: (255, 62, 150),
102: (255, 62, 150),
103: (160, 82, 45),
104: (160, 82, 45),
105: (165, 42, 42),
106: (165, 42, 42),
107: (205, 91, 69),
108: (205, 91, 69),
109: (100, 149, 237),
110: (100, 149, 237),
113: (135, 206, 235),
114: (135, 206, 235),
115: (250, 128, 114),
116: (250, 128, 114),
117: (255, 255, 0),
118: (255, 255, 0),
119: (221, 160, 221),
120: (221, 160, 221),
121: (0, 238, 0),
122: (0, 238, 0),
123: (205, 92, 92),
124: (205, 92, 92),
125: (176, 48, 96),
126: (176, 48, 96),
129: (152, 251, 152),
130: (152, 251, 152),
133: (50, 205, 50),
134: (50, 205, 50),
135: (0, 100, 0),
136: (0, 100, 0),
137: (173, 216, 230),
138: (173, 216, 230),
139: (153, 50, 204),
140: (153, 50, 204),
141: (160, 32, 240),
142: (160, 32, 240),
143: (0, 206, 208),
144: (0, 206, 208),
145: (51, 50, 135),
146: (51, 50, 135),
147: (135, 50, 74),
148: (135, 50, 74),
149: (218, 112, 214),
150: (218, 112, 214),
151: (240, 230, 140),
152: (240, 230, 140),
153: (255, 255, 0),
154: (255, 255, 0),
155: (255, 110, 180),
156: (255, 110, 180),
157: (0, 255, 255),
158: (0, 255, 255),
161: (100, 50, 100),
162: (100, 50, 100),
163: (178, 34, 34),
164: (178, 34, 34),
165: (255, 0, 255),
166: (255, 0, 255),
167: (39, 64, 139),
168: (39, 64, 139),
169: (255, 99, 71),
170: (255, 99, 71),
171: (255, 69, 0),
172: (255, 69, 0),
173: (210, 180, 140),
174: (210, 180, 140),
175: (0, 255, 127),
176: (0, 255, 127),
177: (74, 155, 60),
178: (74, 155, 60),
179: (255, 215, 0),
180: (255, 215, 0),
181: (238, 0, 0),
182: (238, 0, 0),
183: (46, 139, 87),
184: (46, 139, 87),
185: (238, 201, 0),
186: (238, 201, 0),
187: (102, 205, 170),
188: (102, 205, 170),
191: (255, 218, 185),
192: (255, 218, 185),
193: (238, 130, 238),
194: (238, 130, 238),
195: (255, 165, 0),
196: (255, 165, 0),
197: (255, 192, 203),
198: (255, 192, 203),
199: (244, 222, 179),
200: (244, 222, 179),
201: (208, 32, 144),
202: (208, 32, 144),
203: (34, 139, 34),
204: (34, 139, 34),
205: (125, 255, 212),
206: (127, 255, 212),
207: (0, 0, 128),
208: (0, 0, 128),
}
85 changes: 85 additions & 0 deletions torchio/visualization.py
@@ -0,0 +1,85 @@
import numpy as np

from .data.image import Image, LabelMap
from .data.subject import Subject
from .transforms.preprocessing.spatial.to_canonical import ToCanonical


def import_pyplot():
try:
import matplotlib.pyplot as plt
except ImportError as e:
raise ImportError('Install matplotlib for plotting support') from e
return plt


def rotate(image):
return np.rot90(image)


def plot_image(
image: Image,
channel=0,
axes=None,
show=True,
cmap=None,
):
plt = import_pyplot()
if axes is None:
_, axes = plt.subplots(1, 3)
image = ToCanonical()(image)
data = image.data[channel]
indices = np.array(data.shape) // 2
i, j, k = indices
slice_x = rotate(data[i, :, :])
slice_y = rotate(data[:, j, :])
slice_z = rotate(data[:, :, k])
kwargs = {}
is_label = isinstance(image, LabelMap)
if isinstance(cmap, dict):
slices = slice_x, slice_y, slice_z
slice_x, slice_y, slice_z = color_labels(slices, cmap)
else:
if cmap is None:
cmap = 'inferno' if is_label else 'gray'
kwargs['cmap'] = cmap
if is_label:
kwargs['interpolation'] = 'none'
x_extent, y_extent, z_extent = [tuple(b) for b in image.bounds.T]
axes[0].imshow(slice_x, extent=y_extent + z_extent, **kwargs)
axes[1].imshow(slice_y, extent=x_extent + z_extent, **kwargs)
axes[2].imshow(slice_z, extent=x_extent + y_extent, **kwargs)
plt.tight_layout()
if show:
plt.show()


def plot_subject(
subject: Subject,
cmap_dict=None,
):
plt = import_pyplot()
_, axes = plt.subplots(len(subject), 3)
iterable = enumerate(subject.get_images_dict(intensity_only=False).items())
axes_names = 'sagittal', 'coronal', 'axial'
for row, (name, image) in iterable:
row_axes = axes[row]
cmap = None
if cmap_dict is not None and name in cmap_dict:
cmap = cmap_dict[name]
plot_image(image, axes=row_axes, show=False, cmap=cmap)
for axis, axis_name in zip(row_axes, axes_names):
axis.set_title(f'{name} ({axis_name})')
plt.tight_layout()
plt.show()


def color_labels(arrays, cmap_dict):
results = []
for array in arrays:
si, sj = array.shape
rgb = np.zeros((si, sj, 3), dtype=np.uint8)
for label, value in cmap_dict.items():
rgb[array == label] = value
results.append(rgb)
return results

0 comments on commit f04b343

Please sign in to comment.