In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_006 import *

# Carvana

## Setup

In [None]:
PATH = Path('data/carvana')
PATH_PNG = PATH/'train_masks_png'
PATH_X_FULL = PATH/'train'
PATH_X_128 = PATH/'train-128'
PATH_Y_FULL = PATH_PNG
PATH_Y_128 = PATH/'train_masks-128'

# start with the 128x128 images
PATH_X = PATH_X_128
PATH_Y = PATH_Y_128

In [None]:
def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_mask.png'

In [None]:
def get_datasets(path):
    x_fns = [o for o in path.iterdir() if o.is_file()]
    y_fns = [get_y_fn(o) for o in x_fns]
    val_idxs = list(range(1008))
    ((val_x,trn_x),(val_y,trn_y)) = split_arrs(val_idxs, x_fns, y_fns)
    return (MatchedFilesDataset(trn_x, trn_y),
            MatchedFilesDataset(val_x, val_y))

In [None]:
train_ds,valid_ds = get_datasets(PATH_X_128)
train_ds,valid_ds

In [None]:
size=128

In [None]:
def get_tfm_datasets(size):
    datasets = get_datasets(PATH_X_128 if size<=128 else PATH_X_FULL)
    tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)
    return transform_datasets(train_ds, valid_ds, tfms, tfm_y=True, size=size)

In [None]:
default_norm,default_denorm = normalize_funcs(*imagenet_stats)
bs = 32

In [None]:
def get_data(size, bs):
    return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm)

In [None]:
data = get_data(size, bs)

## Unet

In [None]:
def get_sfs_idxs(sfs, last=True):
    if last:
        feature_szs = [sfs_feats.features.size()[-1] for sfs_feats in sfs]
        sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
        if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
    else: sfs_idxs = list(range(len(sfs)))
    return sfs_idxs

