In [12]:
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import sys
import os
from optparse import OptionParser
import numpy as np
from torch import optim
from PIL import Image
import progressbar
from torch.autograd import Function, Variable
import matplotlib.pyplot as plt
import matplotlib
from torchvision import transforms
from glob import glob
from skimage import io
from sklearn.feature_extraction.image import extract_patches_2d,reconstruct_from_patches_2d
import pickle
from torch.utils.data import Dataset
%matplotlib inline
import cv2

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


Define base model, VGG:

In [63]:

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(4, 64, 3, padding=100)
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/2

        # conv2
        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
        self.relu2_2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/4

        # conv3
        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu3_2 = nn.ReLU(inplace=True)
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu3_3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/8

        # conv4
        self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
        self.relu4_1 = nn.ReLU(inplace=True)
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu4_2 = nn.ReLU(inplace=True)
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu4_3 = nn.ReLU(inplace=True)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/16

        # conv5
        self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_2 = nn.ReLU(inplace=True)
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_3 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/32

    def forward(self, x):
        x = self.relu1_1(self.conv1_1(x))
        x = self.relu1_2(self.conv1_2(x))
        x = self.pool1(x)

        x = self.relu2_1(self.conv2_1(x))
        x = self.relu2_2(self.conv2_2(x))
        x = self.pool2(x)

        x = self.relu3_1(self.conv3_1(x))
        x = self.relu3_2(self.conv3_2(x))
        x = self.relu3_3(self.conv3_3(x))
        x = self.pool3(x)

        x = self.relu4_1(self.conv4_1(x))
        x = self.relu4_2(self.conv4_2(x))
        x = self.relu4_3(self.conv4_3(x))
        x = self.pool4(x)

        x = self.relu5_1(self.conv5_1(x))
        x = self.relu5_2(self.conv5_2(x))
        x = self.relu5_3(self.conv5_3(x))
        x = self.pool5(x)
        return x


Define connec. layer:

In [64]:
def conv_layer(in_channels, out_channles, kernel_size, stride=1, padding=0, bias=True):
    layer = nn.Conv2d(in_channels, out_channles, kernel_size, stride, padding, bias=bias)
    layer.weight.data.zero_()
    if bias:
        layer.bias.data.zero_()
    return layer

In [65]:
def fully_connected(in_channels, out_channles, bias=True):
    layer = nn.Linear(in_channels, out_channles, bias=True)
    if bias:
        layer.bias.data.zero_()
    return layer


Define the FCN model (using VGG model):

In [66]:
class FCN16s(nn.Module):
    def __init__(self):
        super(FCN16s, self).__init__()
        self.backbone = VGG16()
        num_classes = 5

        # fc1
        self.fc1 = conv_layer(512, 4096, 7)
        self.relu1 = nn.ReLU()
        self.drop1 = nn.Dropout2d()

        # fc2
        self.fc2 = conv_layer(4096, 4096, 1)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout2d()

        self.score_fr = conv_layer(4096, num_classes, 1)
        self.score_pool4 = conv_layer(512, num_classes, 1)

        self.upscore2 = bilinear_upsampling(num_classes, num_classes, 4, stride=2, bias=False)
        self.upscore16 = bilinear_upsampling(num_classes, num_classes, 32, stride=16, bias=False)

    def forward(self, x):
        _, _, h, w = x.size()
        x = self.backbone.conv1_1(x)
        x = self.backbone.relu1_1(x)
        x = self.backbone.conv1_2(x)
        x = self.backbone.relu1_2(x)
        x = self.backbone.pool1(x)

        x = self.backbone.conv2_1(x)
        x = self.backbone.relu2_1(x)
        x = self.backbone.conv2_2(x)
        x = self.backbone.relu2_2(x)
        x = self.backbone.pool2(x)

        x = self.backbone.conv3_1(x)
        x = self.backbone.relu3_1(x)
        x = self.backbone.conv3_2(x)
        x = self.backbone.relu3_2(x)
        x = self.backbone.conv3_3(x)
        x = self.backbone.relu3_3(x)
        x = self.backbone.pool3(x)

        x = self.backbone.conv4_1(x)
        x = self.backbone.relu4_1(x)
        x = self.backbone.conv4_2(x)
        x = self.backbone.relu4_2(x)
        x = self.backbone.conv4_3(x)
        x = self.backbone.relu4_3(x)
        x = self.backbone.pool4(x)
        pool4 = x  # 1/16

        x = self.backbone.conv5_1(x)
        x = self.backbone.relu5_1(x)
        x = self.backbone.conv5_2(x)
        x = self.backbone.relu5_2(x)
        x = self.backbone.conv5_3(x)
        x = self.backbone.relu5_3(x)
        x = self.backbone.pool5(x)
        
        x = self.relu1(self.fc1(x))
        x = self.drop1(x)

        x = self.relu2(self.fc2(x))
        x = self.drop2(x)

        x = self.score_fr(x)
        x = self.upscore2(x)
        upscore2 = x

        x = self.score_pool4(pool4)
        x = x[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
        score_pool4c = x  # 1/16

        x = upscore2 + score_pool4c

        x = self.upscore16(x)
        x = x[:, :, 27:27 + h, 27:27 + w].contiguous()
        return x


In [67]:

class DiceCoeff(Function):
    """Dice coeff for one pair of input image and target image"""
    def forward(self, prediction, target):
        self.save_for_backward(prediction, target)
        eps = 0.0001 # in case union = 0
        
        A = prediction.contiguous().view(-1)
        B = target.contiguous().view(-1)
        self.inter = torch.dot(A, B)
        self.union = A.sum()+B.sum()- torch.dot(A, B)+ eps
        # Calculate DICE 
        d = self.inter /  self.union  
        return d

# Calculate dice coefficients for batches
def dice_coeff(prediction, target):
    """Dice coeff for batches"""
    s = torch.FloatTensor(1).zero_()

    for i, (a,b) in enumerate(zip(prediction, target)):
        s = DiceCoeff().forward(a,b)  
    s = s / (i + 1)
    return s

In [68]:
class Flip(object):
    """
    Flip the image left or right for data augmentation, but prefer original image.
    """
    def __init__(self,ori_probability=0.60):
        self.ori_probability = ori_probability
 
    def __call__(self, sample):
        if random.uniform(0,1) < self.ori_probability:
            return sample
        else:
            img, label = sample['img'], sample['label']
            img_flip = img[:,:,::-1]
            label_flip = label[:,::-1]
            
            return {'img': img_flip, 'label': label_flip}
        
class ToTensor(object):
    """
    Convert ndarrays in sample to Tensors.
    """
    def __init__(self):
        pass

    def __call__(self, sample):
        image, label = sample['img'], sample['label']

        return {'img': torch.from_numpy(image.copy()).type(torch.FloatTensor),
                'label': torch.from_numpy(label.copy()).type(torch.FloatTensor)}

In [69]:
class CustomDataset(Dataset):
    def __init__(self, image_masks, transforms=None): 

        self.image_masks = image_masks
        self.transforms = transforms
    
    def __len__(self):  # return count of sample we have

        return len(self.image_masks)
    
    def __getitem__(self, index):

        image = np.array(self.image_masks[index][0]) # Channel,H, W
        mask = self.image_masks[index][1]
        
        sample = {'img': image, 'label': mask}
        
        if transforms:
            sample = self.transforms(sample)
            
        return sample

In [70]:
def get_upsampling_weight(in_channels, out_channels, kernel_size):
    """
    Make a 2D bilinear kernel suitable for unsampling
    """
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    bilinear_filter = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float32)
    weight[range(in_channels), range(out_channels), :, :] = bilinear_filter
    return torch.from_numpy(weight).float()


