# Data augmentation on the GPU

In [None]:
from nb_200 import *
import pickle

## Get the data

In [None]:
class PetsData(DataBlock):
    type_cls = (Image, Category)
    def get_source(self):        return untar_data(URLs.PETS)/"images"
    def get_items(self, source): return [get_image_files(source)[0]]*100
    def split(self, items):      return random_splitter(items)
    def label_func(self, item):  return re_labeller(pat = r'/([^/]+)_\d+.jpg$')(item)

In [None]:
class CamvidData(DataBlock):
    type_cls = (Image, SegmentMask)
    def get_source(self):        return untar_data(URLs.CAMVID_TINY)      
    def get_items(self, source): 
        self.path_lbl = source/'labels'
        return [get_image_files(source/'images')[0]] * 100
    def split(self, items):      return random_splitter(items)
    def label_func(self, item):  return self.path_lbl/f'{item.stem}_P{item.suffix}'

In [None]:
class BiwiData(DataBlock):
    type_cls = (Image, Points)
    def get_source(self):        return untar_data(URLs.BIWI_SAMPLE)      
    def get_items(self, source): return [get_image_files(source/'images')[0]] * 100
    def split(self, items):      return random_splitter(items)
    def label_func(self, item):  return [[0, 0], [120, 0], [0, 160], [120,160]]

In [None]:
class CocoData(DataBlock):
    type_cls = (Image,BBox)
    
    def get_source(self):        return untar_data(URLs.COCO_TINY)      
    def get_items(self, source): 
        images, lbl_bbox = get_annotations(source/'train.json')
        self.img2bbox = dict(zip(images, lbl_bbox))
        return [get_image_files(source/'train')[18]]* 100
    def split(self, items):      return random_splitter(items)
    def label_func(self, item):  return self.img2bbox[item.name]
    
    def databunch(self, ds_tfms=None, dl_tfms=None, bs=64, tfm_kwargs=None, **kwargs):
        kwargs['collate_fn'] = bb_pad_collate
        return super().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=bs, tfm_kwargs=tfm_kwargs, **kwargs)

In [None]:
ds_tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor(), ToFloatTensor()]

In [None]:
#data = PetsData().databunch(ds_tfms=ds_tfms, bs=16)
#data = CamvidData().databunch(ds_tfms=ds_tfms, bs=16)
#data = BiwiData().databunch(ds_tfms=ds_tfms, bs=16)
data = CocoData().databunch(ds_tfms=ds_tfms, bs=16)

In [None]:
data.show_batch()

## Affine and coords

Writing batch transforms as transforms for now, can easily put them in callbacks.

In [None]:
device = torch.device('cuda',0)

In [None]:
from fastai.torch_core import to_device, to_cpu

In [None]:
class Cuda(Transform):
    _order = 0
    def __init__(self,device): self.device=device
    def __call__(self, b, tfm_y=TfmY.No): return to_device(b, self.device)
    def decode(self, b): return to_cpu(b)

In [None]:
import torch.nn.functional as F

