Skip to content

Commit

Permalink
Merge 87496d0 into 417c33f
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Feb 10, 2021
2 parents 417c33f + 87496d0 commit 15a9e31
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
5 changes: 3 additions & 2 deletions tests/data/sampler/test_label_sampler.py
Expand Up @@ -19,9 +19,10 @@ def test_label_probabilities(self):
)
subject = tio.SubjectsDataset([subject])[0]
probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
sampler = tio.LabelSampler(5, 'label', label_probabilities=probs_dict)
patch_size = (1, 1, 5)
sampler = tio.LabelSampler(patch_size, label_probabilities=probs_dict)
probabilities = sampler.get_probability_map(subject)
fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
fixture = torch.Tensor((0, 0, 1 / 4, 1 / 4, 1 / 4, 0, 0))
assert torch.all(probabilities.squeeze().eq(fixture))

def test_inconsistent_shape(self):
Expand Down
35 changes: 31 additions & 4 deletions torchio/data/sampler/label.py
@@ -1,6 +1,7 @@
from typing import Dict, Optional

import torch
import numpy as np

from ...data.image import LabelMap
from ...data.subject import Subject
Expand Down Expand Up @@ -39,8 +40,12 @@ class will be sampled. Probabilities do not need to be normalized.
>>> subject = tio.datasets.Colin27()
>>> subject
Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
>>> subject = tio.SubjectsDataset([subject])[0]
>>> sampler = tio.data.LabelSampler(64, 'brain')
>>> probabilities = {0: 0.5, 1: 0.5}
>>> sampler = tio.data.LabelSampler(
... patch_size=64,
... label_name='brain',
... label_probabilities=probabilities,
... )
>>> generator = sampler(subject)
>>> for patch in generator:
... print(patch.shape)
Expand Down Expand Up @@ -85,23 +90,38 @@ def get_probability_map_image(self, subject: Subject) -> LabelMap:
return label_map

def get_probability_map(self, subject: Subject) -> torch.Tensor:
label_map_tensor = self.get_probability_map_image(subject).data
label_map_tensor = label_map_tensor.float()
label_map_tensor = self.get_probability_map_image(subject).data.float()

if self.label_probabilities_dict is None:
return label_map_tensor > 0
probability_map = self.get_probabilities_from_label_map(
label_map_tensor,
self.label_probabilities_dict,
self.patch_size,
)
return probability_map

@staticmethod
def get_probabilities_from_label_map(
label_map: torch.Tensor,
label_probabilities_dict: Dict[int, float],
patch_size: np.ndarray,
) -> torch.Tensor:
"""Create probability map according to label map probabilities."""
patch_size = patch_size.astype(int)
ini_i, ini_j, ini_k = patch_size // 2
spatial_shape = np.array(label_map.shape[1:])
if np.any(patch_size > spatial_shape):
message = (
f'Patch size {patch_size}'
f'larger than label map {spatial_shape}'
)
raise RuntimeError(message)
crop_fin_i, crop_fin_j, crop_fin_k = crop_fin = (patch_size - 1) // 2
fin_i, fin_j, fin_k = spatial_shape - crop_fin
# See https://github.com/fepegar/torchio/issues/458
label_map = label_map[:, ini_i:fin_i, ini_j:fin_j, ini_k:fin_k]

multichannel = label_map.shape[0] > 1
probability_map = torch.zeros_like(label_map)
label_probs = torch.Tensor(list(label_probabilities_dict.values()))
Expand All @@ -122,4 +142,11 @@ def get_probabilities_from_label_map(
probability_map[mask] = prob_voxels
if multichannel:
probability_map = probability_map.sum(dim=0, keepdim=True)

# See https://github.com/fepegar/torchio/issues/458
padding = ini_k, crop_fin_k, ini_j, crop_fin_j, ini_i, crop_fin_i
probability_map = torch.nn.functional.pad(
probability_map,
padding,
)
return probability_map

0 comments on commit 15a9e31

Please sign in to comment.