From 9f4092e4d136b5b0487996b3c59fd88f31fb84d0 Mon Sep 17 00:00:00 2001 From: Buglakova Date: Sun, 24 Jul 2022 06:44:37 +0200 Subject: [PATCH 1/6] Add class for 3D elastic deformation. --- torch_em/transform/augmentation.py | 64 +++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/torch_em/transform/augmentation.py b/torch_em/transform/augmentation.py index f4dc6c74..6f1b9087 100644 --- a/torch_em/transform/augmentation.py +++ b/torch_em/transform/augmentation.py @@ -6,7 +6,67 @@ from ..util import ensure_tensor -# TODO RandomElastic3D ? +class RandomElasticDeformation3D(kornia.augmentation.AugmentationBase3D): + def __init__(self, + control_point_spacing=1, + sigma=(32., 32.), + alpha=(4., 4.), + interpolation=kornia.constants.Resample.BILINEAR, + p=0.5, + keepdim=False, + same_on_batch=True): + super().__init__(p=p, # keepdim=keepdim, + same_on_batch=same_on_batch) + if isinstance(control_point_spacing, int): + self.control_point_spacing = [control_point_spacing] * 2 + else: + self.control_point_spacing = control_point_spacing + assert len(self.control_point_spacing) == 2 + self.interpolation = interpolation + self.flags = dict( + interpolation=torch.tensor(self.interpolation.value), + sigma=sigma, + alpha=alpha + ) + + # The same transformation applied to all samples in a batch + def generate_parameters(self, batch_shape): + assert len(batch_shape) == 5 + shape = batch_shape[3:] + control_shape = tuple( + sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) + ) + deformation_fields = [ + np.random.uniform(-1, 1, control_shape), + np.random.uniform(-1, 1, control_shape) + ] + deformation_fields = [ + resize(df, shape, order=3)[None] for df in deformation_fields + ] + noise = np.concatenate(deformation_fields, axis=0)[None].astype('float32') + noise = torch.from_numpy(noise) + return {'noise': noise} + + def __call__(self, input, params=None): + assert(len(input.shape) == 5) + if params is None: + params = self.generate_parameters(input.shape) + self._params = params + + noise = params['noise'] + mode = 'bilinear' if (self.flags['interpolation'] == 1).all() else 'nearest' + noise_ch = noise.expand(input.shape[1], -1, -1, -1) + input_transformed = [] + for i, x in enumerate(torch.unbind(input, dim=0)): + x_transformed = kornia.geometry.transform.elastic_transform2d( + x, noise_ch, sigma=self.flags['sigma'], + alpha=self.flags['alpha'], mode=mode, + padding_mode="reflection" + ) + input_transformed.append(x_transformed) + input_transformed = torch.stack(input_transformed) + return input_transformed + class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): @@ -129,7 +189,7 @@ def halo(self, shape): "RandomRotation": {"degrees": 90}, "RandomRotation3D": {"degrees": (90, 90, 90)}, "RandomVerticalFlip": {}, - "RandomVerticalFlip3D": {}, + "RandomVerticalFlip3D": {} } From aa849153599b3a77c66238ef332553dee22d2399 Mon Sep 17 00:00:00 2001 From: Buglakova Date: Sun, 24 Jul 2022 06:54:02 +0200 Subject: [PATCH 2/6] Add new augmentation to the dict. --- torch_em/transform/augmentation.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/torch_em/transform/augmentation.py b/torch_em/transform/augmentation.py index 6f1b9087..da263ca6 100644 --- a/torch_em/transform/augmentation.py +++ b/torch_em/transform/augmentation.py @@ -189,7 +189,8 @@ def halo(self, shape): "RandomRotation": {"degrees": 90}, "RandomRotation3D": {"degrees": (90, 90, 90)}, "RandomVerticalFlip": {}, - "RandomVerticalFlip3D": {} + "RandomVerticalFlip3D": {}, + "RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]} } @@ -209,6 +210,14 @@ def halo(self, shape): ] +def create_augmentation(trafo): + assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined" + if trafo in dir(kornia.augmentation): + return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) + + return globals()[trafo](**AUGMENTATIONS[trafo]) + + def get_augmentations(ndim=2, transforms=None, dtype=torch.float32): @@ -220,10 +229,7 @@ def get_augmentations(ndim=2, transforms = DEFAULT_3D_AUGMENTATIONS else: transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS - transforms = [ - getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) - for trafo in transforms - ] + transforms = [create_augmentation(trafo) for trafo in transforms] assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase) for trafo in transforms) From 2219a7b3d989b506098c39221bedde1f2ef90825 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 25 Jul 2022 11:19:30 +0200 Subject: [PATCH 3/6] Fix issues in RandomElasticDeformation --- torch_em/transform/augmentation.py | 38 +++++++++++++++--------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torch_em/transform/augmentation.py b/torch_em/transform/augmentation.py index da263ca6..94cc6529 100644 --- a/torch_em/transform/augmentation.py +++ b/torch_em/transform/augmentation.py @@ -43,9 +43,9 @@ def generate_parameters(self, batch_shape): deformation_fields = [ resize(df, shape, order=3)[None] for df in deformation_fields ] - noise = np.concatenate(deformation_fields, axis=0)[None].astype('float32') + noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") noise = torch.from_numpy(noise) - return {'noise': noise} + return {"noise": noise} def __call__(self, input, params=None): assert(len(input.shape) == 5) @@ -53,14 +53,14 @@ def __call__(self, input, params=None): params = self.generate_parameters(input.shape) self._params = params - noise = params['noise'] - mode = 'bilinear' if (self.flags['interpolation'] == 1).all() else 'nearest' + noise = params["noise"] + mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest" noise_ch = noise.expand(input.shape[1], -1, -1, -1) input_transformed = [] for i, x in enumerate(torch.unbind(input, dim=0)): x_transformed = kornia.geometry.transform.elastic_transform2d( - x, noise_ch, sigma=self.flags['sigma'], - alpha=self.flags['alpha'], mode=mode, + x, noise_ch, sigma=self.flags["sigma"], + alpha=self.flags["alpha"], mode=mode, padding_mode="reflection" ) input_transformed.append(x_transformed) @@ -68,7 +68,6 @@ def __call__(self, input, params=None): return input_transformed - class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): def __init__(self, control_point_spacing=1, @@ -79,8 +78,7 @@ def __init__(self, keepdim=False, same_on_batch=False): super().__init__(p=p, # keepdim=keepdim, - same_on_batch=same_on_batch, - return_transform=False) + same_on_batch=same_on_batch) if isinstance(control_point_spacing, int): self.control_point_spacing = [control_point_spacing] * 2 else: @@ -107,17 +105,19 @@ def generate_parameters(self, batch_shape): deformation_fields = [ resize(df, shape, order=3)[None] for df in deformation_fields ] - noise = np.concatenate(deformation_fields, axis=0)[None].astype('float32') + noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32") noise = torch.from_numpy(noise) - return {'noise': noise} + return {"noise": noise} - def apply_transform(self, input, params): - noise = params['noise'] - mode = 'bilinear' if (self.flags['resample'] == 1).all() else 'nearest' - # NOTE mode is currently only available on my fork, need kornia PR: - # https://github.com/kornia/kornia/pull/883 + def __call__(self, input, params=None): + if params is None: + params = self.generate_parameters(input.shape) + self._params = params + noise = params["noise"] + mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest" return kornia.geometry.transform.elastic_transform2d( - input, noise, sigma=self.flags['sigma'], alpha=self.flags['alpha'], mode=mode + input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode, + padding_mode="reflection" ) @@ -214,9 +214,9 @@ def create_augmentation(trafo): assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined" if trafo in dir(kornia.augmentation): return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo]) - + return globals()[trafo](**AUGMENTATIONS[trafo]) - + def get_augmentations(ndim=2, transforms=None, From 433d1c92d52950b3c03d567f5a548d3ab8b8191e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 25 Jul 2022 11:19:49 +0200 Subject: [PATCH 4/6] Update check augmentation script --- scripts/augmentation/.gitignore | 1 + scripts/augmentation/check_augmentation.py | 113 ++++++++++++--------- 2 files changed, 67 insertions(+), 47 deletions(-) create mode 100644 scripts/augmentation/.gitignore diff --git a/scripts/augmentation/.gitignore b/scripts/augmentation/.gitignore new file mode 100644 index 00000000..8fce6030 --- /dev/null +++ b/scripts/augmentation/.gitignore @@ -0,0 +1 @@ +data/ diff --git a/scripts/augmentation/check_augmentation.py b/scripts/augmentation/check_augmentation.py index 06e704f7..6b1c77d0 100644 --- a/scripts/augmentation/check_augmentation.py +++ b/scripts/augmentation/check_augmentation.py @@ -1,83 +1,102 @@ -import napari -import numpy as np +import os + import h5py import kornia +import napari +import numpy as np +import torch -from torch_em.transform.augmentation import KorniaAugmentationPipeline, get_augmentations -from torch_em.transform.augmentation import RandomElasticDeformation +import torch_em.transform.augmentation as augmentation +from torch_em.data.datasets.uro_cell import _require_urocell_data -pr = '/g/schwab/hennies/project_segmentation_paper/ds_sbem-6dpf-1-whole/seg_210122_mito/seg_10nm/gt_cubes/gt000/raw_256.h5' -pgt = '/g/schwab/hennies/project_segmentation_paper/ds_sbem-6dpf-1-whole/seg_210122_mito/seg_10nm/gt_cubes/gt000/mito.h5' -bb = np.s_[:32, :128, :128] -with h5py.File(pr, 'r') as f: - raw = f['data'][bb].astype('float32') -with h5py.File(pgt, 'r') as f: - seg = f['data'][bb] +def get_data(): + _require_urocell_data("./data", download=True) + path = "./data/fib1-3-3-0.h5" + assert os.path.exists(path) + bb = np.s_[32:64, 128:256, 128:256] + with h5py.File(path, "r") as f: + raw = f["raw"][bb].astype("float32") + seg = f["labels/mito"][bb] + return raw, seg def check_kornia_augmentation(): - rot = kornia.augmentation.RandomRotation( - degrees=90., p=1. - ) + raw, seg = get_data() - trafo = KorniaAugmentationPipeline( - rot - ) + rot = kornia.augmentation.RandomRotation(degrees=90.0, p=1.0) + trafo = augmentation.KorniaAugmentationPipeline(rot) transformed_raw, transformed_seg = trafo(raw, seg) transformed_raw = transformed_raw.numpy().squeeze() transformed_seg = transformed_seg.numpy().squeeze() - with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(raw) - viewer.add_image(transformed_raw) - viewer.add_labels(seg) - viewer.add_labels(transformed_seg) + viewer = napari.Viewer() + viewer.add_image(raw) + viewer.add_image(transformed_raw) + viewer.add_labels(seg) + viewer.add_labels(transformed_seg) + napari.run() def check_default_augmentation(): - trafo = get_augmentations() + raw, seg = get_data() + + trafo = augmentation.get_augmentations() transformed_raw, transformed_seg = trafo(raw, seg) transformed_raw = transformed_raw.numpy().squeeze() - transformed_seg = transformed_seg.numpy().squeeze() + transformed_seg = transformed_seg.numpy().squeeze().astype("uint32") - with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(raw) - viewer.add_image(transformed_raw) - viewer.add_labels(seg) - viewer.add_labels(transformed_seg) + viewer = napari.Viewer() + viewer.add_image(raw) + viewer.add_image(transformed_raw) + viewer.add_labels(seg) + viewer.add_labels(transformed_seg) + napari.run() def check_elastic_2d(): - import torch - raw_ = raw[0] - traw = torch.from_numpy(raw_[None, None]) + raw, seg = get_data() + raw = raw[0] + traw = torch.from_numpy(raw[None, None]) - # trafo = RandomElasticDeformation(alpha=(1., 1.), p=1) - # transformed_raw = trafo(traw) + trafo = augmentation.RandomElasticDeformation(alpha=(1., 1.), p=1) + transformed_raw = trafo(traw) # noise_shape = (1, 2) + raw_.shape # noise = torch.zeros(noise_shape) - amp = 1. / raw.shape[0] - noise = np.concatenate([np.random.uniform(-amp, amp, traw.shape), - np.random.uniform(-amp, amp, traw.shape)], axis=1).astype('float32') - noise = torch.from_numpy(noise) + # amp = 1. / raw.shape[0] + # noise = np.concatenate([np.random.uniform(-amp, amp, traw.shape), + # np.random.uniform(-amp, amp, traw.shape)], axis=1).astype('float32') + # noise = torch.from_numpy(noise) + + # alpha = 1. + # transformed_raw = kornia.geometry.transform.elastic_transform2d(traw, noise, alpha=(alpha, alpha)) + + transformed_raw = transformed_raw.numpy().squeeze() + viewer = napari.Viewer() + viewer.add_image(raw) + viewer.add_image(transformed_raw) + napari.run() + + +def check_elastic_3d(): + raw, seg = get_data() + traw = torch.from_numpy(raw[None, None]) - alpha = 1. - transformed_raw = kornia.geometry.transform.elastic_transform2d(traw, noise, alpha=(alpha, alpha)) + trafo = augmentation.RandomElasticDeformation3D(alpha=(1., 1.), p=1) + transformed_raw = trafo(traw) transformed_raw = transformed_raw.numpy().squeeze() - with napari.gui_qt(): - viewer = napari.Viewer() - viewer.add_image(raw_) - viewer.add_image(transformed_raw) + viewer = napari.Viewer() + viewer.add_image(raw) + viewer.add_image(transformed_raw) + napari.run() -if __name__ == '__main__': +if __name__ == "__main__": # check_kornia_augmentation() # check_default_augmentation() check_elastic_2d() + # check_elastic_3d() From 7e0010d089e0732c3e24a1975a4741775f9e766e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 25 Jul 2022 11:33:52 +0200 Subject: [PATCH 5/6] Update check_augmentation and some cosmetics --- scripts/augmentation/check_augmentation.py | 39 ++++++++++------------ torch_em/transform/augmentation.py | 13 ++++---- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/scripts/augmentation/check_augmentation.py b/scripts/augmentation/check_augmentation.py index 6b1c77d0..647f6be1 100644 --- a/scripts/augmentation/check_augmentation.py +++ b/scripts/augmentation/check_augmentation.py @@ -4,7 +4,6 @@ import kornia import napari import numpy as np -import torch import torch_em.transform.augmentation as augmentation from torch_em.data.datasets.uro_cell import _require_urocell_data @@ -58,45 +57,41 @@ def check_default_augmentation(): def check_elastic_2d(): raw, seg = get_data() - raw = raw[0] - traw = torch.from_numpy(raw[None, None]) - - trafo = augmentation.RandomElasticDeformation(alpha=(1., 1.), p=1) - transformed_raw = trafo(traw) - - # noise_shape = (1, 2) + raw_.shape - # noise = torch.zeros(noise_shape) - # amp = 1. / raw.shape[0] - # noise = np.concatenate([np.random.uniform(-amp, amp, traw.shape), - # np.random.uniform(-amp, amp, traw.shape)], axis=1).astype('float32') - # noise = torch.from_numpy(noise) - - # alpha = 1. - # transformed_raw = kornia.geometry.transform.elastic_transform2d(traw, noise, alpha=(alpha, alpha)) + raw, seg = raw[0], seg[0] + deform = augmentation.RandomElasticDeformation(alpha=(1., 1.), p=1) + trafo = augmentation.KorniaAugmentationPipeline(deform) + transformed_raw, transformed_seg = trafo(raw[None, None], seg[None, None]) transformed_raw = transformed_raw.numpy().squeeze() + transformed_seg = transformed_seg.numpy().squeeze().astype("uint32") + viewer = napari.Viewer() viewer.add_image(raw) viewer.add_image(transformed_raw) + viewer.add_labels(seg) + viewer.add_labels(transformed_seg) napari.run() def check_elastic_3d(): raw, seg = get_data() - traw = torch.from_numpy(raw[None, None]) - - trafo = augmentation.RandomElasticDeformation3D(alpha=(1., 1.), p=1) - transformed_raw = trafo(traw) + deform = augmentation.RandomElasticDeformation3D(alpha=(1., 1.), p=1) + trafo = augmentation.KorniaAugmentationPipeline(deform) + transformed_raw, transformed_seg = trafo(raw[None, None], seg[None, None]) transformed_raw = transformed_raw.numpy().squeeze() + transformed_seg = transformed_seg.numpy().squeeze().astype("uint32") + viewer = napari.Viewer() viewer.add_image(raw) viewer.add_image(transformed_raw) + viewer.add_labels(seg) + viewer.add_labels(transformed_seg) napari.run() if __name__ == "__main__": # check_kornia_augmentation() # check_default_augmentation() - check_elastic_2d() - # check_elastic_3d() + # check_elastic_2d() + check_elastic_3d() diff --git a/torch_em/transform/augmentation.py b/torch_em/transform/augmentation.py index 94cc6529..579c00e5 100644 --- a/torch_em/transform/augmentation.py +++ b/torch_em/transform/augmentation.py @@ -93,7 +93,7 @@ def __init__(self, # TODO do we need special treatment for batches, channels > 1? def generate_parameters(self, batch_shape): - assert len(batch_shape) == 4 + assert len(batch_shape) == 4, f"{len(batch_shape)}" shape = batch_shape[2:] control_shape = tuple( sh // spacing for sh, spacing in zip(shape, self.control_point_spacing) @@ -125,7 +125,7 @@ def __call__(self, input, params=None): # so that we can load a bigger block and cut it away class KorniaAugmentationPipeline(torch.nn.Module): interpolatable_torch_types = [torch.float16, torch.float32, torch.float64] - interpolatable_numpy_types = [np.dtype('float32'), np.dtype('float64')] + interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")] def __init__(self, *kornia_augmentations, dtype=torch.float32): super().__init__() @@ -134,7 +134,7 @@ def __init__(self, *kornia_augmentations, dtype=torch.float32): self.halo = self.compute_halo() # for now we only add a halo for the random rotation trafos and - # also don't compute the halo dynamically based on the input shape + # also don"t compute the halo dynamically based on the input shape def compute_halo(self): halo = None for aug in self.augmentations: @@ -151,11 +151,10 @@ def is_interpolatable(self, tensor): return tensor.dtype in self.interpolatable_numpy_types def transform_tensor(self, augmentation, tensor, interpolatable, params=None): - interpolating = 'interpolation' in getattr(augmentation, 'flags', []) + interpolating = "interpolation" in getattr(augmentation, "flags", []) if interpolating: - resampler = kornia.constants.Resample.get('BILINEAR' if interpolatable else 'NEAREST') - augmentation.flags['interpolation'] = torch.tensor(resampler.value) - + resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST") + augmentation.flags["interpolation"] = torch.tensor(resampler.value) transformed = augmentation(tensor, params) return transformed, augmentation._params From c8b66cb9c98dd29d6ed436f0eca02741101471ab Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 25 Jul 2022 19:50:54 +0200 Subject: [PATCH 6/6] Rename trafo --- scripts/augmentation/check_augmentation.py | 4 ++-- torch_em/transform/augmentation.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/augmentation/check_augmentation.py b/scripts/augmentation/check_augmentation.py index 647f6be1..43575f7a 100644 --- a/scripts/augmentation/check_augmentation.py +++ b/scripts/augmentation/check_augmentation.py @@ -59,7 +59,7 @@ def check_elastic_2d(): raw, seg = get_data() raw, seg = raw[0], seg[0] - deform = augmentation.RandomElasticDeformation(alpha=(1., 1.), p=1) + deform = augmentation.RandomElasticDeformation(alpha=(1.0, 1.0), p=1) trafo = augmentation.KorniaAugmentationPipeline(deform) transformed_raw, transformed_seg = trafo(raw[None, None], seg[None, None]) transformed_raw = transformed_raw.numpy().squeeze() @@ -76,7 +76,7 @@ def check_elastic_2d(): def check_elastic_3d(): raw, seg = get_data() - deform = augmentation.RandomElasticDeformation3D(alpha=(1., 1.), p=1) + deform = augmentation.RandomElasticDeformationStacked(alpha=(1.0, 1.0), p=1) trafo = augmentation.KorniaAugmentationPipeline(deform) transformed_raw, transformed_seg = trafo(raw[None, None], seg[None, None]) transformed_raw = transformed_raw.numpy().squeeze() diff --git a/torch_em/transform/augmentation.py b/torch_em/transform/augmentation.py index 579c00e5..29eddae4 100644 --- a/torch_em/transform/augmentation.py +++ b/torch_em/transform/augmentation.py @@ -6,11 +6,11 @@ from ..util import ensure_tensor -class RandomElasticDeformation3D(kornia.augmentation.AugmentationBase3D): +class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D): def __init__(self, control_point_spacing=1, - sigma=(32., 32.), - alpha=(4., 4.), + sigma=(32.0, 32.0), + alpha=(4.0, 4.0), interpolation=kornia.constants.Resample.BILINEAR, p=0.5, keepdim=False, @@ -71,8 +71,8 @@ def __call__(self, input, params=None): class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D): def __init__(self, control_point_spacing=1, - sigma=(4., 4.), - alpha=(32., 32.), + sigma=(4.0, 4.0), + alpha=(32.0, 32.0), resample=kornia.constants.Resample.BILINEAR, p=0.5, keepdim=False,