In [None]:
from nb_200 import *

## Get the data

In [None]:
class PetsData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = CategoryGetter
    
    def get_source(self):        return untar_data(URLs.PETS)
    def get_items(self, source): return get_image_files(source/"images")
    def split(self, items):      return random_splitter(items)
    def label(self, items):      return re_labeller(items, pat = r'/([^/]+)_\d+.jpg$')

In [None]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

In [None]:
data = PetsData(tfms_x=tfms).databunch()

In [None]:
data.show_batch()

## Data augmentation on the GPU

Writing batch transforms as transforms for now, can easily put them in callbacks.

In [None]:
device = torch.device('cuda',0)

In [None]:
class Cuda():
    _order = 0
    def __init__(self,device): self.device=device
    def __call__(self, b): return (b[0].to(self.device), b[1].to(self.device))

In [None]:
class DataAugTfm():
    _order = 5
    def randomize(self, bs):   pass

In [None]:
import torch.nn.functional as F

In [None]:
class AffineAndCoordTfm(DataAugTfm):
    def __init__(self, aff_tfms, coord_tfms, size=None, tfm_y=False, mode='bilinear', padding_mode='reflection'):
        self.aff_tfms,self.coord_tfms,self.mode,self.padding_mode = aff_tfms,coord_tfms,mode,padding_mode
        self.tfm_y = tfm_y
        self.size = None if size is None else (size,size) if isinstance(size, int) else tuple(size)
    
    def randomize(self, x):
        for t in self.aff_tfms+self.coord_tfms: t.randomize(x)
    
    def __call__(self, b): 
        self.randomize(b[0])
        return (self.apply_tfms(b[0]), self.apply_tfms(b[1])) if self.tfm_y else (self.apply_tfms(b[0]), b[1])
    
    def apply_tfms(self, x):
        isint = x.dtype == torch.int64
        if isint: x = x.float()
        bs = x.size(0)
        size = tuple(x.shape[-2:]) if self.size is None else self.size
        size = (x.size(0),x.size(1)) + size
        aff_m = torch.eye(3, dtype=x.dtype, device=x.device).unsqueeze(0).expand(bs, 3, 3)
        ms = [tfm() for tfm in self.aff_tfms]
        ms = [m for m in ms if m is not None]
        for m in ms: aff_m = aff_m @ m
        coords = F.affine_grid(aff_m[:,:2], size)
        coords = compose(coords, self.coord_tfms)
        mode = 'nearest' if isint else self.mode
        res = F.grid_sample(x, coords, mode=mode, padding_mode=self.padding_mode)
        return res.long() if isint else res

In [None]:
import math
from torch import stack

### Rotate

In [None]:
from torch import zeros_like as t0, ones_like as t1

In [None]:
def rotation_matrix(thetas):
    thetas.mul_(math.pi/180)
    rows = [stack([thetas.cos(),  thetas.sin(), t0(thetas)], dim=1),
            stack([-thetas.sin(), thetas.cos(), t0(thetas)], dim=1),
            stack([t0(thetas),    t0(thetas),   t1(thetas)], dim=1)]
    return stack(rows, dim=1)

In [None]:
from torch.distributions.bernoulli import Bernoulli

In [None]:
class RandomRotation(DataAugTfm):
    def __init__(self, degrees, p=0.5):
        self.mat,self.degrees,self.p = None,degrees,p
    
    def randomize(self, x):
        mask = x.new_empty(x.size(0)).bernoulli_(self.p)
        thetas = x.new_empty(x.size(0)).uniform_(-self.degrees,self.degrees)
        self.mat = rotation_matrix(thetas * mask)
    
    def __call__(self): return self.mat

In [None]:
tfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(30.)], [])]

In [None]:
batch = next(iter(data.train_dl))
b = (batch[0][0].unsqueeze(0).expand(16,*batch[0].shape[1:]), batch[1][0].unsqueeze(0).expand(16))

In [None]:
def show_batch(data, batch, items=9, **kwargs):
    x,y = batch[0].cpu(),batch[1].cpu()
    data.train_ds.x.item_get.show_xys(
        [data.train_ds.x.deproc(x[i]) for i in range(items)], 
        [data.train_ds.y.deproc(y[i]) for i in range(items)], 
        data.train_ds.y.item_get,
    **kwargs)

