In [None]:
from nb_200 import *

## Get the data

In [None]:
class PetsData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = CategoryGetter
    
    def get_source(self):        return untar_data(URLs.PETS)
    def get_items(self, source): return get_image_files(source/"images")
    def split(self, items):      return random_splitter(items)
    def label(self, items):      return re_labeller(items, pat = r'/([^/]+)_\d+.jpg$')

In [None]:
class PlanetData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = MultiCategoryGetter
    
    def get_source(self):        
        self.path = untar_data(URLs.PLANET_SAMPLE)
        return pd.read_csv(path/'labels.csv')
    def get_items(self, source): return read_column(source, 'image_name', prefix=f'{self.path}/train/', suffix='.jpg')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      return read_column(self.source, 'tags', delim=' ')

In [None]:
class CamvidData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = SegmentMaskGetter
    
    def get_source(self):        return untar_data(URLs.CAMVID_TINY)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        path_lbl = self.source/'labels'
        codes = np.loadtxt(self.source/'codes.txt', dtype=str)
        return func_labeller(items, lambda x: path_lbl/f'{x.stem}_P{x.suffix}')

In [None]:
class BiwiData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = PointsGetter
    
    def get_source(self):        return untar_data(URLs.BIWI_SAMPLE)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        fn2ctr = pickle.load(open(self.source/'centers.pkl', 'rb'))
        return func_labeller(items, lambda o:fn2ctr[o.name])

In [None]:
class CocoData(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = BBoxGetter
    
    def get_source(self):        return untar_data(URLs.COCO_TINY)      
    def get_items(self, source): return get_image_files(source/'train')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        images, lbl_bbox = get_annotations(self.source/'train.json')
        img2bbox = dict(zip(images, lbl_bbox))
        return func_labeller(items, lambda o:img2bbox[o.name])
    
    def databunch(self, bs=64, **kwargs):
        kwargs['collate_fn'] = bb_pad_collate
        return super().databunch(bs=bs, **kwargs)

In [None]:
tfms = [DecodeImg(), ResizeFixed(128), ToByteTensor(), ToFloatTensor()]

In [None]:
data = PetsData(tfms=tfms).databunch()

In [None]:
data.show_batch()

## Data augmentation on the GPU

Writing batch transforms as transforms for now, can easily put them in callbacks.

In [None]:
device = torch.device('cuda',0)

In [None]:
def to_device(x, device):
    return [to_device(o, device) for o in x] if isinstance(x, (list,tuple)) else x.to(device)

In [None]:
class Cuda(Transform):
    _order = 0
    def __init__(self,device): self.device=device
    def __call__(self, b, tfm_y=TfmY.No): return to_device(b, self.device)

In [None]:
import torch.nn.functional as F

In [None]:
class AffineAndCoordTfm(Transform):
    def __init__(self, aff_tfms, coord_tfms, size=None, mode='bilinear', padding_mode='reflection'):
        self.aff_tfms,self.coord_tfms,self.mode,self.padding_mode = aff_tfms,coord_tfms,mode,padding_mode
        self.size = None if size is None else (size,size) if isinstance(size, int) else tuple(size)
    
    def randomize(self):
        for t in self.aff_tfms+self.coord_tfms: t.randomize(self.x)
    
    def _get_affine_mat(self):
        aff_m = torch.eye(3, dtype=self.x.dtype, device=self.x.device)
        aff_m = aff_m.unsqueeze(0).expand(self.x.size(0), 3, 3)
        ms = [tfm() for tfm in self.aff_tfms]
        ms = [m for m in ms if m is not None]
        for m in ms: aff_m = aff_m @ m
        return aff_m
    
    def apply(self, x):
        bs = x.size(0)
        size = tuple(x.shape[-2:]) if self.size is None else self.size
        size = (bs,x.size(1)) + size
        coords = F.affine_grid(self._get_affine_mat()[:,:2], size)
        coords = compose(coords, self.coord_tfms)
        return F.grid_sample(x, coords, mode=self.mode, padding_mode=self.padding_mode)
    
    def apply_mask(self, y):
        self.old_mode,self.mode = self.mode,'nearest'
        res = self.apply(y.float())
        self.mode = self.old_mode
        return res.long()
    
    def apply_point(self, y):
        m = self._get_affine_mat()[:,:2]
        y = y @ m[:,:,:2] + m[:,:,2].unsqueeze(1)
        return compose(y, self.coord_tfms, invert=True)
    
    def apply_bbox(self, y):
        bbox,label = y
        bs,n = bbox.shape[:2]
        pnts = stack([bbox[...,:2], stack([bbox[...,0],bbox[...,3]],dim=2), 
                      stack([bbox[...,2],bbox[...,1]],dim=2), bbox[...,2:]], dim=2)
        pnts = self.apply_point(pnts.view(bs, 4*n, 2))
        pnts = pnts.view(bs, n, 4, 2)
        tl,dr = pnts.min(dim=2)[0],pnts.max(dim=2)[0]
        return [torch.cat([tl, dr], dim=2), label]

### Rotate

In [None]:
import math
from torch import stack, zeros_like as t0, ones_like as t1
from torch.distributions.bernoulli import Bernoulli

In [None]:
def rotation_matrix(thetas):
    thetas.mul_(math.pi/180)
    rows = [stack([thetas.cos(),  thetas.sin(), t0(thetas)], dim=1),
            stack([-thetas.sin(), thetas.cos(), t0(thetas)], dim=1),
            stack([t0(thetas),    t0(thetas),   t1(thetas)], dim=1)]
    return stack(rows, dim=1)

In [None]:
class DataAugTfm():
    _order = 0
    def randomize(self, x): pass

In [None]:
class RandomRotation():
    def __init__(self, degrees, p=0.5):
        self.mat,self.degrees,self.p = None,degrees,p
    
    def randomize(self, x):
        mask = x.new_empty(x.size(0)).bernoulli_(self.p)
        thetas = x.new_empty(x.size(0)).uniform_(-self.degrees,self.degrees)
        self.mat = rotation_matrix(thetas * mask)
    
    def __call__(self): return self.mat

In [None]:
btfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(30.)], [])]

