In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_004c import *

from functools import wraps

# Carvana

## Setup

(See final section of notebook for one-time data processing steps.)

In [None]:
PATH = Path('data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X = PATH/'train-128'
PATH_Y = PATH/'train_masks-128'

In [None]:
img_f = next(PATH_X.iterdir())
x = open_image(img_f)
show_image(x)

In [None]:
def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png'

In [None]:
#export
def pil2tensor(image, as_mask=False):
    arr = torch.ByteTensor(torch.ByteStorage.from_buffer(image.tobytes()))
    arr = arr.view(image.size[1], image.size[0], -1)
    arr = arr.permute(2,0,1).float()
    return arr if as_mask else arr.div_(255)

def open_image(fn, as_mask=False):
    x = PIL.Image.open(fn)
    if not as_mask: x = x.convert('RGB')
    return pil2tensor(x, as_mask=as_mask)

def image2np(image):
    res = image.cpu().permute(1,2,0).numpy()
    return res[...,0] if res.shape[2]==1 else res

def show_image(img, ax=None, figsize=(3,3), hide_axis=True, cmap='binary', alpha=None):
    if ax is None: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(image2np(img), cmap=cmap, alpha=alpha)
    if hide_axis: ax.axis('off')
    return ax
        
def show_xy_image(xim, yim, ax=None, figsize=(3,3), alpha=0.5, hide_axis=True, cmap='RdGy_r'):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax1 = show_image(xim, ax=ax, hide_axis=hide_axis, cmap=cmap)
    show_image(yim, ax=ax1, alpha=alpha, hide_axis=hide_axis, cmap=cmap)
    if hide_axis: ax.axis('off')
        
def show_xy_images(x,y,rows,figsize=(9,9)):
    fig, axs = plt.subplots(rows,rows,figsize=figsize)
    for i, ax in enumerate(axs.flatten()):
        show_xy_image(x[i], y[i], ax)
    plt.tight_layout()


In [None]:
img_y_f = get_y_fn(img_f)
y = open_image(img_y_f, as_mask=True)
show_image(y)

In [None]:
show_xy_image(x,y)

In [None]:
tfms = [
    rotate(degrees=(-20,20.), p=0.75),
    zoom(scale=(0.5,2), p=0.75),
    contrast(scale=(0.6,1.4)),
    brightness(change=(0.3,0.7)),
    *zoom_crop(scale=(1.,2.), p=0.5, do_rand=True)

]

In [None]:
tfms

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,3))
for ax in axes: show_image(apply_tfms(tfms,x,size=128), ax)

## xy transforms

- data types: regr, class, seg, bbox, polygon, generative (s/res, color), custom

In [None]:
#export

def skip_ys(*args):
    for tfm in args:
        tfm._y_tfm = None
    return args

def kw_override(tfm, override):
    new_tfm = copy(tfm)
    new_tfm.kwargs = override
    tfm._y_tfm = new_tfm
    return tfm

def force_nearest(tfm):
    if not isinstance(tfm.tfm,TfmAffine): return tfm
    else: return kw_override(tfm, {'mode': 'nearest'})
        
def affine_y_mode_nearest(*args):
    return [force_nearest(tfm) for tfm in args]

In [None]:
xy_tfms = [
    flip_lr(p=0.5),
    rotate(degrees=(-20,20.), p=0.75),
    zoom(scale=(0.5,2), p=0.75),
    *skip_ys(
        contrast(scale=(0.6,1.4)),
        brightness(change=(0.3,0.7))
    ),
    *zoom_crop(scale=(1.,2.), p=0.5, do_rand=True)
]

In [None]:
#export

def choose_tfm(tfm, do_y):
    if do_y and hasattr(tfm, '_y_tfm'): 
        return tfm._y_tfm
    else: return tfm
        
        
def resolve_tfms_xy(tfms):
    tfms = listify(tfms)
    resolve_tfms(tfms)
    for tfm in tfms:
        if hasattr(tfm, '_y_tfm') and tfm._y_tfm:
            tfm._y_tfm.resolve()
            tfm._y_tfm.resolved = {**tfm.resolved, **tfm._y_tfm.resolved}
        
