In [None]:
# General use for DL
import torch
import numpy as np

# To loader dataset
from torch.utils.data import DataLoader

# Data Measurements
from torchmetrics import JaccardIndex
from torchmetrics import Dice

# Dataset import
from dataset import ISICSegmentationDataset

# Plotting
import matplotlib.pyplot as plt

from exp_utils import stopping_rule, moving_avg
from train import train_model, test_model


Load Dataset

In [None]:
# Paths to masks and images
image_path = "./ISIC/images/ISIC2018_Task1-2_Training_Input/"
masks_path = "./ISIC/masks/ISIC2018_Task1_Training_GroundTruth/"
# Size of image
size = 256
# Define dataset
dataset = ISICSegmentationDataset(image_path, masks_path, size)

Hyperparameters

In [None]:
## Preliminary variables ##

# Specifies whether to train on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loss for training
loss_fn = torch.nn.BCELoss()

# Measurements
jaccard = JaccardIndex(task='multiclass', num_classes = 2, average = 'micro').to(device)
dice = Dice(num_classes = 2, average = 'micro').to(device)

# Batch Size
BATCH_SIZE = 16

# Stopping threshold
threshold = 0.001

# Speeds up training
num_workers = 8

# Alpha for EMA
alpha = 0.9

Load Test Indices

In [None]:
# File name
name = 'baseline.pt'
baseline = torch.load(f=name)
# Extract indices and baseline jaccard and dice scores
all_indices = baseline['fold_dict']

Training

In [None]:
splits = ['A', 'B', 'C', 'D', 'E']
split_dices = {}
for split in splits:
    # Get indices of test and train points in dataset
    indices = all_indices[split]
    train_indices = indices[0]
    test_indices = indices[1]
    
    # Array to store each splits dices
    test_dice = []
    
    # Create test dataloader for fold
    test_loader = DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE,
        sampler=torch.utils.data.SubsetRandomSampler(test_indices),
        num_workers = num_workers
    )
    
    # Array to store indices of model, increased by 10 at a time until we get to the whole train set
    model_indices = []
    
    while len(train_indices) > 0:
        # To prevent selection of more points than available, causing error
        n_to_select = min(10, len(train_indices))
        # Get a random 10 point of the train set
        model_indices.extend(np.random.choice(train_indices, n_to_select, replace=False))
        # Remove the indices from the the next round of selection
        train_indices = np.setdiff1d(train_indices, model_indices) 
        # Create a train dataloader for partition
        train_loader = DataLoader(
            dataset=dataset,
            batch_size = BATCH_SIZE,
            sampler=torch.utils.data.SubsetRandomSampler(model_indices),
            num_workers = num_workers
        )
        
        ## Initialize the model ##

        # Loading an untrained model to GPU/CPU
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
            in_channels=3, out_channels=1, init_features=64, pretrained=False, trust_repo=True).to(device)
        # We will begin our learning rate at 0.01 
        lr = 0.01
        # Optimizer for model
        optimizer = torch.optim.Adam(model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,25)
        
        ## Initialize the training ##

        # Initialize previous and current loss for stopping rule
        
        L_MA = 1 # Moving average of loss, intialized to zero for error purposes
        L_k = 0 # Current loss
        
        # To determine in threshold is too low
        counter = 0
        # Train model until stopping rule is reached
        while(stopping_rule(L_MA, L_k, threshold) or counter < 10):
            # Train model and compute loss
            L_k = train_model(model, loss_fn, device, train_loader, optimizer)
            
            # Initialization of EMA
            if(L_MA == 0):
                L_MA = L_k
            
            # Find EMA of losses
            L_MA = moving_avg(alpha, L_MA, L_k)
        # Get the dice score of the trained model
        jaccard_score, dice_score = test_model(model, device, test_loader, jaccard, dice)
        print(f"Model split {split}, {len(model_indices)} data points | Dice Score: {dice_score}")
        # Store the dice score
        test_dice.append(dice_score)
        
    # Store that splits dice scores in dictionary
    split_dices[split] = test_dice

# Store the dictionary of dice scores
torch.save(split_dices, 'split_dices.pt')



Chart the dice progression average

In [None]:
# Get the length of a arbitrary split to compute the average dice scores of them 
average_dices = np.zeros(len(splits["A"]))
for split in splits:
    average_dices += split_dices[split]
average_dices /= len(splits)

plt.plot(average_dices, label="Average Dice Scores")
plt.xlim(0, 800)
plt.ylim(0, 1)
plt.show()