Skip to content

Commit

Permalink
Replace constant with attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Dec 14, 2020
1 parent ea441f8 commit 48e8a7d
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 53 deletions.
45 changes: 22 additions & 23 deletions tests/transforms/augmentation/test_random_labels_to_image.py
@@ -1,5 +1,4 @@
from torchio.transforms import RandomLabelsToImage
from torchio import DATA
from ...utils import TorchioTestCase


Expand All @@ -23,12 +22,12 @@ def test_deterministic_simulation(self):
)
transformed = transform(self.sample_subject)
self.assertTensorEqual(
transformed['image_from_labels'][DATA] == 0.5,
self.sample_subject['label'][DATA] == 0
transformed['image_from_labels'].data == 0.5,
self.sample_subject['label'].data == 0
)
self.assertTensorEqual(
transformed['image_from_labels'][DATA] == 2,
self.sample_subject['label'][DATA] == 1
transformed['image_from_labels'].data == 2,
self.sample_subject['label'].data == 1
)

def test_deterministic_simulation_with_discretized_label_map(self):
Expand All @@ -43,12 +42,12 @@ def test_deterministic_simulation_with_discretized_label_map(self):
)
transformed = transform(self.sample_subject)
self.assertTensorEqual(
transformed['image_from_labels'][DATA] == 0.5,
self.sample_subject['label'][DATA] == 0
transformed['image_from_labels'].data == 0.5,
self.sample_subject['label'].data == 0
)
self.assertTensorEqual(
transformed['image_from_labels'][DATA] == 2,
self.sample_subject['label'][DATA] == 1
transformed['image_from_labels'].data == 2,
self.sample_subject['label'].data == 1
)

def test_deterministic_simulation_with_pv_map(self):
Expand All @@ -62,11 +61,11 @@ def test_deterministic_simulation_with_pv_map(self):
)
transformed = transform(subject)
self.assertTensorAlmostEqual(
transformed['image_from_labels'][DATA][0],
subject['label'][DATA][0] * 0.5 + subject['label'][DATA][1] * 1
transformed['image_from_labels'].data[0],
subject['label'].data[0] * 0.5 + subject['label'].data[1] * 1
)
self.assertEqual(
transformed['image_from_labels'][DATA].shape,
transformed['image_from_labels'].data.shape,
(1, 10, 20, 30)
)

Expand All @@ -83,8 +82,8 @@ def test_deterministic_simulation_with_discretized_pv_map(self):
)
transformed = transform(subject)
self.assertTensorAlmostEqual(
transformed['image_from_labels'][DATA],
(subject['label'][DATA] > 0) * 0.5
transformed['image_from_labels'].data,
(subject['label'].data > 0) * 0.5
)

def test_filling(self):
Expand All @@ -96,11 +95,11 @@ def test_filling(self):
image_key='t1',
used_labels=[1]
)
t1_indices = self.sample_subject['label'][DATA] == 0
t1_indices = self.sample_subject['label'].data == 0
transformed = transform(self.sample_subject)
self.assertTensorAlmostEqual(
transformed['t1'][DATA][t1_indices],
self.sample_subject['t1'][DATA][t1_indices]
transformed['t1'].data[t1_indices],
self.sample_subject['t1'].data[t1_indices]
)

def test_filling_with_discretized_label_map(self):
Expand All @@ -113,11 +112,11 @@ def test_filling_with_discretized_label_map(self):
discretize=True,
used_labels=[1]
)
t1_indices = self.sample_subject['label'][DATA] < 0.5
t1_indices = self.sample_subject['label'].data < 0.5
transformed = transform(self.sample_subject)
self.assertTensorAlmostEqual(
transformed['t1'][DATA][t1_indices],
self.sample_subject['t1'][DATA][t1_indices]
transformed['t1'].data[t1_indices],
self.sample_subject['t1'].data[t1_indices]
)

def test_filling_with_discretized_pv_label_map(self):
Expand All @@ -131,11 +130,11 @@ def test_filling_with_discretized_pv_label_map(self):
discretize=True,
used_labels=[1]
)
t1_indices = subject['label'][DATA].argmax(dim=0) == 0
t1_indices = subject['label'].data.argmax(dim=0) == 0
transformed = transform(subject)
self.assertTensorAlmostEqual(
transformed['t1'][DATA][0][t1_indices],
subject['t1'][DATA][0][t1_indices]
transformed['t1'].data[0][t1_indices],
subject['t1'].data[0][t1_indices]
)

