-
Notifications
You must be signed in to change notification settings - Fork 230
/
subject.py
144 lines (123 loc) · 4.73 KB
/
subject.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import copy
import pprint
from typing import Any, Dict, List, Tuple
from ..torchio import TYPE, INTENSITY
from .image import Image
class Subject(dict):
"""Class to store information about the images corresponding to a subject.
Args:
*args: If provided, a dictionary of items.
**kwargs: Items that will be added to the subject sample.
Example:
>>> import torchio
>>> from torchio import Image, Subject
>>> # One way:
>>> subject = Subject(
... one_image=Image('path_to_image.nii.gz', type=torchio.INTENSITY),
... a_segmentation=Image('path_to_seg.nii.gz', type=torchio.LABEL),
... age=45,
... name='John Doe',
... hospital='Hospital Juan Negrín',
... )
>>> # If you want to create the mapping before, or have spaces in the keys:
>>> subject_dict = {
... 'one image': Image('path_to_image.nii.gz', type=torchio.INTENSITY),
... 'a segmentation': Image('path_to_seg.nii.gz', type=torchio.LABEL),
... 'age': 45,
... 'name': 'John Doe',
... 'hospital': 'Hospital Juan Negrín',
... }
>>> Subject(subject_dict)
"""
def __init__(self, *args, **kwargs: Dict[str, Any]):
if args:
if len(args) == 1 and isinstance(args[0], dict):
kwargs.update(args[0])
else:
message = (
'Only one dictionary as positional argument is allowed')
raise ValueError(message)
super().__init__(**kwargs)
self.images = [
(k, v) for (k, v) in self.items()
if isinstance(v, Image)
]
self._parse_images(self.images)
self.__dict__.update(self) # this allows me to do e.g. subject.t1
self.history = []
def __repr__(self):
string = (
f'{self.__class__.__name__}'
f'(Keys: {tuple(self.keys())}; images: {len(self.images)})'
)
return string
@staticmethod
def _parse_images(images: List[Tuple[str, Image]]) -> None:
# Check that it's not empty
if not images:
raise ValueError('A subject without images cannot be created')
@property
def shape(self):
"""Return shape of first image in subject.
Consistency of shapes across images in the subject is checked first.
"""
self.check_consistent_shape()
image = self.get_images(intensity_only=False)[0]
return image.shape
@property
def spatial_shape(self):
"""Return spatial shape of first image in subject.
Consistency of shapes across images in the subject is checked first.
"""
return self.shape[1:]
@property
def spacing(self):
"""Return spacing of first image in subject.
Consistency of shapes across images in the subject is checked first.
"""
self.check_consistent_shape()
image = self.get_images(intensity_only=False)[0]
return image.spacing
def get_images_dict(self, intensity_only=True):
images = {}
for image_name, image in self.items():
if not isinstance(image, Image):
continue
if intensity_only and not image[TYPE] == INTENSITY:
continue
images[image_name] = image
return images
def get_images(self, intensity_only=True):
images_dict = self.get_images_dict(intensity_only=intensity_only)
return list(images_dict.values())
def check_consistent_shape(self) -> None:
shapes_dict = {}
iterable = self.get_images_dict(intensity_only=False).items()
for image_name, image in iterable:
shapes_dict[image_name] = image.shape
num_unique_shapes = len(set(shapes_dict.values()))
if num_unique_shapes > 1:
message = (
'Images in sample have inconsistent shapes:'
f'\n{pprint.pformat(shapes_dict)}'
)
raise ValueError(message)
def add_transform(
self,
transform: 'Transform',
parameters_dict: dict,
) -> None:
self.history.append((transform.name, parameters_dict))
def load(self):
for image in self.get_images(intensity_only=False):
image.load()
def crop(self, index_ini, index_fin):
result_dict = {}
for key, value in self.items():
if isinstance(value, Image):
# patch.clone() is much faster than copy.deepcopy(patch)
value = value.crop(index_ini, index_fin)
else:
value = copy.deepcopy(value)
result_dict[key] = value
return Subject(result_dict)