In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_005a import *
from nb_005b import *

from nb_006 import *
import fast_progress as fp

# Camvid

## Setup

In [None]:
PATH = Path('data/camvid')
PATH_X = PATH/'701_StillsRaw_full'
PATH_Y = PATH/'LabeledApproved_full'
PATH_Y_PROCESSED = PATH/'LabelProcessed'
label_csv = PATH/'label_colors.txt'

PATH_Y_PROCESSED.mkdir(exist_ok=True)

In [None]:
list(PATH_Y.iterdir())[0]

In [None]:
def parse_code(l):
    a,b = [c for c in l.strip().split("\t") if c]
    return tuple(int(o) for o in a.split(' ')), b
label_codes,label_names = zip(*[parse_code(l) for l in open(PATH/"label_colors.txt")])
label_codes,label_names = list(label_codes),list(label_names)
code2id = {v:k for k,v in enumerate(label_codes)}
failed_code = len(label_codes)+1
label_codes.append((0,0,0))
label_names.append('unk')


In [None]:
def get_y_fn(x_fn): return PATH_Y/f'{x_fn.name[:-4]}_L.png'
def get_y_proc_fn(y_fn): return PATH_Y_PROCESSED/f'{y_fn.name[:-6]}_P.png'

In [None]:
x_fns = [o for o in PATH_X.iterdir() if o.is_file()]
y_fns = [get_y_fn(o) for o in x_fns]
y_proc_fns = [get_y_proc_fn(o) for o in y_fns]

In [None]:
def process_file(fns):
    yfn, pfn = fns
    if not pfn.exists():
        y_data = open_mask(yfn).px.long()

        h, w = y_data.shape[1:3]
        data = y_data.view(3, -1)
        n_pixels = data.shape[1]
        proc_data = np.zeros((1, n_pixels),dtype=np.uint8)
        for i in range(n_pixels):
            proc_data[:,i] = code2id.get(tuple(data[:,i].numpy()), 0)
        proc_data.resize((1, h, w))
        img = PIL.Image.fromarray(proc_data[0])
        img.save(pfn)
    return pfn

from concurrent.futures import ProcessPoolExecutor
def process_label_files(y_fns, y_proc_fns):
    ex = ProcessPoolExecutor(16)
    for pfn in ex.map(process_file, zip(y_fns, y_proc_fns)):
        pass

In [None]:
%time process_label_files(y_fns, y_proc_fns)

In [None]:
def get_datasets(path, valid_pct=0.2):
    x_fns = [o for o in path.iterdir() if o.is_file()]
    y_fns = [get_y_fn(o) for o in x_fns]
    y_proc_fns = [get_y_proc_fn(o) for o in y_fns]
    total = len(x_fns)
    
    is_test = np.random.uniform(size=(len(x_fns),)) < valid_pct
    ((val_x,trn_x),(val_y,trn_y)) = split_arrs(is_test, x_fns, y_proc_fns)
    return (MatchedFilesDataset(trn_x, trn_y),
            MatchedFilesDataset(val_x, val_y))

In [None]:
def get_tfm_datasets(size):
    datasets = get_datasets(PATH_X)
    tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)
    return transform_datasets(*datasets, tfms=tfms, tfm_y=True, size=size)

In [None]:
default_norm,default_denorm = normalize_funcs(*imagenet_stats)
bs = 8
size = 512

In [None]:
tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)

In [None]:
def get_data(size, bs):
    return DataBunch.create(*get_tfm_datasets(size), bs=bs, tfms=default_norm)

In [None]:
data = get_data(size, bs)

In [None]:
x, y = data.train_ds[0]
x.shape, y.shape, y.data.dtype

## Unet

In [None]:
def in_channels(m):
    for l in flatten_model(m):
        if hasattr(l, 'weight'): return l.weight.shape[1]
    raise Exception('No weight layer')

def model_sizes(m, size=(256,256), full=True):
    hooks = hook_outputs(m)
    ch_in = in_channels(m)
    x = torch.zeros(1,ch_in,*size)
    x = m.eval()(x)
    res = [o.stored.shape for o in hooks]
    if not full: hooks.remove()
    return res,x,hooks if full else res

def get_sfs_idxs(sizes, last=True):
    if last:
        feature_szs = [size[-1] for size in sizes]
        sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
        if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
    else: sfs_idxs = list(range(len(sfs)))
    return sfs_idxs

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in_c, x_in_c, hook):
        super().__init__()
        self.hook = hook
        ni = up_in_c
        self.upconv = conv2d_trans(ni, ni//2) # H, W -> 2H, 2W
        ni = ni//2 + x_in_c
        self.conv1 = conv2d(ni, ni//2)
        ni = ni//2
        self.conv2 = conv2d(ni, ni)
        self.bn = nn.BatchNorm2d(ni)

    def forward(self, up_in):
        up_out = self.upconv(up_in)
        cat_x = torch.cat([up_out, self.hook.stored], dim=1)
        x = F.relu(self.conv1(cat_x))
        x = F.relu(self.conv2(x))
        return self.bn(x)

In [None]:
class Debugger(nn.Module): 
    def forward(self,x): 
        set_trace()
        return x

class DynamicUnet(nn.Sequential):
    def __init__(self, encoder, last=True, n_classes=3):
        imsize = (256,256)
        sfs_szs,x,self.sfs = model_sizes(encoder, size=imsize)
        sfs_idxs = reversed(get_sfs_idxs(sfs_szs, last))
        
        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(conv2d_relu(ni, ni*2, bn=True), conv2d_relu(ni*2, ni, bn=True))
        x = middle_conv(x)
        layers = [encoder, nn.ReLU(), middle_conv]

        for idx in sfs_idxs:
            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[idx])
            layers.append(unet_block)
            x = unet_block(x)

        ni = unet_block.conv2.out_channels
        if imsize != sfs_szs[0][-2:]: layers.append(conv2d_trans(ni, ni))
        layers.append(conv2d(ni, n_classes, 1))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

In [None]:
from nb_005b import accuracy_thresh
metrics=[accuracy_thresh,dice]
lr = 1e-3

In [None]:
def my_loss(pred, target):
    return F.cross_entropy(pred, target.squeeze().long())

body = create_body(tvm.resnet34(True), 2)
model = DynamicUnet(body, n_classes=len(label_codes)).cuda()
learn = Learner(data, model, metrics=metrics,
                loss_fn=my_loss)
learn.split([model[0][6], model[1]])
learn.freeze()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr = 5e-2

In [None]:
learn.fit_one_cycle(1, slice(lr), pct_start=0.05)

In [None]:
learn.fit_one_cycle(6, slice(lr), pct_start=0.05)

In [None]:
learn.save('u0')

In [None]:
# x,py = learn.pred_batch()

# for i, ax in enumerate(plt.subplots(4,4,figsize=(10,10))[1].flat):
#     show_image(default_denorm(x[i].cpu()), py[i]>0, ax=ax)

In [None]:
learn.unfreeze()
lr=1e-2

In [None]:
learn.fit_one_cycle(6, slice(lr/100,lr), pct_start=0.05)

In [None]:
size=640
bs = 4
learn.data = get_data(size, bs)

In [None]:
#learn.freeze()

In [None]:
learn.fit_one_cycle(6, slice(lr), pct_start=0.05)

## Fin