In [1]:
from glob import glob
import cv2
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt


import torch.nn as nn
import torch

from fastai.conv_learner import *
from fastai.dataset import *


%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))
    
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()
            
class UnetModel():
    def __init__(self,model,name='unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()

In [3]:
f = resnet34
cut,lr_cut = model_meta[f]

def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)

In [4]:
m_base = get_base()
m = to_gpu(Unet34(m_base))

In [5]:
TRAIN_DN = 'data/train'
TRAIN_LABEL_DN = 'data/train_mask'
x_names = np.array(glob(str(Path(TRAIN_DN)/f'*.jpg')))
y_names = np.array(glob(str(Path(TRAIN_LABEL_DN)/f'*.jpg')))

val_idxs = list(range(10407))
((val_x, trn_x), (val_y, trn_y)) = split_by_idx(val_idxs, x_names, y_names)

In [6]:
def dataloader(x_fn_list, y_fn_list, bs):  
    batch_image_arr = []
    batch_label_arr = []
    cnt = 0
    for i in range(len(x_fn_list)):
        image_arr = open_image(x_fn_list[i])
        image_arr = image_arr * 2 - 1
#         print(image_arr.max(), image_arr.min())
        label_arr = open_image(y_fn_list[i])[:,:,0]
#         label_arr = label_arr.astype(np.uint8)
#         print(label_arr.max(), label_arr.min())
#         print(label_arr.shape)
#         raise
        batch_image_arr.append(image_arr.T)
        batch_label_arr.append(label_arr.T)
        cnt += 1
        if cnt >= bs:
            cnt = 0
            yield np.array(batch_image_arr), np.array(batch_label_arr)
            batch_image_arr, batch_label_arr = [], []
    if len(batch_image_arr) and len(batch_label_arr):
        yield np.array(batch_image_arr), np.array(batch_label_arr)

In [7]:
trainloader = dataloader(trn_x, trn_y, bs=5)
valloader = dataloader(val_x, val_y, bs=5)

In [8]:
def train(model, loss_fn, optimizer, trainloader, testloader, num_epochs = 1):
    for epoch in range(num_epochs):
        print('Starting epoch %d / %d' % (epoch + 1, num_epochs))
        model.train()
        train_loss = []
        for x, y in trainloader:
            x = Variable(torch.from_numpy(x).cuda())
            y = Variable(torch.from_numpy(y).cuda())
            optimizer.zero_grad()
            
            scores = model(x)
            loss = loss_fn(scores, y)
            train_loss.append(loss.data[0])
            loss.backward()
            optimizer.step()
            print('train loss = {}'.format(loss.data[0]))
        
        val_loss = []
        for t, (x, y) in enumerate(testloader):
            x = Variable(x.cuda())
            y = Variable(y.cuda())
            
            scores = model(x)
            loss = loss_fn(scores, y)
            val_loss.append(loss.data[0])
        print('val loss = {}'.format(np.mean(val_loss)))

def check_accuracy(model, loader):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')   
    num_correct = 0
    num_samples = 0
    model.eval() # Put the model in test mode (the opposite of model.train(), essentially)
    for x, y in loader:
        x_var = Variable(x.cuda())

        scores = model(x_var)
        _, preds = scores.data.cpu().max(1)
        num_correct += (preds == y).sum()
        num_samples += preds.size(0)
    acc = float(num_correct) / num_samples
    print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))

In [9]:
def dice(pred, targs):
    pred = (pred>0).float()
#     print(pred.shape, targs.shape)
    return 2. * ((pred*targs).sum()+1e-8)/ ((pred+targs).sum()+1e-8)

In [10]:
# loss_fn = nn.CrossEntropyLoss()
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(m.parameters(), lr=1e-3, momentum=0.9)

In [11]:
train(m, loss_fn, optimizer, trainloader, valloader, num_epochs=2)
# check_accuracy(m, testloader)

Starting epoch 1 / 2
train loss = 0.7050632834434509
train loss = 0.7049021124839783
train loss = 0.704341471195221
train loss = 0.7035050988197327
train loss = 0.7023150324821472
train loss = 0.7014855146408081
train loss = 0.700497031211853
train loss = 0.6990828514099121
train loss = 0.6973162889480591
train loss = 0.6958601474761963
train loss = 0.6943911910057068
train loss = 0.6924437880516052
train loss = 0.6906020045280457
train loss = 0.6886652708053589
train loss = 0.6867001056671143
train loss = 0.6850572228431702
train loss = 0.6826474666595459
train loss = 0.68092280626297
train loss = 0.6784927845001221
train loss = 0.6762816309928894
train loss = 0.6740670800209045
train loss = 0.6718223094940186
train loss = 0.6696360111236572
train loss = 0.6674396991729736
train loss = 0.6651862859725952
train loss = 0.6631177067756653
train loss = 0.660552442073822
train loss = 0.6582621932029724
train loss = 0.6562259197235107
train loss = 0.653497040271759
train loss = 0.6511470079

KeyboardInterrupt: 