In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

from PIL import Image

# U-Net Architecture

![UNET Architecture](UNET_architecture.png)

In [2]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3),
        nn.ReLU(inplace=True))

def up_trans(in_channels, out_channels):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

def crop(original_tensor, target_tensor):
    target_size = target_tensor.size()[2]
    original_size = original_tensor.size()[2]
    delta = abs(original_size - target_size)
    start = delta // 2
    end = original_size - start
    return original_tensor[:, :, start:end, start:end]

In [3]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        
        self.up_trans_6 = up_trans(1024, 512)
        self.up_trans_7 = up_trans(512, 256)
        self.up_trans_8 = up_trans(256, 128)
        self.up_trans_9 = up_trans(128, 64)
        
        self.up_conv_6 = double_conv(1024, 512)
        self.up_conv_7 = double_conv(512, 256)
        self.up_conv_8 = double_conv(256, 128)
        self.up_conv_9 = double_conv(128, 64)
        
        self.out = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)
        
    def forward(self, img, verbose=0):
        # Contracting path
        # Block 1
        contracting_1 = self.down_conv_1(img)
        if verbose:
            print(f'Conntracting Block 1: {contracting_1.shape}')
        
        # Block 2
        contracting_2 = self.max_pool_2x2(contracting_1)  
        contracting_2 = self.down_conv_2(contracting_2)
        if verbose:
            print(f'Conntracting Block 2: {contracting_2.shape}')

        # Block 3
        contracting_3 = self.max_pool_2x2(contracting_2)
        contracting_3 = self.down_conv_3(contracting_3)
        if verbose:
            print(f'Conntracting Block 3: {contracting_3.shape}')

        # Block 4
        contracting_4 = self.max_pool_2x2(contracting_3)
        contracting_4 = self.down_conv_4(contracting_4)
        if verbose:
            print(f'Conntracting Block 4: {contracting_4.shape}')

        # Block 5
        contracting_5 = self.max_pool_2x2(contracting_4)
        contracting_5 = self.down_conv_5(contracting_5)
        if verbose:
            print(f'Conntracting Block 5: {contracting_5.shape}')
        
        # Expansive path
        # Block 6
        expansive_6 = self.up_trans_6(contracting_5)
        contracting_4_cropped = crop(contracting_4, expansive_6)
        concat = torch.cat([contracting_4_cropped, expansive_6], dim=1)
        expansive_6 = self.up_conv_6(concat)
        if verbose:
            print(f'Expansive Block 6: {expansive_6.shape}')
        
        # Block 7
        expansive_7 = self.up_trans_7(expansive_6)
        contracting_3_cropped = crop(contracting_3, expansive_7)
        concat = torch.cat([contracting_3_cropped, expansive_7], dim=1)
        expansive_7 = self.up_conv_7(concat)
        if verbose:
            print(f'Expansive Block 7: {expansive_7.shape}')
        
        # Block 8
        expansive_8 = self.up_trans_8(expansive_7)
        contracting_2_cropped = crop(contracting_2, expansive_8)
        concat = torch.cat([contracting_2_cropped, expansive_8], dim=1)
        expansive_8 = self.up_conv_8(concat)
        if verbose:
            print(f'Expansive Block 8: {expansive_8.shape}')
        
        # Block 9
        expansive_9 = self.up_trans_9(expansive_8)
        contracting_1_cropped = crop(contracting_1, expansive_9)
        concat = torch.cat([contracting_1_cropped, expansive_9], dim=1)
        expansive_9 = self.up_conv_9(concat)
        output = self.out(expansive_9)
        if verbose:
            print(f'Expansive Block 9: {expansive_9.shape}')
        
        return output
        

In [4]:
# Test U-Net with 1 random image
model = UNet()
img = torch.rand(1, 1, 572, 572) # batch_size, channel, height, width
print(f'Input: {img.shape}')
output = model(img, verbose=1)
print(f'Output: {output.shape}')

Input: torch.Size([1, 1, 572, 572])
Conntracting Block 1: torch.Size([1, 64, 568, 568])
Conntracting Block 2: torch.Size([1, 128, 280, 280])
Conntracting Block 3: torch.Size([1, 256, 136, 136])
Conntracting Block 4: torch.Size([1, 512, 64, 64])
Conntracting Block 5: torch.Size([1, 1024, 28, 28])
Expansive Block 6: torch.Size([1, 512, 52, 52])
Expansive Block 7: torch.Size([1, 256, 100, 100])
Expansive Block 8: torch.Size([1, 128, 196, 196])
Expansive Block 9: torch.Size([1, 64, 388, 388])
Output: torch.Size([1, 1, 388, 388])


