In [1]:
import os
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
import random
from PIL import Image, ImageSequence
import os
import numpy as np

In [2]:
# Function to load training images and their corresponding images
# TODO: Retrain the model after applying data augmentation
def load_train_dataset(data_folder):
    path_train_images = os.path.join(data_folder, 'train-volume.tif')
    path_train_labels = os.path.join(data_folder, 'train-labels.tif')

    images = np.array([[[np.array(page)]] for page in ImageSequence.Iterator(Image.open(path_train_images))])
    raw_labels = np.array([np.array(page) // 255 for page in ImageSequence.Iterator(Image.open(path_train_labels))])
    labels = np.zeros((raw_labels.reshape(-1).size, 2))
    labels[np.arange(raw_labels.reshape(-1).size), raw_labels.reshape(-1)] = 1
    return zip(images, labels.reshape((30, 512, 512, 2)))

In [3]:
def load_test_dataset(data_folder):
    path_test_images = os.path.join(data_folder, 'test-volume.tif')

    images = np.array([[[np.array(page)]] for page in ImageSequence.Iterator(Image.open(path_test_images))])
    return images

### Contraction path (downsampling) consists of a
### typical CNN architecture, by consecutive stacking two 3x3 convolutions (blue arrow) 
### followed by a 2x2 max pooling (red arrow) for downsampling. 
### At each downsampling step, the number of channels is doubled.

In [4]:
import torch
from torch import nn
import torch.nn.functional as F

class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2), # Adding dropout layer for regularisation
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2) # Adding dropout layer of regularisation
        )

    def forward(self, x):
        x = self.conv(x)
        return x

### 'Each upsample block passes the input to two 3X3 CNN layers followed by a 2X2 upsampling layer. 
### Also after each block number of feature maps used by convolutional layer get half to maintain symmetry. 
### However, every time the input is also get appended by feature maps of the corresponding contraction layer. 
### This action would ensure that the features that are learned while contracting the image will be used to reconstruct it


In [5]:
class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()
        self.up_scale = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)

    def forward(self, x1, x2):
        x2 = self.up_scale(x2)
        x = torch.cat([x2, x1], dim=1)
        return x

In [6]:
class down_layer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down_layer, self).__init__()
        self.pool = nn.MaxPool2d(2, stride=2, padding=0)
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(self.pool(x))
        return x

class up_layer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up_layer, self).__init__()
        self.up = up(in_ch, out_ch)
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        a = self.up(x1, x2)
        x = self.conv(a)
        return x



### Define neural architecture

In [7]:
class UNet(nn.Module):
    def __init__(self, dimensions=2):
        super(UNet, self).__init__()
        self.conv1 = double_conv(1, 64)
        self.down1 = down_layer(64, 128)
        self.down2 = down_layer(128, 256)
        self.down3 = down_layer(256, 512)
        self.down4 = down_layer(512, 1024)
        self.up1 = up_layer(1024, 512)
        self.up2 = up_layer(512, 256)
        self.up3 = up_layer(256, 128)
        self.up4 = up_layer(128, 64)
        self.last_conv = nn.Conv2d(64, dimensions, 1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x1_up = self.up1(x4, x5)
        x2_up = self.up2(x3, x1_up)
        x3_up = self.up3(x2, x2_up)
        x4_up = self.up4(x1, x3_up)
        output = self.last_conv(x4_up)
        return torch.sigmoid(output)


In [8]:
data_folder = 'data'
model_path = 'model/unet.pt'
saving_interval = 5
epoch_number = 25
from dataset import load_train_dataset, load_test_dataset
from unet import UNet
training_dataset = list(load_train_dataset(data_folder))


## Initialization of weights 
### initialize  weights from a gaussian distribution with standard deviation of sqrt(N)
### where N is the the number of input features


In [9]:
# takes in a module and applies the specified weight initialization
def weights_init_normal(m):
    # for every Conv2D  layer in a model..
    if isinstance(m,nn.Conv2d):
        n = m.in_channels
        y = (1.0/np.sqrt(n))
        m.weight.data.normal_(0, y)
        m.bias.data.fill_(0)

### Use cross entropy loss function 

### According to paper use a high momentum (0.99)
### such that a large number of the previously seen training samples determine the
### update in the current optimization step.

In [10]:
def train():
    model = UNet(dimensions=2)
    model.apply(weights_init_normal)
    if os.path.isfile(model_path):
        model.load_state_dict(torch.load(model_path))
    #optimizer = optim.Adam(model.parameters(), lr = 0.001)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.99)
    training_dataset = list(load_train_dataset(data_folder))
    for epoch in range(epoch_number):
        #running_loss = 0
        i = random.randint(0, len(training_dataset) - 1)
            
        (input, label) = training_dataset[i]
        optimizer.zero_grad()
        target = torch.from_numpy(label).float()
        output = model(torch.from_numpy(input.astype(np.float32))).permute(0, 2, 3, 1)
        loss = F.binary_cross_entropy(output, target)
        step_loss = loss.item()
        print(f'Epoch: {epoch} \tLoss: {step_loss}')

        loss.backward()
        optimizer.step()                    

        if (epoch + 1) % saving_interval == 0:
            print('Saving model')

            torch.save(model.state_dict(), model_path)
    torch.save(model.state_dict(), model_path)
    return

if __name__ == "__main__":
    train()




Epoch: 0 	Loss: 0.41935834288597107
Epoch: 1 	Loss: 0.3835543096065521
Epoch: 2 	Loss: 0.4188568890094757
Epoch: 3 	Loss: 0.519564151763916
Epoch: 4 	Loss: 0.421855092048645
Saving model
Epoch: 5 	Loss: 0.4103362262248993
Epoch: 6 	Loss: 0.45291823148727417
Epoch: 7 	Loss: 0.5653619766235352
Epoch: 8 	Loss: 0.5040107369422913
Epoch: 9 	Loss: 0.4135444462299347
Saving model
Epoch: 10 	Loss: 0.47338300943374634
Epoch: 11 	Loss: 0.43941980600357056
Epoch: 12 	Loss: 0.41271787881851196
Epoch: 13 	Loss: 0.3778258264064789
Epoch: 14 	Loss: 0.41888412833213806
Saving model
Epoch: 15 	Loss: 0.40619927644729614
Epoch: 16 	Loss: 0.40149495005607605
Epoch: 17 	Loss: 0.5137727856636047
Epoch: 18 	Loss: 0.3999673128128052
Epoch: 19 	Loss: 0.374215304851532
Saving model
Epoch: 20 	Loss: 0.4518355429172516
Epoch: 21 	Loss: 0.41597214341163635
Epoch: 22 	Loss: 0.38572871685028076
Epoch: 23 	Loss: 0.3439171016216278
Epoch: 24 	Loss: 0.3778289556503296
Saving model


### It seems that performance is not improving much after 10 epochs

In [12]:
data_folder = 'data'
model_path = 'model/unet.pt'

def predict():
    model = UNet()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint) 
    image_count = 1
    for input in load_test_dataset(data_folder):
        output = model(torch.from_numpy(input.astype(np.float32))).permute(0, 2, 3, 1).detach().numpy()
        input_array = input.reshape((512, 512))
        output_array = output.argmax(3).reshape((512, 512)) * 255
        input_img = Image.fromarray(input_array)
        output_img = Image.fromarray(output_array.astype(dtype=np.uint16)).convert('L')
        input_img.save('output/input_image'+str(image_count)+".png")
        output_img.save('output/output_image'+str(image_count)+".png")
        image_count = image_count +1
        print(image_count)
    return

if __name__ == "__main__":
    predict()



2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
