### Importing Libraries

In [1]:
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import PIL
from PIL import Image,ImageOps
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchmetrics import JaccardIndex


%load_ext autoreload
%autoreload 2

### Setting augmentation parameters

In [2]:
#Size of original train dataset
train_size = 200
#Use data augmentation
use_augmentation = 'baseline'

if use_augmentation == 'baseline':
    model_path = 'Models/model_baseline_augmentation_' + str(train_size)+'.pt'
elif use_augmentation:
    model_path = 'Models/model_datasetgan_augmentation_' + str(train_size)+'.pt'
else:
    model_path = 'Models/model_no_augmentation_' + str(train_size)+'.pt'

### Loading model and dataloaders

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Datasets
from dataloader import get_datasets
train_dataset,valid_dataset,test_dataset = get_datasets(train_size,use_augmentation)

#Dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


#Loading Model
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=1, out_channels=1, init_features=32, pretrained=False)
model = model.to(device)

#Hyperparamaters
lr = 0.0002
beta1 = 0.5
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()
jaccard = JaccardIndex(num_classes=2)

Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


### Dice Coefficient

In [4]:
#Dice Coefficient
def dice_coeff(pred, target):
    smooth = 1.
    pred = torch.where(pred>=0.5,1,0)
    iflat = pred.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

### Visualizing model outputs

In [5]:
#Function for visualization
def visualize(data,pred,label):
    batchsize = data.size()[0]
    fig, axs = plt.subplots(batchsize,3,figsize=(10,10))
    for i in range(batchsize):        
        img_plot = data[i].permute(1,2,0).detach().cpu()
        pred_plot = pred[i].permute(1,2,0).detach().cpu()
        label_plot = label[i].permute(1,2,0).detach().cpu()
        
        axs[i,0].imshow(img_plot)
        axs[i,0].set_title('Image')
        
        axs[i,1].imshow(pred_plot,cmap='gray')
        axs[i,1].set_title('Prediction')
        
        axs[i,2].imshow(label_plot,cmap='gray')
        axs[i,2].set_title('Label')
        
    plt.show()

### Validation and Test Loops

In [6]:
#Validation Loop
def validation(model,dataloader):
    losses = []
    model.eval()
    for i, (data, targets) in enumerate(dataloader, 0):
        data = data.to(device)
        targets = targets.to(device)

        pred = model(data)
        loss = criterion(pred,targets)
        losses.append(loss.item())
    return np.mean(losses)

#Test Loop
def test(model_path,dataloader):
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=1, out_channels=1, init_features=32, pretrained=False)
    weights = torch.load(model_path)
    model.load_state_dict(weights)
    model = model.to(device)
    dice_scores = []
    model.eval()
    with torch.no_grad():
        for i, (data, targets) in enumerate(dataloader, 0):
            data = data.to(device)
            targets = targets.type(torch.int8).to(device)

            pred = model(data)
            dice_scores.append(jaccard(pred.cpu(), targets.cpu()))
    return np.mean(dice_scores)

### Training

In [None]:
# Lists to keep track of progress
num_epochs = 100
valid_loss = 100


print("Starting Training Loop...")
# For each epoch
for epoch in tqdm(range(num_epochs)):
    losses = []
    
    # For each batch in the dataloader
    for i, (data, targets) in enumerate(train_dataloader, 0):
        
        data = data.to(device)
        targets = targets.to(device).to(torch.float32)
        optimizer.zero_grad()
        pred = model(data)
        loss = criterion(pred,targets)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    print("Epoch : %d, Loss : %2.5f" % (epoch,np.mean(losses)))
    
    if (epoch+1)%10==0:
        visualize(data[0:3],pred[0:3],targets[0:3])
    
    cur_loss = validation(model,valid_dataloader)
    if cur_loss<valid_loss:
        valid_loss = cur_loss
        torch.save(model.state_dict(),model_path)       

### Test acccuracy of all the models

In [7]:
model_names = sorted([i for i in os.listdir('Models') if i.endswith('.pt')])
for model in model_names:
    test_accuracy = test(os.path.join('Models/',model),test_dataloader)
    print(model,test_accuracy)

Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_baseline_augmentation_0.pt 0.79405504


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_baseline_augmentation_100.pt 0.9406273


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_baseline_augmentation_20.pt 0.90721476


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_baseline_augmentation_200.pt 0.94876254


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_baseline_augmentation_50.pt 0.9383766


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_datasetgan_augmentation_0.pt 0.82673806


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_datasetgan_augmentation_100.pt 0.94818574


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_datasetgan_augmentation_20.pt 0.9222699


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_datasetgan_augmentation_200.pt 0.95033234


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_datasetgan_augmentation_50.pt 0.9312889


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_no_augmentation_100.pt 0.9353795


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_no_augmentation_20.pt 0.86314094


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_no_augmentation_200.pt 0.9462622


Using cache found in /home/s2agarwal/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


model_no_augmentation_50.pt 0.9047231


### Results - Dice Scores



|Train Size| Without Augmentation | Baseline Augmentation| With Augmentation |
| :-: | :-: |:-: | :-: |
| 0   | -       | 0.794 | 0.826 |
| 20  | 0.863  | 0.907 | 0.922  |
| 50  | 0.904  | 0.938 | 0.931 |
| 100 | 0.935 | 0.940 | 0.948 |
| 200 | 0.946 | 0.948 | 0.950 |