In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_006 import *
import gc

# Carvana

## Setup

In [None]:
PATH = Path('data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X_FULL = PATH/'train'
PATH_X_128 = PATH/'train-128'
PATH_Y_FULL = PATH_PNG
PATH_Y_128 = PATH/'train_masks-128'

PATH_X = PATH_X_128
PATH_Y = PATH_Y_128

In [None]:
def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png'

In [None]:
def get_datasets(path):
    x_fns = [o for o in path.iterdir() if o.is_file()]
    y_fns = [get_y_fn(o) for o in x_fns]
    val_idxs = list(range(1008))
    ((val_x,trn_x),(val_y,trn_y)) = split_arrs(val_idxs, x_fns, y_fns)
    return (MatchedFilesDataset(trn_x, trn_y),
            MatchedFilesDataset(val_x, val_y))

In [None]:
size=128

In [None]:
def get_tfm_datasets(size):
    datasets = get_datasets(PATH_X_128 if size<=128 else PATH_X_FULL)
    tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)
#     tfms = [None,None]
    return transform_datasets(*datasets, tfms, tfm_y=True, size=size)

In [None]:
default_norm,default_denorm = normalize_funcs(*imagenet_stats)
bs = 32

In [None]:
def get_data(size, bs):
    return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm)

In [None]:
data = get_data(size, bs)

## Unet

In [None]:
# todo: init?

In [None]:
def ifnone(a,b):
    "`a` if its not None, otherwise `b`"
    return b if a is None else a

def children(m): return list(m.children())
def num_children(m): return len(children(m))
def range_children(m): return range(num_children(m))

def cond_init(m, init_fn):
    if not isinstance(m, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d)):
        if hasattr(m, 'weight'): init_fn(m.weight)
        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)

def apply_init(m, init_fn):
    m.apply(lambda x: cond_init(x, init_fn))

def apply_leaf(m, f):
    c = children(m)
    if isinstance(m, nn.Module): f(m)
    for l in c: apply_leaf(l,f)

In [None]:
class Hook():
    def __init__(self, m, hook_func, is_forward=True):
        self.hook_func,self.stored = hook_func,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)

    def hook_fn(self, module, input, output):
        input  = (o.detach() for o in input ) if is_listy(input ) else input
        output = (o.detach() for o in output) if is_listy(output) else output
        self.stored = self.hook_func(module, input, output)

    def remove(self): self.hook.remove()

def hook_output(module): return Hook(module, lambda m,i,o: o)

In [None]:
class Hooks():
    def __init__(self, ms, hook_func, is_forward=True):
        self.hooks = [Hook(m, hook_func, is_forward) for m in ms]
        
    def __getitem__(self,i): return self.hooks[i]
    def __len__(self): return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    
    def remove(self):
        for h in self.hooks: h.remove()

def hook_outputs(modules): return Hooks(modules, lambda m,i,o: o)

In [None]:
def in_channels(m):
    for l in flatten_model(m):
        if hasattr(l, 'weight'): return l.weight.shape[1]
    raise Exception('No weight layer')

In [None]:
def model_sizes(m, size=(256,256)):
    hooks = hook_outputs(m)
    ch_in = in_channels(m)
    x = torch.zeros(1,ch_in,*size)
    m.eval()(x)
    res = [o.stored.shape for o in hooks]
    hooks.remove()
    return res,x

In [None]:
def get_sfs_idxs(sizes, last=True):
    if last:
        feature_szs = [size[-1] for size in sizes]
        sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
        if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
    else: sfs_idxs = list(range(len(sfs)))
    return sfs_idxs

def conv2d(ni, nf, ks=3, stride=1, padding=None):
    if padding is None: padding = ks//2
    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=padding)

def conv2d_trans(ni, nf, ks=2, stride=2, padding=0):
    return nn.ConvTranspose2d(ni, nf, kernel_size=ks, stride=stride, padding=padding)

def conv_bn_relu(ni, nf, ks=3, stride=1, padding=None):
    return nn.Sequential(
        conv2d(ni, nf, ks=ks, stride=stride, padding=padding),
        nn.ReLU(),
        nn.BatchNorm2d(nf))

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in_c, x_in_c, hook):
        super().__init__()
        self.hook = hook
        ni = up_in_c
        self.upconv = conv2d_trans(ni, ni//2) # H, W -> 2H, 2W
        ni = ni//2 + x_in_c
        self.conv1 = conv2d(ni, ni//2)
        ni = ni//2
        self.conv2 = conv2d(ni, ni)
        self.bn = nn.BatchNorm2d(ni)

    def forward(self, up_in):
        up_out = self.upconv(up_in)
        cat_x = torch.cat([up_out, self.hook.stored], dim=1)
        x = F.relu(self.conv1(cat_x))
        x = F.relu(self.conv2(x))
        return self.bn(x)

In [None]:
class DynamicUnet(nn.Module):
    def __init__(self, encoder, last=True, n_classes=3):
        super().__init__()

        sfs_szs,x = model_sizes(encoder)
        imsize = x.shape[-2:]
        sfs_idxs = get_sfs_idxs(sfs_szs, last)
        
        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(conv_bn_relu(ni, ni*2), conv_bn_relu(ni*2, ni))
        self.sfs = hook_outputs(encoder)
        x = middle_conv(encoder(x))

        layers = [encoder, nn.ReLU(), middle_conv]
        for idx in sfs_idxs[::-1]:
            up_in_c, x_in_c = int(x.size()[1]), int(sfs_szs[idx][1])
            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[idx])
            layers.append(unet_block)
            x = unet_block(x)

        final_in_c = unet_block.conv2.out_channels
        if imsize != sfs_szs[0][-2:]:
            layers.append(conv2d_trans(final_in_c, final_in_c))

        layers.append(conv2d(final_in_c, n_classes, 1))
        self.layers = nn.Sequential(*layers)

    def forward(self, x): return self.layers(x)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

In [None]:
metrics=[accuracy_thresh,dice]
lr = 1e-3

In [None]:
body = create_body(tvm.resnet34(True), 2)
model = DynamicUnet(body, n_classes=1).cuda()

learn = Learner(data, model, metrics=metrics,
                loss_fn=F.binary_cross_entropy_with_logits)

learn.split([model.layers[0][6], model.layers[1]])
learn.freeze()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(1, slice(lr), pct_start=0.05)

In [None]:
learn.fit_one_cycle(6, slice(lr), pct_start=0.05)

In [None]:
learn.save('u0')

In [None]:
x,py = learn.pred_batch()

for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat):
    show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax)

In [None]:
learn.unfreeze()
lr=1e-3

In [None]:
learn.fit_one_cycle(6, slice(lr/100,lr), pct_start=0.05)

In [None]:
size=512
bs = 8
learn.data = get_data(size, bs)

In [None]:
learn.freeze()

## Fin