def test_filling_without_any_hole(self):
Expand Down
14 changes: 7 additions & 7 deletions tests/transforms/test_lambda_transform.py
@@ -1,5 +1,5 @@
import torch
from torchio import DATA, LABEL
from torchio import LABEL
from torchio.transforms import Lambda
from ..utils import TorchioTestCase

Expand All @@ -26,18 +26,18 @@ def test_lambda(self):
transform = Lambda(lambda x: x + 1)
transformed = transform(self.sample_subject)
assert torch.all(torch.eq(
transformed['t1'][DATA], self.sample_subject['t1'][DATA] + 1))
transformed.t1.data, self.sample_subject.t1.data + 1))
assert torch.all(torch.eq(
transformed['t2'][DATA], self.sample_subject['t2'][DATA] + 1))
transformed.t2.data, self.sample_subject.t2.data + 1))
assert torch.all(torch.eq(
transformed['label'][DATA], self.sample_subject['label'][DATA] + 1))
transformed.label.data, self.sample_subject.label.data + 1))

def test_image_types(self):
transform = Lambda(lambda x: x + 1, types_to_apply=[LABEL])
transformed = transform(self.sample_subject)
assert torch.all(torch.eq(
transformed['t1'][DATA], self.sample_subject['t1'][DATA]))
transformed.t1.data, self.sample_subject.t1.data))
assert torch.all(torch.eq(
transformed['t2'][DATA], self.sample_subject['t2'][DATA]))
transformed.t2.data, self.sample_subject.t2.data))
assert torch.all(torch.eq(
transformed['label'][DATA], self.sample_subject['label'][DATA] + 1))
transformed.label.data, self.sample_subject.label.data + 1))
6 changes: 3 additions & 3 deletions torchio/data/image.py
Expand Up @@ -151,7 +151,7 @@ def __getitem__(self, item):
return super().__getitem__(item)

def __array__(self):
return self[DATA].numpy()
return self.data.numpy()

