Skip to content

Commit

Permalink
Merge pull request #79 from Buglakova/elastic_3d
Browse files Browse the repository at this point in the history
Elastic 3d
  • Loading branch information
constantinpape committed Jul 25, 2022
2 parents babd160 + c8b66cb commit 88616f4
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 75 deletions.
1 change: 1 addition & 0 deletions scripts/augmentation/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/
114 changes: 64 additions & 50 deletions scripts/augmentation/check_augmentation.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,97 @@
import napari
import numpy as np
import os

import h5py
import kornia
import napari
import numpy as np

from torch_em.transform.augmentation import KorniaAugmentationPipeline, get_augmentations
from torch_em.transform.augmentation import RandomElasticDeformation
import torch_em.transform.augmentation as augmentation
from torch_em.data.datasets.uro_cell import _require_urocell_data

pr = '/g/schwab/hennies/project_segmentation_paper/ds_sbem-6dpf-1-whole/seg_210122_mito/seg_10nm/gt_cubes/gt000/raw_256.h5'
pgt = '/g/schwab/hennies/project_segmentation_paper/ds_sbem-6dpf-1-whole/seg_210122_mito/seg_10nm/gt_cubes/gt000/mito.h5'

bb = np.s_[:32, :128, :128]
with h5py.File(pr, 'r') as f:
raw = f['data'][bb].astype('float32')
with h5py.File(pgt, 'r') as f:
seg = f['data'][bb]
def get_data():
_require_urocell_data("./data", download=True)
path = "./data/fib1-3-3-0.h5"
assert os.path.exists(path)
bb = np.s_[32:64, 128:256, 128:256]
with h5py.File(path, "r") as f:
raw = f["raw"][bb].astype("float32")
seg = f["labels/mito"][bb]
return raw, seg


def check_kornia_augmentation():
rot = kornia.augmentation.RandomRotation(
degrees=90., p=1.
)
raw, seg = get_data()

trafo = KorniaAugmentationPipeline(
rot
)
rot = kornia.augmentation.RandomRotation(degrees=90.0, p=1.0)
trafo = augmentation.KorniaAugmentationPipeline(rot)

transformed_raw, transformed_seg = trafo(raw, seg)
transformed_raw = transformed_raw.numpy().squeeze()
transformed_seg = transformed_seg.numpy().squeeze()

with napari.gui_qt():
viewer = napari.Viewer()
viewer.add_image(raw)
viewer.add_image(transformed_raw)
viewer.add_labels(seg)
viewer.add_labels(transformed_seg)
viewer = napari.Viewer()
viewer.add_image(raw)
viewer.add_image(transformed_raw)
viewer.add_labels(seg)
viewer.add_labels(transformed_seg)
napari.run()


def check_default_augmentation():
trafo = get_augmentations()
raw, seg = get_data()

trafo = augmentation.get_augmentations()

transformed_raw, transformed_seg = trafo(raw, seg)
transformed_raw = transformed_raw.numpy().squeeze()
transformed_seg = transformed_seg.numpy().squeeze()
transformed_seg = transformed_seg.numpy().squeeze().astype("uint32")

with napari.gui_qt():
viewer = napari.Viewer()
viewer.add_image(raw)
viewer.add_image(transformed_raw)
viewer.add_labels(seg)
viewer.add_labels(transformed_seg)
viewer = napari.Viewer()
viewer.add_image(raw)
viewer.add_image(transformed_raw)
viewer.add_labels(seg)
viewer.add_labels(transformed_seg)
napari.run()


def check_elastic_2d():
import torch
raw_ = raw[0]
traw = torch.from_numpy(raw_[None, None])
raw, seg = get_data()
raw, seg = raw[0], seg[0]

# trafo = RandomElasticDeformation(alpha=(1., 1.), p=1)
# transformed_raw = trafo(traw)
deform = augmentation.RandomElasticDeformation(alpha=(1.0, 1.0), p=1)
trafo = augmentation.KorniaAugmentationPipeline(deform)
transformed_raw, transformed_seg = trafo(raw[None, None], seg[None, None])
transformed_raw = transformed_raw.numpy().squeeze()
transformed_seg = transformed_seg.numpy().squeeze().astype("uint32")