In [None]:
class AffineAndCoordTfm(ImageTransform):
    _data_aug=True
    def __init__(self, aff_tfms, coord_tfms, size=None, mode='bilinear', padding_mode='reflection'):
        self.aff_tfms,self.coord_tfms,self.mode,self.padding_mode = aff_tfms,coord_tfms,mode,padding_mode
        self.size = None if size is None else (size,size) if isinstance(size, int) else tuple(size)
    
    def randomize(self):
        for t in self.aff_tfms+self.coord_tfms: t.randomize(self.x)
    
    def _get_affine_mat(self):
        aff_m = torch.eye(3, dtype=self.x.dtype, device=self.x.device)
        aff_m = aff_m.unsqueeze(0).expand(self.x.size(0), 3, 3)
        ms = [tfm() for tfm in self.aff_tfms]
        ms = [m for m in ms if m is not None]
        for m in ms: aff_m = aff_m @ m
        return aff_m
    
    def apply(self, x):
        bs = x.size(0)
        size = tuple(x.shape[-2:]) if self.size is None else self.size
        size = (bs,x.size(1)) + size
        coords = F.affine_grid(self._get_affine_mat()[:,:2], size)
        coords = apply_all(coords, self.coord_tfms)
        return F.grid_sample(x, coords, mode=self.mode, padding_mode=self.padding_mode)
    
    def apply_mask(self, y):
        self.old_mode,self.mode = self.mode,'nearest'
        res = self.apply(y.float())
        self.mode = self.old_mode
        return res.long()
    
    def apply_point(self, y):
        m = self._get_affine_mat()[:,:2]
        y = y @ m[:,:,:2] + m[:,:,2].unsqueeze(1)
        return apply_all(y, self.coord_tfms, filter_kwargs=True, invert=True)
    
    def apply_bbox(self, y):
        bbox,label = y
        bs,n = bbox.shape[:2]
        pnts = stack([bbox[...,:2], stack([bbox[...,0],bbox[...,3]],dim=2), 
                      stack([bbox[...,2],bbox[...,1]],dim=2), bbox[...,2:]], dim=2)
        pnts = self.apply_point(pnts.view(bs, 4*n, 2))
        pnts = pnts.view(bs, n, 4, 2)
        tl,dr = pnts.min(dim=2)[0],pnts.max(dim=2)[0]
        return [torch.cat([tl, dr], dim=2), label]

### Rotate

In [None]:
import math
from torch import stack, zeros_like as t0, ones_like as t1
from torch.distributions.bernoulli import Bernoulli

In [None]:
def rotation_matrix(thetas):
    thetas.mul_(math.pi/180)
    rows = [stack([thetas.cos(),  thetas.sin(), t0(thetas)], dim=1),
            stack([-thetas.sin(), thetas.cos(), t0(thetas)], dim=1),
            stack([t0(thetas),    t0(thetas),   t1(thetas)], dim=1)]
    return stack(rows, dim=1)

In [None]:
class DataAugTfm():
    _order = 0
    def randomize(self, x): pass

In [None]:
def mask_tensor(x, p=0.5, neutral=0.):
    if p==1.: return x
    if neutral != 0: x.add_(-neutral)
    mask = x.new_empty(*x.size()).bernoulli_(p)
    x.mul_(mask)
    return x.add_(neutral) if neutral != 0 else x

In [None]:
def masked_uniform(x, a, b, *sz, p=0.5, neutral=0.):
    return mask_tensor(x.new_empty(*sz).uniform_(a,b), p=p, neutral=neutral)

In [None]:
class RandomRotation():
    def __init__(self, degrees, p=0.5):
        self.mat,self.degrees,self.p = None,degrees,p
    
    def randomize(self, x):
        self.mat = rotation_matrix(masked_uniform(x, -self.degrees,self.degrees, x.size(0), p=self.p))
    
    def __call__(self): return self.mat

In [None]:
dl_tfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(30.)], [])]

In [None]:
#data = PetsData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
#data = CamvidData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
#data = BiwiData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data = CocoData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)

In [None]:
data.show_batch()

### Warp

In [None]:
def find_coeffs(p1, p2):
    matrix = []
    p = p1[:,0,0]
    #The equations we'll need to solve.
    for i in range(p1.shape[1]):
        matrix.append(stack([p2[:,i,0], p2[:,i,1], t1(p), t0(p), t0(p), t0(p), -p1[:,i,0]*p2[:,i,0], -p1[:,i,0]*p2[:,i,1]]))
        matrix.append(stack([t0(p), t0(p), t0(p), p2[:,i,0], p2[:,i,1], t1(p), -p1[:,i,1]*p2[:,i,0], -p1[:,i,1]*p2[:,i,1]]))
    #The 8 scalars we seek are solution of AX = B
    A = stack(matrix).permute(2, 0, 1)
    B = p1.view(p1.shape[0], 8, 1)
    return torch.solve(B,A)[0]