def apply_tfms(tfms, x, do_resolve=True, xtra=None, size=None,
               mult=32, do_crop=True, padding_mode='reflect', do_y=False, **kwargs):
    if not tfms: return x
    if not xtra: xtra={}
    tfms = sorted(listify(tfms), key=lambda o: o.tfm.order)
    tfms = [choose_tfm(tfm, do_y) for tfm in tfms]
    tfms = [tfm for tfm in tfms if not tfm is None]
    
    if do_resolve: 
        resolve_tfms_xy(tfms)
    x = Image(x.clone())
    mode = 'nearest' if do_y else 'bilinear'
    x.set_sample(padding_mode=padding_mode, mode=mode, **kwargs)
    if size:
        crop_target = get_crop_target(size, mult=mult)
        target = get_resize_target(x, crop_target, do_crop=do_crop)
        x.resize(target)

    size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)]
    for tfm in tfms:
        if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])
        elif tfm in size_tfms: x = tfm(x, size=size, padding_mode=padding_mode)
        else: x = tfm(x)
    return x.px


def build_xy_tfm(tfms):
    def xy_tfm(x_y, **kwargs):
        kwargs = copy(kwargs)
        if 'do_resolve_y' in kwargs:
            do_resolve_y = kwargs.get('do_resolve_y')
            del kwargs['do_resolve_y']
        else:
            do_resolve_y = False
        x, y = x_y
        xt = apply_tfms(tfms, x, **kwargs)
        yt = apply_tfms(tfms, y, do_y=True,
                        **{**kwargs, **{'do_resolve': do_resolve_y}})
        return xt, yt
    return xy_tfm

In [None]:
size=128
_,axes = plt.subplots(1,4, figsize=(12,6))
for i in range(4):
    tfm = build_xy_tfm(xy_tfms)
    imgx,imgy = tfm((x,y),size=size)
    show_xy_image(imgx, imgy, axes[i])

## Dataset

In [None]:
#export
@dataclass
class MatchedFilesDataset(Dataset):
    x_fns:List[Path]; y_fns:List[Path]
    def __post_init__(self): assert len(self.x_fns)==len(self.y_fns)
    def __repr__(self): return f'{type(self).__name__} of len {len(self.x_fns)}'
    def __len__(self): return len(self.x_fns)
    def __getitem__(self, i): return open_image(self.x_fns[i]), open_image(self.y_fns[i],as_mask=True)
    
def split_by_idxs(seq, idxs):
    '''A generator that returns sequence pieces, seperated by indexes specified in idxs. '''
    last = 0
    for idx in idxs:
        if not (-len(seq) <= idx < len(seq)):
            raise KeyError(f'Idx {idx} is out-of-bounds')
        yield seq[last:idx]
        last = idx
    yield seq[last:]

def split_by_idx(idxs, *a):
    """
    Split each array passed as *a, to a pair of arrays like this (elements selected by idxs,  the remaining elements)
    This can be used to split multiple arrays containing training data to validation and training set.
    :param idxs [int]: list of indexes selected
    :param a list: list of np.array, each array should have same amount of elements in the first dimension
    :return: list of tuples, each containing a split of corresponding array from *a.
            First element of each tuple is an array composed from elements selected by idxs,
            second element is an array of remaining elements.
    """
    mask = np.zeros(len(a[0]),dtype=bool)
    mask[np.array(idxs)] = True
    return [(o[mask],o[~mask]) for o in a]

class DatasetTfmXY(Dataset):
    def __init__(self, ds:Dataset, tfms:Collection[Callable]=None, **kwargs):
        self.ds,self.kwargs = ds, kwargs
        if tfms is not None: self.tfm = build_xy_tfm(tfms)
        else: self.tfm = None 
        
    def __len__(self): return len(self.ds)
    
    def __getitem__(self,idx):
        x,y = self.ds[idx]
        if self.tfm: return self.tfm((x,y), **self.kwargs)
        else: return x,y

In [None]:
x_fns = [o for o in PATH_X.iterdir() if o.is_file()]
y_fns = [get_y_fn(o) for o in x_fns]
val_idxs = list(range(1008))
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, np.array(x_fns), np.array(y_fns))
train_ds = MatchedFilesDataset(trn_x, trn_y)
val_ds = MatchedFilesDataset(val_x, val_y)
train_ds, val_ds