# noise_shape = (1, 2) + raw_.shape
# noise = torch.zeros(noise_shape)
amp = 1. / raw.shape[0]
noise = np.concatenate([np.random.uniform(-amp, amp, traw.shape),
np.random.uniform(-amp, amp, traw.shape)], axis=1).astype('float32')
noise = torch.from_numpy(noise)
viewer = napari.Viewer()
viewer.add_image(raw)
viewer.add_image(transformed_raw)
viewer.add_labels(seg)
viewer.add_labels(transformed_seg)
napari.run()

alpha = 1.
transformed_raw = kornia.geometry.transform.elastic_transform2d(traw, noise, alpha=(alpha, alpha))

def check_elastic_3d():
raw, seg = get_data()

deform = augmentation.RandomElasticDeformationStacked(alpha=(1.0, 1.0), p=1)
trafo = augmentation.KorniaAugmentationPipeline(deform)
transformed_raw, transformed_seg = trafo(raw[None, None], seg[None, None])
transformed_raw = transformed_raw.numpy().squeeze()
with napari.gui_qt():
viewer = napari.Viewer()
viewer.add_image(raw_)
viewer.add_image(transformed_raw)
transformed_seg = transformed_seg.numpy().squeeze().astype("uint32")

viewer = napari.Viewer()
viewer.add_image(raw)
viewer.add_image(transformed_raw)
viewer.add_labels(seg)
viewer.add_labels(transformed_seg)
napari.run()


if __name__ == '__main__':
if __name__ == "__main__":
# check_kornia_augmentation()
# check_default_augmentation()
check_elastic_2d()
# check_elastic_2d()
check_elastic_3d()
115 changes: 90 additions & 25 deletions torch_em/transform/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,79 @@
from ..util import ensure_tensor


# TODO RandomElastic3D ?
class RandomElasticDeformationStacked(kornia.augmentation.AugmentationBase3D):
def __init__(self,
control_point_spacing=1,
sigma=(32.0, 32.0),
alpha=(4.0, 4.0),
interpolation=kornia.constants.Resample.BILINEAR,
p=0.5,
keepdim=False,
same_on_batch=True):
super().__init__(p=p, # keepdim=keepdim,
same_on_batch=same_on_batch)
if isinstance(control_point_spacing, int):
self.control_point_spacing = [control_point_spacing] * 2
else:
self.control_point_spacing = control_point_spacing
assert len(self.control_point_spacing) == 2
self.interpolation = interpolation
self.flags = dict(
interpolation=torch.tensor(self.interpolation.value),
sigma=sigma,
alpha=alpha
)

# The same transformation applied to all samples in a batch
def generate_parameters(self, batch_shape):
assert len(batch_shape) == 5
shape = batch_shape[3:]
control_shape = tuple(
sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
)
deformation_fields = [
np.random.uniform(-1, 1, control_shape),
np.random.uniform(-1, 1, control_shape)
]
deformation_fields = [
resize(df, shape, order=3)[None] for df in deformation_fields
]
noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
noise = torch.from_numpy(noise)
return {"noise": noise}

def __call__(self, input, params=None):
assert(len(input.shape) == 5)
if params is None:
params = self.generate_parameters(input.shape)
self._params = params

noise = params["noise"]
mode = "bilinear" if (self.flags["interpolation"] == 1).all() else "nearest"
noise_ch = noise.expand(input.shape[1], -1, -1, -1)
input_transformed = []
for i, x in enumerate(torch.unbind(input, dim=0)):
x_transformed = kornia.geometry.transform.elastic_transform2d(
x, noise_ch, sigma=self.flags["sigma"],
alpha=self.flags["alpha"], mode=mode,
padding_mode="reflection"
)
input_transformed.append(x_transformed)
input_transformed = torch.stack(input_transformed)
return input_transformed


