In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_001b import *
from PIL import Image
import PIL, matplotlib.pyplot as plt
from torch.utils.data import Dataset
from operator import itemgetter, attrgetter

# Temp storage

In [None]:
def is_listy(x): return isinstance(x, (list,tuple))

# Carvana

In [None]:
PATH = Path('data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X = PATH/'train-128'
PATH_Y = PATH/'train_masks-128'

## Convert and resize data

In [None]:
PATH_PNG.mkdir(exist_ok=True)
PATH_X.mkdir(exist_ok=True)
PATH_Y.mkdir(exist_ok=True)

In [None]:
def convert_img(fn): Image.open(fn).save(PATH_PNG/f'{fn.name[:-4]}.png')

In [None]:
files = list((PATH/'train_masks').iterdir())
with ThreadPoolExecutor(8) as e: e.map(convert_img, files)

In [None]:
def resize_img(fn, dirname):
    Image.open(fn).resize((128,128)).save((fn.parent.parent)/dirname/fn.name)

In [None]:
files = list(PATH_PNG).iterdir())
with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train_masks-128'), files)

In [None]:
files = list((PATH/'train').iterdir())
with ThreadPoolExecutor(8) as e: e.map(partial(resize_img, dirname='train-128'), files)

## Basic transforms

In [None]:
img_f = next(PATH_X.iterdir())
img_x = open_image(img_f)
show_image(img_x)

In [None]:
def get_y_fn(x_fn): return f'{x_fn[:-4]}_mask.png'

In [None]:
img_y_f = PATH_Y/get_y_fn(img_f.name)
img_y = open_image(img_y_f)
show_image(img_y)

In [None]:
def x(): return open_image(img_f)
def y(): return open_image(img_y_f)

In [None]:
tfms = [flip_lr_tfm(p=0.5),
        rotate_tfm(degrees=(-10,10.), p=0.25),
        zoom_tfm(scale=(0.8,1.2), p=0.25),
        contrast_tfm(scale=(0.8,1.2)),
        brightness_tfm(change=(0.4,0.6))
]

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_pipeline(x(), tfms), ax)

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_pipeline(y(), tfms), ax)

# Dependent var transforms

## Rotation

In [None]:
def xy(): return x(),y()

In [None]:
resolve_args(brightness, change=(0.4,0.6))

In [None]:
def rotate_rand(x, y=None, smooth_y=True):
    args = resolve_args(rotate, degrees=(-45,45.))
    m = rotate(**args)
    x = do_affine(x, m)
    if y is None: return x
    
    y = do_affine(y, m)
    if not smooth_y: torch.round_(y)
    return x, y

In [None]:
imgx,imgy = rotate_rand(*xy(), smooth_y=False)
assert(torch.any((imgy>0.) & (imgy<1.)) == 0)

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = rotate_rand(*xy(), smooth_y=False)
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = rotate_rand(x(),x())
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,6))
for ax in axes: show_image(rotate_rand(x()), ax)

## Affine transforms

In [None]:
def do_affine(img_x, img_y=None, m=None, funcs=None, smooth_y=True):
    if m is None: m=eye_new(img_x, 3)
    c = affine_grid(img_x,  img_x.new_tensor(m))
    c = compose(funcs)(c)
    img_x = grid_sample(img_x, c, padding='zeros')
    if img_y is None: return img_x

    img_y = grid_sample(img_y, c, padding='zeros')
    if not smooth_y: torch.round_(img_y)
    return img_x, img_y

In [None]:
def apply_pixel_tfm(func): 
    def _inner(x,y=None):
        logit_(x)
        if y is None: return func(x).sigmoid()
        logit_(y)
        x,y = func(x,y)
        return x.sigmoid(),y.sigmoid()
    
    return _inner

In [None]:
def apply_pipeline(tfms, x, y=None, smooth_y=True):
    tfms = listify(tfms)
    if len(tfms)==0: return x
    grouped_tfms = dict_groupby(tfms, lambda o: o.__annotations__['return'])
    pixel_tfms,coord_tfms,affine_tfms = map(grouped_tfms.get, TfmType)
    x = apply_pixel_tfm(compose(pixel_tfms))(x,y)
    if isinstance(x,tuple): x,y = x
    matrices = [f() for f in listify(affine_tfms)]
    return do_affine(x, y, affines_mat(x, matrices), funcs=coord_tfms, smooth_y=smooth_y)

