In [None]:
import numpy as np
np.random.seed(42)

# progress bars
from tqdm import tqdm_notebook as tqdm

# limit to single device
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# pytorch imports
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from torch import nn
from torch.autograd import Variable
from torchvision import transforms

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor  # uncomment to run on GPU

# matplotlib for plotting
import matplotlib.pyplot as plt
fig_size = (7, 7)
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.bottom'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['figure.figsize'] = fig_size
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['image.interpolation'] = 'none'
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['xtick.color'] = 'white'
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['ytick.color'] = 'white'
%matplotlib inline

# batching
from random import shuffle

def batch_generator(dataset, batch_size=5):
    shuffle(dataset)
    N_full_batches = len(dataset) // batch_size
    for i in range(N_full_batches):
        idx_from = batch_size * i
        idx_to = batch_size * (i + 1)
        imgs, masks = zip(*[(img, mask) for img, mask in dataset[idx_from:idx_to]])
        yield imgs, masks

In [None]:
from datasets.drive.extract_patches import get_data_training, get_data_testing

batch_size = 32
img_size = 64

patches_imgs, patches_masks = get_data_training(N=2000, img_size=img_size)
dataset_train = [(img, mask) for img, mask in zip(patches_imgs, patches_masks)]

test_imgs, test_masks = get_data_testing()

In [None]:
class UnetDown(nn.Module):
    def __init__(self, kernel=2):
        super(UnetDown, self).__init__()
        self.down = nn.MaxPool2d(kernel)

    def forward(self, inputs):
        return self.down(inputs)

class UnetUp(nn.Module):
    def __init__(self, in_size, out_size, 
                 kernel=2, stride=2, padding=(1, 1, 1, 1)):
        super(UnetUp, self).__init__()
        self.padding = padding
        self.deconv = nn.ConvTranspose2d(in_size, out_size, 
                                         kernel, stride, padding=0)
        self.norm = nn.BatchNorm2d(out_size)
        self.act = nn.ReLU()
        
    def forward(self, inputs):
        outputs = F.pad(inputs, self.padding)
        outputs = self.deconv(outputs)
        outputs = self.norm(outputs)
        return self.act(outputs)


class UnetConv(nn.Module):
    def __init__(self, in_size, out_size, 
                 kernel=3, stride=1, padding=1,
                 act=nn.ReLU()):
        super(UnetConv, self).__init__()

        self.conv = nn.Conv2d(in_size, out_size, 
                              kernel, stride, padding)
        self.norm = nn.BatchNorm2d(out_size)
        self.act = act
        
    def forward(self, inputs):
        outputs = self.conv(inputs)
        outputs = self.norm(outputs)
        if self.act is not None:
            return self.act(outputs)
        else:
            return outputs

class UnetConc(nn.Module):
    def __init__(self, dropout=0.5):
        super(UnetConc, self).__init__()

        if dropout is not False and dropout > 0.:
            self.dropout = torch.nn.Dropout()
        else:
            self.dropout = None
        
    def forward(self, inputs1, inputs2):
        x = torch.cat([inputs1, inputs2], 1)
        
        if self.dropout is not None:
            x = self.dropout(x)
        
        return x

