## imports / setup

In [21]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [22]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

from pathlib import Path
import torch
from torch import optim, nn
import PIL 
import pydicom

from functools import partial, reduce
from enum import Enum

In [23]:
# helper function from the competition
import mask_functions

In [1]:
from dataloader_full import *
from Learner import Learner

In [27]:
path = Path('/home/ubuntu/data/fastai/data/siim-acr/')

## encoder

In [28]:
class ResBlock(nn.Module):
    def __init__(self, ni, nf):
        super().__init__()
        self.c1 = nn.Conv2d(ni, nf, kernel_size=3, stride=1, padding=1)
        self.c2 = nn.Conv2d(ni, nf, kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        a1 = F.relu(self.c1(x))
        a2 = F.relu(self.c2(x))
        return torch.add(a2, x)
        

In [29]:
# resb = ResBlock(1, 16).cuda()

# yp = resb(xb)
# assert(yp.shape == (1,16,1024,1024))

In [30]:
def conv2d(ni, nf, s): return nn.Conv2d(ni, nf, kernel_size=3, stride=2, padding=1)
def conv_layer(ni, nf, s=2, p=0.5): 
    return nn.Sequential(
    conv2d(ni, nf, s),
    nn.BatchNorm2d(nf),
    nn.Dropout3d(p),
    nn.ReLU()
)

In [31]:
def deconv2d(ni, nf, s): return nn.ConvTranspose2d(ni, nf, kernel_size=3, stride=2, padding=1, output_padding=1)
def deconv_layer(ni, nf, s=2, p=0.5): 
    return nn.Sequential(
    deconv2d(ni, nf, s),
    nn.BatchNorm2d(nf),
    nn.Dropout3d(p),
    nn.ReLU()
)

In [32]:
layer = nn.Sequential(
    nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
    nn.BatchNorm2d(256),
    nn.Dropout3d(0.1),
    nn.ReLU()
)

input_x = torch.rand([1, 512, 16, 16])

out_a = layer(input_x)

print(input_x.shape, out_a.shape)

torch.Size([1, 512, 16, 16]) torch.Size([1, 256, 32, 32])


In [33]:
class unet_simple(nn.Module):
    def __init__(self):
        super().__init__()
        
        # encoder
#         self.rin0 = ResBlock(1, 1) # 1024
        self.cin1 = conv_layer(1, 16, 2, 0.5) # 1024**2 -> 512**2
        self.cin2 = conv_layer(16, 32, 2, 0.5) # 512**2 -> 256**2
        self.rin1 = ResBlock(32, 32) # 256 
        self.cin3 = conv_layer(32, 64, 2, 0.5) # 256**2 -> 128**2
        self.cin4 = conv_layer(64, 128, 2, 0.5) # 128**2 -> 64**2
        self.rin2 = ResBlock(128, 128) # 64
        self.cin5 = conv_layer(128, 256, 2, 0.5) # 64**2 -> 32**2
        self.cin6 = conv_layer(256, 512, 2, 0.5) # 32**2 -> 16**2
        self.rin3 = ResBlock(512, 512)
        
        # decoder
        self.dcout1 = deconv_layer(512, 256, 2, 0.5) # 16**2 -> 32**2
        self.dcout2 = deconv_layer(256, 128, 2, 0.5) # 32**2 -> 64**2
        self.drout1 = ResBlock(256, 256) # 64
        # skip
        self.dcout3 = deconv_layer(256, 64, 2, 0.5) # 64**2 -> 128**2
        self.dcout4 = deconv_layer(64, 32, 2, 0.5) # 128**2 -> 256**2
        self.drout2 = ResBlock(64, 64) # 256
        # skip
        self.dcout5 = deconv_layer(64, 32, 2, 0.5) # 256**2 -> 512**2
        self.dcout6 = deconv_layer(32, 1, 2, 0.5) # 512**2 -> 1024**2
#         self.drout3 = ResBlock(16, 16) # 1024
        # skip with input + a
        
        
    def forward(self, x):
        
        # downsampling path
#         a0 = self.rin0(x)
        # skip + x
        a1 = self.cin1(x)
        a2 = self.cin2(a1)
        a3 = self.rin1(a2)
        # skip
        a4 = self.cin3(a3)
        a5 = self.cin4(a4)
        a6 = self.rin2(a5) 
        # skip
        a7 = self.cin5(a6)
        a8 = self.cin6(a7)
        a9 = self.rin3(a8) 
        
        
        # upsample path

        a10 = self.dcout1(a9)
        a11 = self.dcout2(a10)
#         print([a.shape for a in [a9, a10, a11]])
        
#         print(a11.shape, a6.shape)
        skip_a_64 = torch.cat((a11, a6), 1)
        a12 = self.drout1(skip_a_64)
        
        a13 = self.dcout3(a12)
        a14 = self.dcout4(a13)
        
#         print(a14.shape, a3.shape)
        skip_a_256 = torch.cat((a14, a3), 1)
        a15 = self.drout2(skip_a_256)
        
        a16 = self.dcout5(a15)
        a17 = self.dcout6(a16)
        
#         skip_a_1024 = torch.cat((a17, a0, x), 1)
#         a15 = self.drout3(skip_a_256)
        
        return a17

In [34]:
bs = 16
lr = 1e-3
wd = 1e-2

In [35]:
model = unet_simple().cuda()

In [37]:
trn,val = get_data(path, bs, shrink=0.1)

In [39]:
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
learn = Learner(path, model, opt, torch.nn.MSELoss(), {'trn': trn['dl'], 'val': val['dl']})

In [None]:
learn.fit_(1)

  9%|▉         | 5/54 [00:09<01:36,  1.97s/it]

In [1]:
learn.save('unet_basic_v1')

NameError: name 'learn' is not defined

In [32]:
learn.load('unet_basic_v1')

In [33]:
model = learn.model.eval()

In [34]:
xb, yb = next(iter(dl))

In [35]:
ypred = model(xb)

In [1]:
for i in range(len(ypred)):
    arr = [t.cpu() for t in [xb[i],yb[i],ypred[i]] ]
    plot_xyhat(arr)


NameError: name 'ypred' is not defined