In [None]:
tfms = [rotate_tfm(degrees=(-45,45.)), brightness_tfm(change=(0.3,0.7))]

In [None]:
imgx,imgy = apply_pipeline(tfms, *xy(), smooth_y=False)
assert(torch.any((imgy>0.) & (imgy<1.)) == 0)

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = apply_pipeline(tfms, *xy(), smooth_y=False)
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(2,4, figsize=(12,6))
for i in range(4):
    imgx,imgy = apply_pipeline(tfms, x(),x())
    show_image(imgx, axes[0][i])
    show_image(imgy, axes[1][i])

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,6))
for ax in axes: show_image(apply_pipeline(tfms, x()), ax)

In [None]:
tfms2 = [jitter_tfm(magnitude=(-0.1,0.1))]

_,axes = plt.subplots(1,4, figsize=(12,6))
for ax in axes: show_image(apply_pipeline(tfms2, x()), ax)

# Extending the training loop

## LRFind

In [None]:
import collections

In [None]:
MODEL_PATH = PATH/'models'
MODEL_PATH.mkdir(exist_ok=True)

TEMP_MODEL_NAME = 'tmp.pt'

In [None]:
def save_model(model, fname): torch.save(model.state_dict(), fname)
def load_model(model, fname): model.load_state_dict(torch.load(fname))

# Sylvain's transforms

## Add transforms

In [None]:
from enum import IntEnum

class TfmType(IntEnum):
    NO = 1
    PIXEL = 2
    COORD = 3
    CLASS = 4

In [None]:
from abc import abstractmethod

class Transform():
    
    def __init__(self, tfm_y=TfmType.NO, p=1, batch_lvl = False):
        self.tfm_y,self.p,self.batch_lvl = tfm_y,p,batch_lvl
    
    def __call__(self, x, y):
        x,y = ((self.transform(x),y) if self.tfm_y==TfmType.NO
                else self.transform(x,y) if self.tfm_y in (TfmType.PIXEL, TfmType.CLASS)
                else self.transform_coord(x,y))
        return x, y
    
    def set_device(self, device):
        if not self.batch_lvl: self.device = device
    
    def transform_coord(self, x, y):
        if self.p == 1 or np.random.rand < self.p:
            return self.transform(x),y

    def transform(self, x, y=None):
        if self.p == 1 or np.random.rand < self.p:
            x = self.do_transform(x,False)
            return (x, self.do_transform(y,True)) if y is not None else x
        else: return x,y
    
    @abstractmethod
    def do_transform(self, x, is_y): raise NotImplementedError
    #In do_transform we can save a value (angle of a random rotation for instance) in self.save_for_y that will be used
    #if is_y is True.

In [None]:
class ChannelOrder(Transform):
    #If we use PIL for data augmentation, maybe the conversion to a numpy array should be handled here?
    def __init__(self, tfm_y=TfmType.NO):
        super().__init__(tfm_y=tfm_y)
    
    def do_transform(self, x, is_y):
        if not is_y or self.tfm_y == TfmType.PIXEL: x = np.rollaxis(x, 2)
        return x

In [None]:
class Normalize(Transform):
    
    def __init__(self, means, stds, tfm_y=TfmType.NO):
        self.means,self.stds = means,stds
        super().__init__(tfm_y=tfm_y, batch_lvl=True)
    
    def set_device(self, device):
        super().set_device(device)
        if type(self.means) != torch.Tensor or not self.means.device == device:
            self.means,self.stds = map(lambda x:torch.Tensor(x).to(device), (self.means, self.stds))
    
    def do_transform(self, x, is_y):
        if not is_y or self.tfm_y == TfmType.PIXEL:
            m, s = self.means[None,:,None,None].type_as(x), self.stds[None,:,None,None].type_as(x)
            x = (x - m) / s
        return x

In [None]:
def compose(tfms, x, y):
    for tfm in tfms: x,y = tfm(x,y)
    return x,y

In [None]:
def split_one_tfms(tfms):
    ds_tfms = [tfm for tfm in tfms if not tfm.batch_lvl]
    dl_tfms = [tfm for tfm in tfms if tfm.batch_lvl]
    return ds_tfms,dl_tfms
    
def split_tfms(trn_tfms, val_tfms):
    trn_ds_tfms, trn_dl_tfms = split_one_tfms(trn_tfms)
    val_ds_tfms, val_dl_tfms = split_one_tfms(val_tfms)
    return trn_ds_tfms, val_ds_tfms, trn_dl_tfms, val_dl_tfms