In [None]:
x,y = next(iter(train_ds))
x.shape, y.shape

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,6))
for i in range(4):
    tfm = build_xy_tfm(xy_tfms)
    imgx,imgy = tfm(train_ds[i],size=size)
    show_xy_image(imgx, imgy, axes[i])

In [None]:
size=128
xy_train_tfms = [
    flip_lr(p=0.5),
    rotate(degrees=(-20,20.), p=0.75),
    zoom(scale=(0.5,2), p=0.75),
    *skip_ys(
        contrast(scale=(0.6,1.4)),
        brightness(change=(0.3,0.7))
    ),
    *zoom_crop(scale=(1.,2.), p=0.5, do_rand=True)
]

xy_valid_tfms = [
    crop_pad()
]

train_tds = DatasetTfmXY(train_ds, xy_train_tfms,size=size)
valid_tds = DatasetTfmXY(val_ds, xy_valid_tfms,size=size)

In [None]:
_,axes = plt.subplots(1,4, figsize=(12,9))
for ax in axes.flat: show_xy_image(*train_tds[1], ax)
train_tds[1][0].shape, train_tds[1][1].shape

In [None]:
#export
def normalize(x, mean,std):   return (x-mean[...,None,None]) / std[...,None,None]
def denormalize(x, mean,std): return x*std[...,None,None] + mean[...,None,None]

def normalize_batch(b, mean, std, do_y=False):
    x,y = b
    x = normalize(x,mean,std)
    if do_y: y = normalize(y,mean,std)
    return x,y

def normalize_funcs(mean, std, do_y=False, device=None):
    if device is None: device=default_device
    return (partial(normalize_batch, mean=mean.to(device),std=std.to(device), do_y=do_y),
            partial(denormalize,     mean=mean,           std=std))



In [None]:
#imagenet
default_mean, default_std = Tensor([0.485, 0.456, 0.406]), Tensor([0.229, 0.224, 0.225])
default_norm,default_denorm = normalize_funcs(default_mean,default_std)

bs = 64
data = DataBunch.create(train_tds, valid_tds, bs=bs, dl_tfms=default_norm)


In [None]:
x,y = next(iter(data.train_dl))
x = x.cpu()
y = y.cpu()
print(x.min(),x.max(),x.mean(),x.std())
x = default_denorm(x)
#y = default_denorm(y)
show_xy_images(x,y,6, figsize=(9,10))
x.shape, y.shape

In [None]:
from torchvision.models import resnet34

model_meta = {
    resnet34:[8,6]
}

f = resnet34
cut,lr_cut = model_meta[f]

def cut_model(m, cut):
    return list(m.children())[:cut] if cut else m

def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)

def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()

def accuracy(out, yb):
    preds = torch.max(out, dim=1)[1]
    return (preds==yb).float().mean()

USE_GPU = torch.cuda.is_available()
def to_gpu(x, *args, **kwargs):
    '''puts pytorch variable to gpu, if cuda is available and USE_GPU is set to true. '''
    return x.cuda(*args, **kwargs) if USE_GPU else x

class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()
        
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))
    
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = UnetBlock(256,3,16)
        self.up6 = nn.ConvTranspose2d(16, 1, 1)
        
    def forward(self,x):
        inp = x
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x, inp)
        x = self.up6(x)
        return x #[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()

            
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]
    
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

In [None]:
m_base = get_base()
model = to_gpu(Unet34(m_base))
learn = Learner(data, model)
learn.metrics = [dice]
learn.loss_fn = nn.BCEWithLogitsLoss()

In [None]:
lr_find(learn, start_lr=0.01, end_lr=100)

In [None]:
learn.recorder.plot()

In [None]:
sched = OneCycleScheduler(learn, 0.3, 10)
learn.fit(10, 0.1, callbacks=[sched])

In [None]:
learn.recorder.plot_losses()

In [None]:
learn.recorder.plot_metrics()

In [None]:
x,y = next(iter(data.valid_dl))
py = learn.model(x)
py = py.detach()

In [None]:
show_image(y[0]), show_image(py[0]>0)