In [1]:
import torch
import torch.nn as nn
import torchvision


from torch.utils.data import DataLoader
from torch.utils.data import Dataset, random_split
import torchvision.transforms.functional as TF
import torch.optim as optim

import sys
import datetime

from PIL import Image

import numpy as np
from matplotlib import pyplot as plt

In [4]:

should_train = False

path_to_trained_model = 'models_final/lon_unet/trained_unet_model.pth'
path_to_train_loss = 'models_final/lon_unet/lon_unet_train_losses.txt'
path_to_val_loss = 'models_final/lon_unet/lon_unet_val_losses.txt'


In [None]:
class HistActivation(nn.Module):
    def __init__(self):
        super().__init__()

    def sigmoid(self, x):
        return 1 / (1 + torch.exp(-x))

    def sigmoid_derivative(self, x):
        return self.sigmoid(x) * (1 - self.sigmoid(x))
    
    def forward(self, x):
        return self.sigmoid_derivative(x)

class Hist(nn.Module):
    def __init__(self,nBins=10,KSize=(3,3),WSize=(3,3)):
        super().__init__()
        self.nBins = nBins
        self.b = nn.Parameter(torch.randn(nBins)).to(device)
        self.K = nn.Parameter(torch.randn(1,1,*KSize)).to(device) # kernel init
        self.W = nn.Parameter(torch.randn(1,1,*WSize)).to(device) # kernel init
        self.act = HistActivation()
        # reordering to save time in forward()
        self.V = torch.cat([self.W for i in range(nBins)],dim=0)
        self.bias = self.b.view(1,nBins,1,1)

    def forward(self, I):
        IK = nn.functional.conv2d(I.to(device), self.K, None, stride=1, padding=1)
        X = nn.functional.conv2d(self.act(self.bias - IK), self.V, None, padding='same', groups=self.nBins)
        return X
    # U-net up-convolution
class convolution(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_c)
        
        self.relu = nn.ReLU()
        
    def forward(self, data):
        x = self.conv1(data)
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        x = self.relu(x)
        return x
    
class encoder(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = convolution(in_c, out_c)
        
    def forward(self, data):
        x = self.conv(data)
        return x
    
# U-net decoder
class decoder(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = convolution(out_c + out_c, out_c)
        
    def forward(self, data, skip): # skip connections
        x = self.up(data)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class LON(nn.Module):
    def __init__(self, nKernels, nBins, nOut):
        super().__init__()
        self.convs = nn.ModuleList([Hist(nBins) for i in range(nKernels)])
        self.lin = nn.Linear(nKernels*nBins,nOut)

    def forward(self, X):
        X = [h(X) for h in self.convs]
        X = torch.cat(X,1)
        X = torch.permute(self.lin(torch.permute(X,(0,2,3,1))),(0,3,1,2))
        return X
    
class LON_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d((2,2))
        
        """ Encoding """
        self.en1 = LON(2,10,64)
        self.en2 = encoder(64, 128)
        self.en3 = encoder(128, 256)
        self.en4 = encoder(256, 512)

        
        # """ Bottleneck """
        self.bottle = convolution(512, 1024)
        
        # """ Decoding """
        self.de1 = decoder(1024, 512)
        self.de2 = decoder(512, 256)
        self.de3 = decoder(256, 128)
        self.de4 = decoder(128, 64)
        
        """ Classifier """
        self.last = nn.Conv2d(64, 1, kernel_size=1, padding=0)
        
    
    def forward(self, data):
        """ Encoding """
        s1 = self.en1(data)
        p1 = self.pool(s1)
        s2 = self.en2(p1)
        p2 = self.pool(s2)
        s3 = self.en3(p2)
        p3 = self.pool(s3)
        s4 = self.en4(p3)
        p4 = self.pool(s4)
        
        """ Bottleneck """
        b = self.bottle(p4)
        
        """ Decoding """
        d1 = self.de1(b, s4)
        d2 = self.de2(d1, s3)
        d3 = self.de3(d2, s2)
        d4 = self.de4(d3, s1)
        
        """ Classifier """
        outs = self.last(d4)
        
        return torch.sigmoid(outs)


class Dataset(Dataset):
    def __init__(self, ids):
        self.ids = ids

    def transform(self, train_data, train_labels):
        return TF.to_tensor(train_data), TF.to_tensor(train_labels)

    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        id = self.ids[index]

        X = Image.open(f"/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/train_images/train_image_{id}.tif")

        y = Image.open(f"/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/train_labels/train_label_{id}.tif")

        patch_size = 512
        batch_size = 8

        # split images into smaller images of size patch_size x patch_size
        X_patch = [TF.to_tensor(X.crop((i, j, i+patch_size, j+patch_size))) for i in range(0, X.width, patch_size) for j in range(0, X.height, patch_size)]
        y_patch = [TF.to_tensor(y.crop((i, j, i+patch_size, j+patch_size))) for i in range(0, y.width, patch_size) for j in range(0, y.height, patch_size)]

        # create list of batches
        X_batches = [torch.stack(X_patch[i:i+batch_size]) for i in range(0, len(X_patch), batch_size)]
        y_batches = [torch.stack(y_patch[i:i+batch_size]) for i in range(0, len(y_patch), batch_size)]

        batch_ids = [(batch, id) for batch in zip(X_batches, y_batches)]
        
        return batch_ids