# Custom Dataset

In [5]:
class MRIDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.masks = os.listdir(mask_dir)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        # Load
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace('.tif', '_mask.tif'))
        assert os.path.exists(img_path)
        assert os.path.exists(mask_path)
        
        # Convert to grayscale
        image = np.array(Image.open(img_path).resize((572,572)).convert('L'), dtype=np.float32)
        mask = np.array(Image.open(mask_path).resize((388,388)), dtype=np.float32)

        # Resize
        resized_img = np.zeros((1,572,572))
        resized_img[0,:,:] = image
        resized_mask = np.zeros((1,388,388))
        resized_mask[0,:,:] = mask
        
        # Convert to binary
        max_ = 255.0
        resized_mask = resized_mask/max_
        
        # Convert to tensor
        tensor_img = torch.from_numpy(resized_img)
        tensor_mask = torch.from_numpy(resized_mask)
        
        # Apply transformation
        if self.transform is not None:
            tensor_img = self.transform(tensor_img)
            
        return tensor_img, tensor_mask

# Data Loader

In [6]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

TRAIN_IMG_DIR = 'data/train_images/'
TRAIN_MASK_DIR = 'data/train_masks/'
VAL_IMG_DIR = 'data/val_images/'
VAL_MASK_DIR = 'data/val_masks/'

In [7]:
train_set = MRIDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR)
val_set = MRIDataset(VAL_IMG_DIR, VAL_MASK_DIR)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)

In [8]:
# Check shape of image
img, mask = train_set[0]
img.shape, mask.shape

(torch.Size([1, 572, 572]), torch.Size([1, 388, 388]))

# Loss Function

In [9]:
def soft_dice(prediction, ground_truth):
    n_images = len(prediction)
    loss = 2*torch.mul(prediction, ground_truth).sum()
    loss /= (prediction.sum() + ground_truth.sum())
    return loss
    # return torch.clamp(loss, min=1e-7, max=1-1e-7)

In [10]:
pred = torch.rand(1,1,338,338)
truth = torch.rand(1,1,338,338)
test_loss = soft_dice(pred, truth)
test_loss.dtype, test_loss

(torch.float32, tensor(0.4992))

# Train

In [11]:
# train, validation
# Next time: show train result
# Ask in telegram chat 

In [13]:
EPOCHS = 1
train_loss_history = []

opt = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(EPOCHS):
    print(f'========== Epoch {epoch} ==========')
    counter = 1
    for X, y in train_loader:
        print(f'---------- Batch {counter} ----------')
        counter += 1
        
        # Get prediction from the model
        X = X.float()
        y = y.float()
        pred = model(X)
        
        # Calculate loss
        loss = soft_dice(pred, y)
        
        # Calculate gradients
        loss.backward()
        
        # Optimize
        opt.step()
        opt.zero_grad()
        
        # Save train loss history
        train_loss_history.append(loss.data.cpu().numpy())

---------- Batch 1 ----------
---------- Batch 2 ----------


KeyboardInterrupt: 

In [None]:
plt.plot(train_loss_history)
plt.show()

In [None]:
def Trainer(model, train_loader, val_loader,  optimizer, epochs, sheduler = None):
    train_loss_history, train_accuracy, train_dice= [], [], []

    for epoch in range(num_epochs):
        model.train(True)
        for (X, y) in train_loader:
            # loss, accuracy= compute_loss(X, y)
            # dice = dice_coef(y,X)
            loss = compute_loss(X, y)
            loss.backward()
            opt.step()
            opt.zero_grad()
            train_loss_history.append(loss.data.cpu().numpy())
            # train_accuracy.append(accuracy)
            # train_dice.append(dice)
        
        clear_output()
        if sheduler is not None:
            sheduler.step(train_loss_history[-1])
        print("Last loss:\t{}\nEpoch number:\t{}\nCurrent Learning rate:{}".format(train_loss_history[-1], epoch, optimizer.state_dict()['param_groups'][0]['lr']))
        plt.plot(train_loss_history)
        plt.show()


model = UNet()
opt = torch.optim.Adam(model.parameters(), lr=0.01)
shedule = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', patience=20)
num_epochs = 1
batch_size = 20

train_batch = train_set[:batch_size]
test_batch = test_set[:batch_size]

try:
    Trainer(model, train_batch, test_batch, epochs=num_epochs, optimizer=opt, sheduler=shedule)
except KeyboardInterrupt:
    pass