def conv_bn_relu(in_c, out_c, kernel_size, stride, padding):
    return [
        nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding),
        nn.ReLU(),
        nn.BatchNorm2d(out_c)]

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in_c, x_in_c):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(up_in_c, up_in_c // 2, 2, 2) # H, W -> 2H, 2W
        self.conv1 = nn.Conv2d(x_in_c + up_in_c // 2, (x_in_c + up_in_c // 2) // 2, 3, 1, 1)
        self.conv2 = nn.Conv2d((x_in_c + up_in_c // 2) // 2, (x_in_c + up_in_c // 2) // 2, 3, 1, 1)
        self.bn = nn.BatchNorm2d((x_in_c + up_in_c // 2) // 2)

    def forward(self, up_in, x_in):
        up_out = self.upconv(up_in)
        cat_x = torch.cat([up_out, x_in], dim=1)
        x = F.relu(self.conv1(cat_x))
        x = F.relu(self.conv2(x))
        return self.bn(x)

In [None]:
class SaveFeatures():
    """ Extract pretrained activations"""
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output.detach()
    def remove(self): self.hook.remove()

In [None]:
class DynamicUnet(nn.Module):
    def __init__(self, encoder, last=True, n_classes=3):
        super().__init__()
        self.encoder = encoder
        self.n_children = len(list(encoder.children()))
        self.sfs = [SaveFeatures(encoder[i]) for i in range(self.n_children)]
        self.last = last
        self.n_classes = n_classes

    def forward(self, x):
        dtype = x.type()
        imsize = x.shape[-2:]
        x = F.relu(self.encoder(x))

        # initialize sfs_idxs, sfs_szs, middle_in_c and middle_conv only once
        if not hasattr(self, 'middle_conv'):
            self.sfs_szs = [sfs_feats.features.size() for sfs_feats in self.sfs]
            self.sfs_idxs = get_sfs_idxs(self.sfs, self.last)
            middle_in_c = self.sfs_szs[-1][1]
            self.middle_conv = nn.Sequential(*conv_bn_relu(middle_in_c, middle_in_c * 2, 3, 1, 1),
                *conv_bn_relu(middle_in_c * 2, middle_in_c, 3, 1, 1)).type(dtype)

        x = self.middle_conv(x)

        # initialize upmodel, extra_block and 1x1 final conv
        if not hasattr(self, 'upmodel'):
            x_copy = x.detach()
            upmodel = []
            for idx in self.sfs_idxs[::-1]:
                up_in_c, x_in_c = int(x_copy.size()[1]), int(self.sfs_szs[idx][1])
                unet_block = UnetBlock(up_in_c, x_in_c).type(dtype)
                upmodel.append(unet_block)
                x_copy = unet_block(x_copy, self.sfs[idx].features)
                self.upmodel = nn.Sequential(*upmodel)

            if imsize != self.sfs_szs[0][-2:]:
                extra_in_c = self.upmodel[-1].conv2.out_channels
                self.extra_block = nn.ConvTranspose2d(extra_in_c, extra_in_c, 2, 2).type(dtype)

            final_in_c = self.upmodel[-1].conv2.out_channels
            self.final_conv = nn.Conv2d(final_in_c, self.n_classes, 1).type(dtype)

        # run upsample
        for block, idx in zip(self.upmodel, self.sfs_idxs[::-1]):
            x = block(x, self.sfs[idx].features)
        if hasattr(self, 'extra_block'): x = self.extra_block(x)

        return self.final_conv(x)

In [None]:
class DynamicUnet(nn.Module):
    def __init__(self, encoder, last=True, n_classes=3):
        super().__init__()
        self.encoder = encoder
        self.n_children = len(list(encoder.children()))
        self.sfs = [SaveFeatures(encoder[i]) for i in range(self.n_children)]
        self.last = last
        self.n_classes = n_classes

        x = torch.FloatTensor()
        sfs_szs = [sfs_feats.features.size() for sfs_feats in self.sfs]
        self.sfs_idxs = get_sfs_idxs(self.sfs, self.last)
        middle_in_c = sfs_szs[-1][1]
        self.middle_conv = nn.Sequential(*conv_bn_relu(middle_in_c, middle_in_c * 2, 3, 1, 1),
            *conv_bn_relu(middle_in_c * 2, middle_in_c, 3, 1, 1)).type(dtype)

        x_copy = x.detach()
        upmodel = []
        for idx in self.sfs_idxs[::-1]:
            up_in_c, x_in_c = int(x_copy.size()[1]), int(self.sfs_szs[idx][1])
            unet_block = UnetBlock(up_in_c, x_in_c).type(dtype)
            upmodel.append(unet_block)
            x_copy = unet_block(x_copy, self.sfs[idx].features)
            self.upmodel = nn.Sequential(*upmodel)

        if imsize != self.sfs_szs[0][-2:]:
            extra_in_c = self.upmodel[-1].conv2.out_channels
            self.extra_block = nn.ConvTranspose2d(extra_in_c, extra_in_c, 2, 2).type(dtype)

        final_in_c = self.upmodel[-1].conv2.out_channels
        self.final_conv = nn.Conv2d(final_in_c, self.n_classes, 1).type(dtype)

    def forward(self, x):
        dtype = x.type()
        imsize = x.shape[-2:]
        x = F.relu(self.encoder(x))

        x = self.middle_conv(x)

        # run upsample
        for block, idx in zip(self.upmodel, self.sfs_idxs[::-1]):
            x = block(x, self.sfs[idx].features)
        if hasattr(self, 'extra_block'): x = self.extra_block(x)

        return self.final_conv(x)

In [None]:
l=list(body.children())[0]

In [None]:
def in_details(m):
    for l in flatten_model(m):
        if hasattr(l, 'weight'): return l.weight.type(),l.weight.shape[1]
    raise Exception('No weight layer')

In [None]:
type_in,ch_in = in_details(body)
ch_in

In [None]:
x = FloatTensor(1,ch_in,256,256).type(type_in)

In [None]:
metrics=[accuracy_thresh,dice]

In [None]:
body = create_body(tvm.resnet34(True), 2)
model = DynamicUnet(body, n_classes=1).cuda()
learn = Learner(data, model, metrics=metrics,
                loss_fn=F.binary_cross_entropy_with_logits)

In [None]:
x,y = next(iter(data.valid_dl))
t = model(x)

In [None]:
model

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr = 1e-3

In [None]:
learn.fit_one_cycle(1, slice(lr))

In [None]:
learn.fit_one_cycle(1, slice(lr))

In [None]:
learn.unfreeze()

In [None]:
x,py = learn.pred_batch()

In [None]:
for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat):
    show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax)

In [None]:
size=512
bs = 8
learn.data = get_data(size, bs)

In [None]:
learn.freeze()

## Fin