def bilinear_upsampling(in_channels, out_channels, kernel_size, stride, bias=False):
    initial_weight = get_upsampling_weight(in_channels, out_channels, kernel_size)
    layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, bias=bias)
    layer.weight.data.copy_(initial_weight)
    # weight is frozen because it's just a bilinear upsampling
    layer.weight.requires_grad = False
    return layer

Load train data:

In [71]:
train_img_masks_save_path = 'Pickles/train_img_masks.pickle'
if os.path.exists(train_img_masks_save_path):
    with open(train_img_masks_save_path,'rb') as f:
        train_img_masks = pickle.load(f)
    f.close()
else:
    pickle_store(train_img_masks_save_path,train_img_masks)

val_img_masks_save_path = 'Pickles/val_img_masks.pickle'
if os.path.exists(val_img_masks_save_path):
    with open(val_img_masks_save_path,'rb') as f:
        val_img_masks = pickle.load(f)
    f.close()
else:
    pickle_store(val_img_masks_save_path,val_img_masks)

In [72]:
train_dataset = CustomDataset(train_img_masks, transforms=transforms.Compose([Flip(),ToTensor()]))
val_dataset = CustomDataset(val_img_masks, transforms=transforms.Compose([Flip(),ToTensor()]))

Create net from the FCN model

In [73]:
net = FCN16s().to(device,dtype=torch.float32)
net.to(device) 
print(net)

FCN16s(
  (backbone): VGG16(
    (conv1_1): Conv2d(4, 64, kernel_size=(3, 3), stride=(1, 1), padding=(100, 100))
    (relu1_1): ReLU(inplace=True)
    (conv1_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1_2): ReLU(inplace=True)
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv2_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2_1): ReLU(inplace=True)
    (conv2_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2_2): ReLU(inplace=True)
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv3_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3_1): ReLU(inplace=True)
    (conv3_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3_2): ReLU(inplace=True)
    (conv3_3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3_3):

Train net:

In [74]:
# Specify number of epochs, image scale factor, batch size and learning rate
epochs =  10        # e.g. 10, or more until CE converge
batch_size = 16    # e.g. 16
lr =   0.001          # e.g. 0.01
N_train = len(train_img_masks)

if not os.path.exists('Model_3'):
    os.mkdir('Model_3')
model_save_path = 'Model_3/'  # directory to same the model after each epoch. 

#optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9,weight_decay=1e-4)
optimizer=optim.Adam(net.parameters(),lr=lr,weight_decay=1e-4)

# The loss function we use is Cross Entropy
criterion = nn.CrossEntropyLoss()

# Start training  #This part takes very long time to run if using CPU
for epoch in range(epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    net.train()
    # Reload images and masks for training and validation and perform random shuffling at the begining of each epoch
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    epoch_loss = 0
    count = 0
    for i, b in enumerate(train_loader):
        
        imgs = b['img'].to(device,dtype=torch.float32)
        true_masks = b['label'].to(device,dtype=torch.long)        
        masks_pred = net(imgs)
        # Calculate the loss by comparing the predicted masks vector and true masks vector
        # And sum the losses together 
        loss = criterion(masks_pred,true_masks.long())
        epoch_loss += loss
        if count % 20 == 0:  #Print status every 20 batch
            print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item())) 
        count = count + 1

Starting epoch 1/10.
0.0000 --- loss: 1.609382
0.0800 --- loss: 1.609382


KeyboardInterrupt: 