def __copy__(self):
kwargs = dict(
Expand Down Expand Up @@ -434,7 +434,7 @@ def save(self, path: TypePath, squeeze: bool = True) -> None:
before saving.
"""
write_image(
self[DATA],
self.data,
self.affine,
path,
squeeze=squeeze,
Expand All @@ -449,7 +449,7 @@ def numpy(self) -> np.ndarray:

def as_sitk(self, **kwargs) -> sitk.Image:
"""Get the image as an instance of :class:`sitk.Image`."""
return nib_to_sitk(self[DATA], self.affine, **kwargs)
return nib_to_sitk(self.data, self.affine, **kwargs)

def as_pil(self) -> ImagePIL:
"""Get the image as an instance of :class:`PIL.Image`."""
Expand Down
6 changes: 3 additions & 3 deletions torchio/data/sampler/label.py
Expand Up @@ -4,7 +4,7 @@

from ...data.subject import Subject
from ...typing import TypePatchSize
from ...constants import DATA, TYPE, LABEL
from ...constants import TYPE, LABEL
from .weighted import WeightedSampler


Expand Down Expand Up @@ -64,10 +64,10 @@ def get_probability_map(self, subject: Subject) -> torch.Tensor:
if self.probability_map_name is None:
for image in subject.get_images(intensity_only=False):
if image[TYPE] == LABEL:
label_map_tensor = image[DATA]
label_map_tensor = image.data
break
elif self.probability_map_name in subject:
label_map_tensor = subject[self.probability_map_name][DATA]
label_map_tensor = subject[self.probability_map_name].data
else:
message = (
f'Image "{self.probability_map_name}"'
Expand Down
4 changes: 2 additions & 2 deletions torchio/datasets/mni/colin.py
@@ -1,6 +1,6 @@
import urllib.parse
from ...utils import get_torchio_cache_dir, download_and_extract_archive
from ... import ScalarImage, LabelMap, DATA
from ... import ScalarImage, LabelMap
from .mni import SubjectMNI


Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, version=1998):
if version == 2008:
path = download_root / 'colin27_cls_tal_hires.nii'
cls_image = LabelMap(path)
cls_image.data = cls_image[DATA].round().byte()
cls_image.data = cls_image.data.round().byte()
cls_image.save(path)

if version == 1998:
Expand Down
4 changes: 2 additions & 2 deletions torchio/datasets/mni/icbm.py
Expand Up @@ -5,7 +5,7 @@
compress,
download_and_extract_archive,
)
from ... import ScalarImage, LabelMap, DATA
from ... import ScalarImage, LabelMap
from .mni import SubjectMNI


Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, load_4d_tissues: bool = True):
gm = LabelMap(f'{p}_gm_{m}.nii')
wm = LabelMap(f'{p}_wm_{m}.nii')
csf = LabelMap(f'{p}_csf_{m}.nii')
gm.data = torch.cat((gm[DATA], wm[DATA], csf[DATA]))
gm.data = torch.cat((gm.data, wm.data, csf.data))
gm.save(tissues_path)

for fp in files_dir.glob('*.nii'):
Expand Down
Expand Up @@ -4,7 +4,6 @@
import torch
import numpy as np

from ....constants import DATA
from ....typing import TypeData
from ....data.subject import Subject
from ... import IntensityTransform
Expand Down Expand Up @@ -110,7 +109,7 @@ def apply_transform(self, subject: Subject) -> Subject:
image.data, order, coefficients)
if self.invert_transform:
np.divide(1, bias_field, out=bias_field)
image.data = image[DATA] * torch.from_numpy(bias_field)
image.data = image.data * torch.from_numpy(bias_field)
return subject

@staticmethod
Expand Down
Expand Up @@ -2,7 +2,6 @@

import torch

from ....constants import DATA
from ....utils import check_sequence
from ....data.subject import Subject
from ....typing import TypeData, TypeRangeFloat
Expand Down Expand Up @@ -186,7 +185,7 @@ def apply_transform(self, subject: Subject) -> Subject:
'discretize': self.discretize,
}

label_map = subject[self.label_key][DATA]
label_map = subject[self.label_key].data

# Find out if we face a partial-volume image or a label map.
# One-hot-encoded label map is considered as a partial-volume image
Expand Down Expand Up @@ -353,7 +352,7 @@ def apply_transform(self, subject: Subject) -> Subject:
bg_mask = label_map == -1
else:
bg_mask = label_map.sum(dim=0, keepdim=True) < 0.5
final_image[DATA][bg_mask] = original_image[DATA][bg_mask].float()
final_image.data[bg_mask] = original_image.data[bg_mask].float()

subject.add_image(final_image, self.image_key)
return subject
Expand Down
5 changes: 2 additions & 3 deletions torchio/transforms/augmentation/intensity/random_noise.py
Expand Up @@ -2,7 +2,6 @@
from typing import Tuple, Union, Dict, Sequence

import torch
from ....constants import DATA
from ....data.subject import Subject
from ... import IntensityTransform
from .. import RandomTransform
Expand Down Expand Up @@ -93,10 +92,10 @@ def apply_transform(self, subject: Subject) -> Subject:
if self.arguments_are_dict():
mean, std, seed = [arg[name] for arg in args]
with self._use_seed(seed):
noise = get_noise(image[DATA], mean, std)
noise = get_noise(image.data, mean, std)
if self.invert_transform:
noise *= -1
image.data = image[DATA] + noise
image.data = image.data + noise
return subject


Expand Down
3 changes: 1 addition & 2 deletions torchio/transforms/augmentation/intensity/random_spike.py
Expand Up @@ -4,7 +4,6 @@
import torch
import numpy as np

from ....constants import DATA
from ....data.subject import Subject
from ... import IntensityTransform, FourierTransform
from .. import RandomTransform
Expand Down Expand Up @@ -111,7 +110,7 @@ def apply_transform(self, subject: Subject) -> Subject:
spikes_positions = self.spikes_positions[image_name]
intensity = self.intensity[image_name]
transformed_tensors = []
for channel in image[DATA]:
for channel in image.data:
transformed_tensor = self.add_artifact(
channel,
spikes_positions,
Expand Down
2 changes: 1 addition & 1 deletion torchio/transforms/augmentation/spatial/random_affine.py
Expand Up @@ -272,7 +272,7 @@ def apply_transform(self, subject: Subject) -> Subject:
center = None

transformed_tensors = []
for tensor in image[DATA]:
for tensor in image.data:
transformed_tensor = self.apply_affine_transform(
tensor,
image.affine,
Expand Down
4 changes: 2 additions & 2 deletions torchio/transforms/lambda_transform.py
Expand Up @@ -2,7 +2,7 @@
import torch
from ..typing import TypeCallable
from ..data.subject import Subject
from ..constants import DATA, TYPE
from ..constants import TYPE
from .transform import Transform


Expand Down Expand Up @@ -44,7 +44,7 @@ def apply_transform(self, subject: Subject) -> Subject:
if image_type not in self.types_to_apply:
continue

function_arg = image[DATA]
function_arg = image.data
result = self.function(function_arg)
if not isinstance(result, torch.Tensor):
message = (
Expand Down

0 comments on commit 48e8a7d

Please sign in to comment.