class RandomElasticDeformation(kornia.augmentation.AugmentationBase2D):
def __init__(self,
control_point_spacing=1,
sigma=(4., 4.),
alpha=(32., 32.),
sigma=(4.0, 4.0),
alpha=(32.0, 32.0),
resample=kornia.constants.Resample.BILINEAR,
p=0.5,
keepdim=False,
same_on_batch=False):
super().__init__(p=p, # keepdim=keepdim,
same_on_batch=same_on_batch,
return_transform=False)
same_on_batch=same_on_batch)
if isinstance(control_point_spacing, int):
self.control_point_spacing = [control_point_spacing] * 2
else:
Expand All @@ -35,7 +93,7 @@ def __init__(self,

# TODO do we need special treatment for batches, channels > 1?
def generate_parameters(self, batch_shape):
assert len(batch_shape) == 4
assert len(batch_shape) == 4, f"{len(batch_shape)}"
shape = batch_shape[2:]
control_shape = tuple(
sh // spacing for sh, spacing in zip(shape, self.control_point_spacing)
Expand All @@ -47,25 +105,27 @@ def generate_parameters(self, batch_shape):
deformation_fields = [
resize(df, shape, order=3)[None] for df in deformation_fields
]
noise = np.concatenate(deformation_fields, axis=0)[None].astype('float32')
noise = np.concatenate(deformation_fields, axis=0)[None].astype("float32")
noise = torch.from_numpy(noise)
return {'noise': noise}

def apply_transform(self, input, params):
noise = params['noise']
mode = 'bilinear' if (self.flags['resample'] == 1).all() else 'nearest'
# NOTE mode is currently only available on my fork, need kornia PR:
# https://github.com/kornia/kornia/pull/883
return {"noise": noise}

def __call__(self, input, params=None):
if params is None:
params = self.generate_parameters(input.shape)
self._params = params
noise = params["noise"]
mode = "bilinear" if (self.flags["resample"] == 1).all() else "nearest"
return kornia.geometry.transform.elastic_transform2d(
input, noise, sigma=self.flags['sigma'], alpha=self.flags['alpha'], mode=mode
input, noise, sigma=self.flags["sigma"], alpha=self.flags["alpha"], mode=mode,
padding_mode="reflection"
)


# TODO implement 'require_halo', and estimate the halo from the transformations
# so that we can load a bigger block and cut it away
class KorniaAugmentationPipeline(torch.nn.Module):
interpolatable_torch_types = [torch.float16, torch.float32, torch.float64]
interpolatable_numpy_types = [np.dtype('float32'), np.dtype('float64')]
interpolatable_numpy_types = [np.dtype("float32"), np.dtype("float64")]

def __init__(self, *kornia_augmentations, dtype=torch.float32):
super().__init__()
Expand All @@ -74,7 +134,7 @@ def __init__(self, *kornia_augmentations, dtype=torch.float32):
self.halo = self.compute_halo()

# for now we only add a halo for the random rotation trafos and
# also don't compute the halo dynamically based on the input shape
# also don"t compute the halo dynamically based on the input shape
def compute_halo(self):
halo = None
for aug in self.augmentations:
Expand All @@ -91,11 +151,10 @@ def is_interpolatable(self, tensor):
return tensor.dtype in self.interpolatable_numpy_types

def transform_tensor(self, augmentation, tensor, interpolatable, params=None):
interpolating = 'interpolation' in getattr(augmentation, 'flags', [])
interpolating = "interpolation" in getattr(augmentation, "flags", [])
if interpolating:
resampler = kornia.constants.Resample.get('BILINEAR' if interpolatable else 'NEAREST')
augmentation.flags['interpolation'] = torch.tensor(resampler.value)

resampler = kornia.constants.Resample.get("BILINEAR" if interpolatable else "NEAREST")
augmentation.flags["interpolation"] = torch.tensor(resampler.value)
transformed = augmentation(tensor, params)
return transformed, augmentation._params

Expand Down Expand Up @@ -130,6 +189,7 @@ def halo(self, shape):
"RandomRotation3D": {"degrees": (90, 90, 90)},
"RandomVerticalFlip": {},
"RandomVerticalFlip3D": {},
"RandomElasticDeformation3D": {"alpha": [5, 5], "sigma": [30, 30]}
}


Expand All @@ -149,6 +209,14 @@ def halo(self, shape):
]


def create_augmentation(trafo):
assert trafo in dir(kornia.augmentation) or trafo in globals().keys(), f"Transformation {trafo} not defined"
if trafo in dir(kornia.augmentation):
return getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])

return globals()[trafo](**AUGMENTATIONS[trafo])


def get_augmentations(ndim=2,
transforms=None,
dtype=torch.float32):
Expand All @@ -160,10 +228,7 @@ def get_augmentations(ndim=2,
transforms = DEFAULT_3D_AUGMENTATIONS
else:
transforms = DEFAULT_ANISOTROPIC_AUGMENTATIONS
transforms = [
getattr(kornia.augmentation, trafo)(**AUGMENTATIONS[trafo])
for trafo in transforms
]
transforms = [create_augmentation(trafo) for trafo in transforms]

assert all(isinstance(trafo, kornia.augmentation.base._AugmentationBase)
for trafo in transforms)
Expand Down

0 comments on commit 88616f4

Please sign in to comment.