In [None]:
class BiwiDataCorner(DataBlock):
    get_x_cls = ImageGetter
    get_y_cls = PointsGetter
    
    def get_source(self):        return untar_data(URLs.BIWI_SAMPLE)      
    def get_items(self, source): return get_image_files(source/'images')
    def split(self, items):      return random_splitter(items)
    def label(self, items):      
        fn2ctr = pickle.load(open(self.source/'centers.pkl', 'rb'))
        return func_labeller(items, lambda o:[[0,0], [120,0]])#, [0,160], [120,160]])

In [None]:
#data = PetsData(tfms=tfms).databunch(bs=16)
#data = CamvidData(tfms=tfms).databunch(bs=16)
#data = BiwiData(tfms=tfms).databunch(bs=16)
#data = BiwiDataCorner(tfms=tfms).databunch(bs=16)
data = CocoData(tfms=tfms).databunch(bs=16)

In [None]:
def expand(sample, bs=16):
    if isinstance(sample, (list,tuple)): return [expand(x, bs=bs) for x in sample]
    return sample.unsqueeze(0).expand(bs, *sample.shape)

In [None]:
batch = next(iter(data.train_dl))
b = expand(grab_idx(batch, 0))

In [None]:
batch_tfmed = compose(b, btfms, tfm_y=data.train_ds.y.item_get.default_tfm)

In [None]:
data.show_batch(batch=batch_tfmed)

### Warp

In [None]:
def find_coeffs(p1, p2):
    matrix = []
    p = p1[:,0,0]
    #The equations we'll need to solve.
    for i in range(p1.shape[1]):
        matrix.append(stack([p2[:,i,0], p2[:,i,1], t1(p), t0(p), t0(p), t0(p), -p1[:,i,0]*p2[:,i,0], -p1[:,i,0]*p2[:,i,1]]))
        matrix.append(stack([t0(p), t0(p), t0(p), p2[:,i,0], p2[:,i,1], t1(p), -p1[:,i,1]*p2[:,i,0], -p1[:,i,1]*p2[:,i,1]]))
    #The 8 scalars we seek are solution of AX = B
    A = stack(matrix).permute(2, 0, 1)
    B = p1.view(p1.shape[0], 8, 1)
    return torch.solve(B,A)[0]

In [None]:
def apply_perspective(coords, coeffs):
    sz = coords.shape
    coords = coords.view(sz[0], -1, 2)
    coeffs = torch.cat([coeffs, t1(coeffs[:,:1])], dim=1).view(coeffs.shape[0], 3,3)
    coords = coords @ coeffs[...,:2].transpose(1,2) + coeffs[...,2].unsqueeze(1)
    coords.div_(coords[...,2].unsqueeze(-1))
    return coords[...,:2].view(*sz)

In [None]:
class RandomWarp(DataAugTfm):
    def __init__(self, magnitude, p=0.5):
        self.coeffs,self.magnitude,self.p = None,magnitude,p
    
    def randomize(self, x):
        mask = x.new_empty(x.size(0)).bernoulli_(self.p)
        up_t = x.new_empty(x.size(0)).uniform_(-self.magnitude,self. magnitude)
        lr_t = x.new_empty(x.size(0)).uniform_(-self.magnitude,self. magnitude)
        orig_pts = torch.tensor([[-1,-1], [-1,1], [1,-1], [1,1]], dtype=x.dtype, device=x.device)
        self.orig_pts = orig_pts.unsqueeze(0).expand(x.size(0),4,2)
        targ_pts = stack([stack([-1-up_t, -1-lr_t]), stack([-1+up_t, 1+lr_t]), 
                               stack([1+up_t, -1+lr_t]), stack([1-up_t, 1-lr_t])])
        self.targ_pts = targ_pts.permute(2,0,1)
    
    def __call__(self, x, invert=False): 
        coeffs = find_coeffs(self.targ_pts, self.orig_pts) if invert else find_coeffs(self.orig_pts, self.targ_pts)
        return apply_perspective(x, coeffs)

In [None]:
btfms = [Cuda(device), AffineAndCoordTfm([RandomRotation(10.)], [RandomWarp(0.2)])]

In [None]:
batch_tfmed = compose(b, btfms, tfm_y=data.train_ds.y.item_get.default_tfm)

In [None]:
data.show_batch(batch=batch_tfmed)