# Image transformers

In [101]:
import torchvision.transforms.functional as TF
import torch
torch.random.manual_seed(42)

class Transformer:
    def __init__(self):
        pass
    def __call__(self):
        pass
    def _update(self):
        pass

class Resize(Transformer):
    '''
    Resize image and mask
    Args:
        size - tulpe of int
        image - PIL image
    '''
    def __init__(self, size):
        self.size = size

    def __call__(self, image):
        return TF.resize(image, self.size)
    
class ToTensor(Transformer):
    '''
    Convert PIL image and mask in sample to Tensors.
    Args:
        image - PIL image
    '''
    def __call__(self, image):
        return TF.to_tensor(image)
    
class RandAffine(Transformer):
    '''
    Apply affine transformation on the image keeping image center invariant.
    Args:
        image - PIL image
    '''
    def __init__(self):
        self._update()
        
    def __call__(self, image):
        return TF.affine(image, self.angle, self.translate, 
                         self.scale, self.shear)

    def _update(self):
        # generate parameters
        self.angle = torch.randint(-180,180, (1,1)).item()
        self.translate = (torch.randint(0,5,(1,1)).item(), 
                          torch.randint(0,5,(1,1)).item())
        self.scale = torch.div(torch.randint(50,200, (1,1)), 100.).item()
        self.shear = torch.randn(1).item()

class RandHFlip(Transformer):
    '''
    Horizontally flip the given PIL Image or torch Tensor.
    Args:
        image - PIL image
    '''
    def __init__(self, p=0.5):
        self.p = p
        self._update()
        
    def __call__(self, image):
        if self.rand_val >= self.p:
            return TF.hflip(image)
        else:
            return image
    
    def _update(self):
        self.rand_val = torch.rand(1).item()

In [81]:
class AdjBrightness(Transformer):
    '''
    Adjust brightness of an Image.
    Args:
        image - PIL image
    '''
    def __init__(self, vmin=0.5, vmax=1.5):
        self.vmin = vmin
        self.vmax = vmax
        self._update()
        
    def __call__(self, image):        
        return TF.adjust_brightness(image, self.factor)
        
    def _update(self):
        # float random generator
        self.factor = (self.vmin - self.vmax) * torch.rand(1) + self.vmax
    
class AdjContrast(Transformer):
    '''
    Adjust contrast of an Image.
    Args:
        image - PIL image
    '''
    def __init__(self, vmin=0.5, vmax=1.5):
        self.vmin = vmin
        self.vmax = vmax
        self._update()
        
    def __call__(self, image):
        return TF.adjust_contrast(image, self.factor)
    
    def _update(self):
        # float random generator
        self.factor = (self.vmin - self.vmax) * torch.rand(1) + self.vmax
    
class AdjSaturation(Transformer):
    '''
    Adjust saturation of an Image.
    Args:
        image - PIL image
    '''
    def __init__(self, vmin=0.5, vmax=1.5):
        self.vmin = vmin
        self.vmax = vmax
        self._update()

    def __call__(self, image):        
        return TF.adjust_saturation(image, self.factor)

    def _update(self):
        # float random generator
        self.factor = (self.vmin - self.vmax) * torch.rand(1) + self.vmax

    
class AdjHue(Transformer):
    '''
    Adjust hue of an Image.
    Args:
        image - PIL image
    '''
    def __init__(self, vmin=-0.5, vmax=0.5):
        self.vmin = vmin
        self.vmax = vmax
        self._update()

    def __call__(self, image):
        return TF.adjust_hue(image, self.factor)
    
    def _update(self):
        # float random generator
        self.factor = (self.vmin - self.vmax) * torch.rand(1) + self.vmax
    