Skip to content

Commit

Permalink
Fix random swap per channel
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Aug 17, 2020
1 parent edc2a84 commit 27caf58
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions torchio/transforms/augmentation/intensity/random_swap.py
Expand Up @@ -53,10 +53,8 @@ def get_params():
def apply_transform(self, sample: Subject) -> dict:
for image in self.get_images(sample):
tensors = []
for tensor in image[DATA]:
tensor = swap(tensor, self.patch_size, self.num_iterations)
tensors.append(tensor)
image[DATA] = torch.stack(tensors)
tensor = image[DATA]
image[DATA] = swap(tensor, self.patch_size, self.num_iterations)
return sample


Expand All @@ -69,12 +67,12 @@ def swap(
patch_size = to_tuple(patch_size)
for _ in range(num_iterations):
first_ini, first_fin = get_random_indices_from_shape(
tensor.shape,
tensor.shape[-3:],
patch_size,
)
while True:
second_ini, second_fin = get_random_indices_from_shape(
tensor.shape,
tensor.shape[-3:],
patch_size,
)
larger_than_initial = np.all(second_ini >= first_ini)
Expand All @@ -91,10 +89,10 @@ def swap(


def insert(tensor: TypeData, patch: TypeData, index_ini: np.ndarray) -> None:
index_fin = index_ini + np.array(patch.shape)
index_fin = index_ini + np.array(patch.shape[-3:])
i_ini, j_ini, k_ini = index_ini
i_fin, j_fin, k_fin = index_fin
tensor[i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch
tensor[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch


def crop(
Expand All @@ -104,20 +102,20 @@ def crop(
) -> Union[np.ndarray, torch.Tensor]:
i_ini, j_ini, k_ini = index_ini
i_fin, j_fin, k_fin = index_fin
return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
return image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]


def get_random_indices_from_shape(
shape: TypeTripletInt,
spatial_shape: TypeTripletInt,
patch_size: TypeTripletInt,
) -> Tuple[np.ndarray, np.ndarray]:
shape_array = np.array(shape)
shape_array = np.array(spatial_shape)
patch_size_array = np.array(patch_size)
max_index_ini = shape_array - patch_size_array
if (max_index_ini < 0).any():
message = (
f'Patch size {patch_size} must not be'
f' larger than image size {shape}'
f'Patch size {patch_size} cannot be'
f' larger than image spatial shape {spatial_shape}'
)
raise ValueError(message)
max_index_ini = max_index_ini.astype(np.uint16)
Expand Down

0 comments on commit 27caf58

Please sign in to comment.