diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index 8772f625..277196f0 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -92,7 +92,7 @@ class RandomContrast(): """ Adjust contrast by scaling image to `mean + alpha * (image - mean)`. """ - def __init__(self, alpha=(0.05, 4), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}): + def __init__(self, alpha=(0.5, 2), mean=0.5, clip_kwargs={'a_min': 0, 'a_max': 1}): self.alpha = alpha self.mean = mean self.clip_kwargs = clip_kwargs @@ -109,12 +109,15 @@ class AdditiveGaussianNoise(): """ Add random Gaussian noise to image. """ - def __init__(self, scale=(0.0, 0.75)): + def __init__(self, scale=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}): self.scale = scale + self.clip_kwargs = clip_kwargs def __call__(self, img): std = np.random.uniform(self.scale[0], self.scale[1]) gaussian_noise = np.random.normal(0, std, size=img.shape) + if self.clip_kwargs: + return np.clip(img + gaussian_noise, 0, 1) return img + gaussian_noise @@ -124,7 +127,7 @@ class AdditivePoissonNoise(): """ # TODO: not sure if Poisson noise like this does make sense # for data that is already normalized - def __init__(self, lam=(0.0, 0.3), clip_kwargs={'a_min': 0, 'a_max': 1}): + def __init__(self, lam=(0.0, 0.1), clip_kwargs={'a_min': 0, 'a_max': 1}): self.lam = lam self.clip_kwargs = clip_kwargs @@ -140,7 +143,7 @@ class PoissonNoise(): """ Add random data-dependent Poisson noise to image. """ - def __init__(self, multiplier=(1.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}): + def __init__(self, multiplier=(5.0, 10.0), clip_kwargs={'a_min': 0, 'a_max': 1}): self.multiplier = multiplier self.clip_kwargs = clip_kwargs @@ -160,7 +163,7 @@ class GaussianBlur(): """ Blur the image. """ - def __init__(self, kernel_size=(2, 24), sigma=(0, 5)): + def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)): self.kernel_size = kernel_size self.sigma = sigma @@ -201,13 +204,13 @@ def get_raw_transform(normalizer=standardize, augmentation1=None, augmentation2= # The default values are made for an image with pixel values in # range [0, 1]. That the image is in this range is ensured by an # initial normalizations step. -def get_default_mean_teacher_augmentations(p=0.5): +def get_default_mean_teacher_augmentations(p=0.3): norm = normalize aug1 = transforms.Compose([ normalize, transforms.RandomApply([GaussianBlur()], p=p), - transforms.RandomApply([AdditiveGaussianNoise()], p=p), - transforms.RandomApply([PoissonNoise()], p=p) + transforms.RandomApply([PoissonNoise()], p=p/2), + transforms.RandomApply([AdditiveGaussianNoise()], p=p/2), ]) aug2 = transforms.RandomApply([RandomContrast()], p=p) return get_raw_transform(