In [None]:
batch_tfmed = compose(b, tfms)

In [None]:
show_batch(data, batch_tfmed)

### Warp

In [None]:
def find_coeffs(p1, p2):
    matrix = []
    p = p1[:,0,0]
    #The equations we'll need to solve.
    for i in range(p1.shape[1]):
        matrix.append(stack([p2[:,i,0], p2[:,i,1], t1(p), t0(p), t0(p), t0(p), -p1[:,i,0]*p2[:,i,0], -p1[:,i,0]*p2[:,i,1]]))
        matrix.append(stack([t0(p), t0(p), t0(p), p2[:,i,0], p2[:,i,1], t1(p), -p1[:,i,1]*p2[:,i,0], -p1[:,i,1]*p2[:,i,1]]))
    #The 8 scalars we seek are solution of AX = B
    A = stack(matrix).permute(2, 0, 1)
    B = p1.view(p1.shape[0], 8, 1)
    return torch.solve(B,A)[0]

In [None]:
def apply_perspective(coords, coeffs):
    coeffs = torch.cat([coeffs, t1(coeffs[:,:1])], dim=1).view(coeffs.shape[0], 3,3)
    coords = torch.einsum('bijk,blk -> bijl', coords, coeffs[...,:2]) + coeffs[:,None,None,:,2]
    coords.div_(coords[...,2].unsqueeze(-1))
    return coords[...,:2]

In [None]:
class RandomWarp(DataAugTfm):
    def __init__(self, magnitude, p=0.5):
        self.coeffs,self.magnitude,self.p = None,magnitude,p
    
    def randomize(self, x):
        mask = x.new_empty(x.size(0)).bernoulli_(self.p)
        up_t = x.new_empty(x.size(0)).uniform_(-self.magnitude,self. magnitude)
        lr_t = x.new_empty(x.size(0)).uniform_(-self.magnitude,self. magnitude)
        orig_pts = torch.tensor([[-1,-1], [-1,1], [1,-1], [1,1]], dtype=x.dtype, device=x.device)
        orig_pts = orig_pts.unsqueeze(0).expand(x.size(0),4,2)
        targ_pts = stack([stack([-1-up_t, -1-lr_t]), stack([-1+up_t, 1+lr_t]), 
                               stack([1+up_t, -1+lr_t]), stack([1-up_t, 1-lr_t])])
        targ_pts = targ_pts.permute(2,0,1)
        self.coeffs = find_coeffs(orig_pts, targ_pts)
    
    def __call__(self, x): 
        return None if self.coeffs is None else apply_perspective(x, self.coeffs)

In [None]:
tfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(10.)], [RandomWarp(0.2)])]

In [None]:
batch_tfmed = compose(b, tfms)

In [None]:
show_batch(data, batch_tfmed)

### With y transformed

In [None]:
class CamvidData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = SegmentMaskGetter
    
    def get_source(self):        return untar_data(URLs.CAMVID_TINY)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        path_lbl = self.source/'labels'
        codes = np.loadtxt(self.source/'codes.txt', dtype=str)
        return func_labeller(items, lambda x: path_lbl/f'{x.stem}_P{x.suffix}')

In [None]:
tfms_x = [make_rgb,  ResizeFixed(128), to_byte_tensor, to_float_tensor]
tfms_y = [make_mask, ResizeFixed(128, mode=PIL.Image.NEAREST), to_byte_tensor, to_long_tensor]

In [None]:
data = CamvidData(tfms_x=tfms_x, tfms_y=tfms_y).databunch(bs=16)

In [None]:
data.show_batch()

In [None]:
batch = next(iter(data.train_dl))
b = (batch[0][0].unsqueeze(0).expand(16,*batch[0].shape[1:]), batch[1][0].unsqueeze(0).expand(16,*batch[1].shape[1:]))

In [None]:
tfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(10.)], [RandomWarp(0.2)], mode='nearest', tfm_y=True)]

In [None]:
batch_tfmed = compose(b, tfms)

In [None]:
show_batch(data, batch_tfmed)