In [1]:
!pip install torch
!pip install torchmetrics

[0m

# Imports
Packages required for experiment

In [2]:
# General use for DL
import torch
import numpy as np

# To loader dataset
from torch.utils.data import DataLoader

# Data Measurements
from torchmetrics import JaccardIndex
from torchmetrics import Dice

# Dataset import
from dataset import BUIDSegmentationDataset

# Stopping Rule
When this function returns a value less than 0.1 % = 0.001 then we know to stop training

In [3]:
# Finds percent change between previous and current loss
# and if it is less than threshold return false
def stopping_rule(L_k, L_k1, threshold):
    return abs(L_k - L_k1) / L_k > threshold

In [4]:
# Calculates the moving average
def moving_avg(alpha, L_MA, L_k):
    return alpha * L_MA + (1-alpha) * L_k

# Training Function

In [5]:
def train_model(model, loss_fn, device, train_loader, optimizer):
    # Initalize loss
    average_loss = 0
    # Train on dataset
    model.train()
    for batch_idx, (X,y) in enumerate(train_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        # Compute loss
        loss = loss_fn(output, mask)
        average_loss += loss.item()
        # Optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Return average loss
    return average_loss / len(train_loader)

# Testing Function

In [6]:
def test_model(model, device, test_loader, jaccard, dice):
    # Initalize average jaccard and dice
    average_jaccard = 0
    average_dice = 0
    # Test the model
    model.eval()
    for batch_idx, (X,y) in enumerate(test_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        average_jaccard += jaccard(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
        average_dice += dice(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
    # Get average of dice and jaccard scores
    average_jaccard /= len(test_loader)
    average_dice /= len(test_loader)

    # Return values
    return average_jaccard, average_dice

# Function to partition list
We will use this function to partition the list of train indices into eighths. If we remember the train indices consist of 80 % of the original dataset, so each partition will contain 10 % of the original dataset. 

In [7]:
def split_into_eights(list):
    # Floor division of length of list
    partition_size = len(list) // 8
    remainder  = len(list) % 8
    
    # List that will store each parition
    partitions = []
    
    # Partition the list, if partition is not even distrubutes 
    # remainder between beginning paritions
    start = 0
    for i in range(8):
        end = start + partition_size + (1 if i < remainder and i != 0 else 0)
        partitions.append(list[start:end])
        start = end
    return partitions  

# Load Data
Load the baseline data and indices splits from previous stage.

In [8]:
# File name
name = 'baseline.pt'
baseline = torch.load(f=name)
# Extract indices and baseline jaccard and dice scores
all_indices = baseline['fold_dict']
baseline_jaccard = baseline['baseline_jaccard']
baseline_dice = baseline['baseline_dice']


Load the dataset

In [9]:
# Path to images and mask
root_dir = './BUID'
# U-Net we are using takes 3 x 256 x 256 images
size = 256
# Importing the dataset
dataset = BUIDSegmentationDataset(root_dir, size)


Specify hyperparameters

In [10]:
## Preliminary variables ##

# Specifies whether to train on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loss for training
loss_fn = torch.nn.BCELoss()

# Measurements
jaccard = JaccardIndex(task='multiclass', num_classes = 2, average = 'micro').to(device)
dice = Dice(num_classes = 2, average = 'micro').to(device)

# Batch Size
BATCH_SIZE = 16

# Stopping threshold
threshold = 0.001

# Speeds up training
num_workers = 8

# Speeds up training
num_workers = 8

# Alpha for EMA
alpha = 0.9


In [None]:
splits = ['A', 'B', 'C', 'D', 'E']
partitions = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7]
# Cycle through all test splits
for split in splits:
    # Get indices of test and train points in dataset
    indices = all_indices[split]
    train_indices = indices[0]
    test_indices = indices[1]
    # Split train indices into eighths
    partition_indices = split_into_eights(train_indices)
    # Create test dataloader for fold
    test_loader = DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE,
        sampler=torch.utils.data.SubsetRandomSampler(test_indices),
        num_workers = num_workers
    )
    # Cycle through all partition lengths of training data (0.1,0.2,...,0.7)
    for i in range(1, len(partition_indices)):
        
        # Initialize Jaccard and Dice
        average_jaccard = 0
        average_dice = 0
                
        train_indices_i = np.hstack(partition_indices[0:i])
        remaining_indices = np.hstack(partition_indices[i:])
        
        # Create a train dataloader for partition
        train_loader = DataLoader(
            dataset=dataset,
            batch_size = BATCH_SIZE,
            sampler=torch.utils.data.SubsetRandomSampler(train_indices_i),
            num_workers = num_workers
        )
        
        ## Initialize the model ##

        # Loading an untrained model to GPU/CPU
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
            in_channels=3, out_channels=1, init_features=64, pretrained=False, trust_repo=True).to(device)
        # We will begin our learning rate at 0.01 
        lr = 0.01
        # Optimizer for model
        optimizer = torch.optim.Adam(model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,25)
        
        ## Initialize the training ##
    
        # Initialize previous and current loss for stopping rule
        
        L_MA = 1 # Moving average of loss
        L_k = 0 # Current loss
        
        # To determine in threshold is too low
        counter = 0
        # Train model until stopping rule is reached
        while(stopping_rule(L_MA, L_k, threshold) or counter < 10):
            # Train model and compute loss
            L_k = train_model(model, loss_fn, device, train_loader, optimizer)
            
            # Initialization of EMA
            if(L_MA == 0):
                L_MA = L_k
            
            # Find EMA of losses
            L_MA = moving_avg(alpha, L_MA, L_k)
            counter += 1
               
        # To determine in threshold is too low
        print(counter)
        # Test model on test split
        jaccard_score, dice_score = test_model(model, device, test_loader, jaccard, dice)
        print(f'Split:{split}|Partition:{partitions[i]}|Jaccard:{jaccard_score}|Dice:{dice_score}')
        # Save model, scores, and indices for next step
        name = f'./model/Split:{split}|Partition:{partitions[i]}|New'
        state = {
            'train_indices' : train_indices_i,
            'remaining_indices' : remaining_indices,
            'test_indices' : test_indices,
            'jaccard_score' : jaccard_score,
            'dice_score' : dice_score,
            'num_iterations' : counter,
            'state_dict' : model.state_dict(),
        }
        torch.save(state, f=name)
        
    

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


366
Split:A|Partition:0.1|Jaccard:0.8460275053977966|Dice:0.9164737701416016


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


492
Split:A|Partition:0.2|Jaccard:0.8686314165592194|Dice:0.9295298278331756


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


340
Split:A|Partition:0.3|Jaccard:0.8930659413337707|Dice:0.9431972205638885


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


170
Split:A|Partition:0.4|Jaccard:0.8934301137924194|Dice:0.9433905601501464


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


321
Split:A|Partition:0.5|Jaccard:0.9102159023284913|Dice:0.952603530883789


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


311
Split:A|Partition:0.6|Jaccard:0.9109900891780853|Dice:0.9532379150390625


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


981
Split:A|Partition:0.7|Jaccard:0.9092088997364044|Dice:0.9521798133850098


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


514
Split:B|Partition:0.1|Jaccard:0.86333127617836|Dice:0.9264231383800506


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


406
Split:B|Partition:0.2|Jaccard:0.891900897026062|Dice:0.9427349388599395


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


357
Split:B|Partition:0.3|Jaccard:0.9102275788784027|Dice:0.9529102981090546


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


287
Split:B|Partition:0.4|Jaccard:0.9310519218444824|Dice:0.9641994774341583


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


121
Split:B|Partition:0.5|Jaccard:0.9317119181156158|Dice:0.9645150184631348


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


339
Split:B|Partition:0.6|Jaccard:0.935376501083374|Dice:0.9665213882923126


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


252
Split:B|Partition:0.7|Jaccard:0.9324437916278839|Dice:0.9649609863758087


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


461
Split:C|Partition:0.1|Jaccard:0.852782028913498|Dice:0.9202640533447266


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


278
Split:C|Partition:0.2|Jaccard:0.8524227738380432|Dice:0.9199829399585724


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


129
Split:C|Partition:0.3|Jaccard:0.8639395058155059|Dice:0.9265961349010468


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


337
Split:C|Partition:0.4|Jaccard:0.8869680285453796|Dice:0.9395377814769745


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


353
Split:C|Partition:0.5|Jaccard:0.895977133512497|Dice:0.9450104057788848


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


356
Split:C|Partition:0.6|Jaccard:0.904334819316864|Dice:0.9494717597961426


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


442
Split:C|Partition:0.7|Jaccard:0.9055923461914063|Dice:0.9503263473510742


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


124
Split:D|Partition:0.1|Jaccard:0.8586161732673645|Dice:0.9236314773559571


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


113
Split:D|Partition:0.2|Jaccard:0.8250739216804505|Dice:0.9039363861083984


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


160
Split:D|Partition:0.3|Jaccard:0.8928453803062439|Dice:0.9431806564331054


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


140
Split:D|Partition:0.4|Jaccard:0.8974138915538787|Dice:0.945737075805664


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


307
Split:D|Partition:0.5|Jaccard:0.9158359348773957|Dice:0.9559115707874298


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


126
Split:D|Partition:0.6|Jaccard:0.9145877718925476|Dice:0.9552316665649414


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


1159
Split:D|Partition:0.7|Jaccard:0.9211888492107392|Dice:0.9588438332080841


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


149
Split:E|Partition:0.1|Jaccard:0.8876965165138244|Dice:0.940396785736084


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


113
Split:E|Partition:0.2|Jaccard:0.8618147850036622|Dice:0.925605422258377


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


129
Split:E|Partition:0.3|Jaccard:0.9033583164215088|Dice:0.9489863693714142


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


190
Split:E|Partition:0.4|Jaccard:0.9198612451553345|Dice:0.9579935073852539


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


276
Split:E|Partition:0.5|Jaccard:0.9307354509830474|Dice:0.9638971626758576


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


310
Split:E|Partition:0.6|Jaccard:0.9304155051708222|Dice:0.9638356506824494


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