class Unet(nn.Module):
    def __init__(self, debug=False):
        super(Unet, self).__init__()
        self.debug = debug
        
        self.conv1 = UnetConv(3, 32)
        self.conv2 = UnetConv(32, 32)
        self.conv3 = UnetConv(32, 32)
        self.conv4 = UnetConv(32, 32)
        self.conv5 = UnetConv(32, 32)
        self.conv6 = UnetConv(32, 32)
        self.conv7 = UnetConv(32, 32)
        self.conv8 = UnetConv(64, 32)
        self.conv9 = UnetConv(32, 32)
        self.conv10 = UnetConv(64, 32)
        self.conv11 = UnetConv(32, 32)
        self.conv12 = UnetConv(64, 32)
        self.conv13 = UnetConv(32, 32)
        self.conv14 = UnetConv(64, 32)
        self.conv15 = UnetConv(32, 32)
        self.conv16 = UnetConv(32, 1, act=None)
    
        self.down1 = UnetDown()
        self.down2 = UnetDown()
        self.down3 = UnetDown()
        self.down4 = UnetDown()
        self.down5 = UnetDown()

        padding = (0, 0, 0, 0)
        self.up1 = UnetUp(32, 32, padding=(0, 0, 0, 0))
        self.up2 = UnetUp(32, 32, padding=(0, 0, 0, 0))
        self.up3 = UnetUp(32, 32, padding=(0, 0, 0, 0))
        self.up4 = UnetUp(32, 32, padding=(0, 0, 0, 0))
        self.up5 = UnetUp(32, 32, padding=(0, 0, 0, 0))
        
        self.conc1 = UnetConc()
        self.conc2 = UnetConc()
        self.conc3 = UnetConc()
        self.conc4 = UnetConc()
        
    def forward(self, x):
        conv1 = self.conv1(x)
        if self.debug: print('conv1 : {}'.format(conv1.size()))
        
        down1 = self.down1(conv1)
        conv2 = self.conv2(down1)
        if self.debug: print('down1 : {}'.format(down1.size()))
        if self.debug: print('conv2 : {}'.format(conv2.size()))

        down2 = self.down2(conv2)
        conv3 = self.conv3(down2)
        if self.debug: print('down2 : {}'.format(down2.size()))
        if self.debug: print('conv3 : {}'.format(conv3.size()))
        
        down3 = self.down3(conv3)
        conv4 = self.conv4(down3)
        if self.debug: print('down3 : {}'.format(down3.size()))
        if self.debug: print('conv4 : {}'.format(conv4.size()))

        down4 = self.down4(conv4)
        conv5 = self.conv5(down4)
        if self.debug: print('down4 : {}'.format(down4.size()))
        if self.debug: print('conv5 : {}'.format(conv5.size()))

        down5 = self.down5(conv5)
        conv6 = self.conv6(down5)
        if self.debug: print('down5 : {}'.format(down5.size()))
        if self.debug: print('conv6 : {}'.format(conv6.size()))
        
        up1 = self.up1(conv6)
        if self.debug: print('up1 : {}'.format(up1.size()))
        conv7 = self.conv6(up1)
        conc1 = self.conc1(conv7, down4)
        conv8 = self.conv8(conc1)
        if self.debug: print('conv7 : {}'.format(conv7.size()))
        if self.debug: print('conc1 : {}'.format(conc1.size()))
        if self.debug: print('conv8 : {}'.format(conv8.size()))

        up2 = self.up2(conv8)
        if self.debug: print('up2 : {}'.format(up2.size()))
        conv9 = self.conv9(up2)
        conc2 = self.conc2(conv9, down3)
        conv10 = self.conv10(conc2)
        if self.debug: print('conv9 : {}'.format(conv9.size()))
        if self.debug: print('conc2 : {}'.format(conc2.size()))
        if self.debug: print('conv10 : {}'.format(conv10.size()))
            
        up3 = self.up3(conv10)
        if self.debug: print('up3 : {}'.format(up3.size()))
        conv11 = self.conv11(up3)
        conc3 = self.conc3(conv11, down2)
        conv12 = self.conv12(conc3)
        if self.debug: print('conv11 : {}'.format(conv11.size()))
        if self.debug: print('conc3 : {}'.format(conc3.size()))
        if self.debug: print('conv12 : {}'.format(conv12.size()))

        up4 = self.up4(conv12)
        if self.debug: print('up4 : {}'.format(up4.size()))
        conv13 = self.conv13(up4)
        conc4 = self.conc4(conv13, down1)
        conv14 = self.conv14(conc4)
        if self.debug: print('conv13 : {}'.format(conv13.size()))
        if self.debug: print('conc4 : {}'.format(conc4.size()))
        if self.debug: print('conv14 : {}'.format(conv14.size()))

        up5 = self.up5(conv14)
        if self.debug: print('up5 : {}'.format(up5.size()))
        conv15 = self.conv15(up5)
        conv16 = self.conv16(conv15)
        if self.debug: print('conv15 : {}'.format(conv15.size()))
        if self.debug: print('conv16 : {}'.format(conv16.size()))

        outputs = nn.Sigmoid()(conv16)
        
        return outputs

net = Unet(debug=False)

inputs = Variable(torch.rand(1, 3, 64, 64))
net(inputs).size()

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.01, 
                       betas=(0.9, 0.995), eps=1e-05)
criterion = nn.MSELoss()

n_epochs = 50

for epoch in tqdm(range(n_epochs)):
    
    bgen = batch_generator(dataset_train, batch_size)
    for idx, (imgs, masks) in enumerate(bgen):
        imgs = np.asarray(imgs).reshape(batch_size, 3, img_size, img_size)
        masks = np.asarray(masks).reshape(batch_size, 1, img_size, img_size)

        inputs = Variable(torch.from_numpy(imgs).type(dtype))
        targets = Variable(torch.from_numpy(masks).type(dtype))

        optimizer.zero_grad()
        pred = net(inputs)
        
        loss = criterion(pred, targets)
        loss.backward()
        
        optimizer.step()
        
        current_loss = loss.data[0]
        
        if idx % 25 == 0:
            print('epoch {} -- batch {} -- loss {}'.format(epoch, idx, current_loss))