In [None]:
def apply_perspective(coords, coeffs):
    sz = coords.shape
    coords = coords.view(sz[0], -1, 2)
    coeffs = torch.cat([coeffs, t1(coeffs[:,:1])], dim=1).view(coeffs.shape[0], 3,3)
    coords = coords @ coeffs[...,:2].transpose(1,2) + coeffs[...,2].unsqueeze(1)
    coords.div_(coords[...,2].unsqueeze(-1))
    return coords[...,:2].view(*sz)

In [None]:
class RandomWarp(DataAugTfm):
    def __init__(self, magnitude, p=0.5):
        self.coeffs,self.magnitude,self.p = None,magnitude,p
    
    def randomize(self, x):
        up_t = masked_uniform(x, -self.magnitude, self.magnitude, x.size(0), p=self.p)
        lr_t = masked_uniform(x, -self.magnitude, self.magnitude, x.size(0), p=self.p)
        orig_pts = torch.tensor([[-1,-1], [-1,1], [1,-1], [1,1]], dtype=x.dtype, device=x.device)
        self.orig_pts = orig_pts.unsqueeze(0).expand(x.size(0),4,2)
        targ_pts = stack([stack([-1-up_t, -1-lr_t]), stack([-1+up_t, 1+lr_t]), 
                          stack([ 1+up_t, -1+lr_t]), stack([ 1-up_t, 1-lr_t])])
        self.targ_pts = targ_pts.permute(2,0,1)
    
    def __call__(self, x, invert=False): 
        coeffs = find_coeffs(self.targ_pts, self.orig_pts) if invert else find_coeffs(self.orig_pts, self.targ_pts)
        return apply_perspective(x, coeffs)

In [None]:
dl_tfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(10.)], [RandomWarp(0.2)])]

In [None]:
#data = PetsData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
#data = CamvidData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data = BiwiData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
#data = CocoData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)

In [None]:
data.show_batch()

## Ligthing transforms

In [None]:
# export
def logit(x):
    "Logit of `x`, clamped to avoid inf."
    x = x.clamp(1e-7, 1-1e-7)
    return -(1/x-1).log()

In [None]:
class LightingTransform(ImageTransform):
    _data_aug=True
    def __init__(self, tfms): self.tfms=listify(tfms)
    def randomize(self): 
        for t in self.tfms: t.randomize(self.x)
    
    def apply(self,x):       return torch.sigmoid(apply_all(logit(x), self.tfms))
    def apply_mask(self, x): return x

In [None]:
from math import log
def masked_log_uniform(x, a, b, *sz, p=0.5, neutral=0.):
    return torch.exp(masked_uniform(x, log(a), log(b), *sz, p=p, neutral=neutral))

In [None]:
class Brightness(DataAugTfm):
    "Apply `change` in brightness of image `x`."
    def __init__(self, max_lighting=0.2, p=0.75): 
        self.p = p
        self.range = (0.5*(1-max_lighting), 0.5*(1+max_lighting))
    def randomize(self, x): 
        self.change = masked_uniform(x, *self.range, x.size(0), *([1]*(x.dim()-1)), p=self.p, neutral=0.5)
    def __call__(self, x): return x.add_(self.change)
    
class Contrast(DataAugTfm):
    "Apply `change` in brightness of image `x`."
    def __init__(self, max_lighting=0.2, p=0.75): 
        self.p = p
        self.range = (1-max_lighting, 1/(1-max_lighting))
    def randomize(self, x): 
        self.change = masked_log_uniform(x, *self.range, x.size(0), *([1]*(x.dim()-1)), p=self.p)
    def __call__(self, x): return x.mul_(self.change)

In [None]:
dl_tfms = [Cuda(device), LightingTransform([Brightness(1), Contrast(0.5)])]

In [None]:
#data = PetsData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
#data = CamvidData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
#data = BiwiData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)
data = CocoData().databunch(ds_tfms=ds_tfms, dl_tfms=dl_tfms, bs=16)

In [None]:
data.show_batch()