In [1]:
!pip install torch
!pip install torchmetrics

[0m

# Imports
Packages required for experiment

In [2]:
# General use for DL
import torch
import numpy as np

# To loader dataset
from torch.utils.data import DataLoader

# Data Measurements
from torchmetrics import JaccardIndex
from torchmetrics import Dice

# Dataset import
from dataset import ISICSegmentationDataset
from base_dataset import BaseDataset

# Stopping Rule
When this function returns a value less than 0.1 % = 0.001 then we know to stop training

In [3]:
# Finds percent change between previous and current loss
# and if it is less than threshold return false
def stopping_rule(L_k, L_k1, threshold):
    return abs(L_k - L_k1) / L_k > threshold

In [None]:
# Calculates the moving average
def moving_avg(alpha, L_MA, L_k):
    return alpha * L_MA + (1-alpha) * L_k

# Training Function

In [4]:
def train_model(model, loss_fn, device, train_loader, optimizer):
    # Initalize loss
    average_loss = 0
    # Train on dataset
    model.train()
    for batch_idx, (X,y) in enumerate(train_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        # Compute loss
        loss = loss_fn(output, mask)
        average_loss += loss.item()
        # Optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Return average loss
    return average_loss / len(train_loader)

Specify hyperparameters

In [5]:
## Preliminary variables ##

# Specifies whether to train on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loss for training
loss_fn = torch.nn.BCELoss()

# Measurements
jaccard = JaccardIndex(task='multiclass', num_classes = 2, average = 'micro').to(device)
dice = Dice(num_classes = 2, average = 'micro').to(device)

# Batch Size
BATCH_SIZE = 16

# Stopping threshold
threshold = 0.001

# Speeds up training
num_workers = 8

# Alpha for EMA
alpha = 0.9

# Get Masks
Function to extracts iamges and masks of given indices

In [6]:
def get_masks(indices, dataset):
    masks = []
    images = []
    for i in indices:
        image, mask = dataset.__getitem__(i)
        images.append(image)
        masks.append(mask)
    return images, masks
    

# Create Masks
Function to create masks of given images

In [7]:
def create_masks(model, device, loader):
    # Initialize 
    masks = []
    images = []
    # Create masks
    model.eval()
    for batch_idx, (X,y) in enumerate(loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        # Detach from CPU and squeeze batch(1) dimension
        masks.append(output.detach().cpu().squeeze(0))
        images.append(image.detach().cpu().squeeze(0))
    # Segmented masks
    return images, masks

# Testing Function

In [8]:
def test_model(model, device, test_loader, jaccard, dice):
    # Initalize average jaccard and dice
    average_jaccard = 0
    average_dice = 0
    # Test the model
    model.eval()
    for batch_idx, (X,y) in enumerate(test_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        average_jaccard += jaccard(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
        average_dice += dice(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
    # Get average of dice and jaccard scores
    average_jaccard /= len(test_loader)
    average_dice /= len(test_loader)

    # Return values
    return average_jaccard, average_dice

# Data Processing
Importing the dataset into the Juypter Notebook enviroment for use

In [9]:
# Paths to masks and images
image_path = "./ISIC/images/ISIC2018_Task1-2_Training_Input/"
masks_path = "./ISIC/masks/ISIC2018_Task1_Training_GroundTruth/"
# Size of image
size = 256
# Define dataset
dataset = ISICSegmentationDataset(image_path, masks_path, size)

100%|██████████| 2594/2594 [00:00<00:00, 3712052.06it/s]
100%|██████████| 2594/2594 [00:00<00:00, 3631516.88it/s]


# Load data
Load the baseline data and indices splits from previous stage.

In [10]:
# File name
name = 'baseline.pt'
baseline = torch.load(f=name)
# Extract indices and baseline jaccard and dice scores
all_indices = baseline['fold_dict']
baseline_jaccard = baseline['baseline_jaccard']
baseline_dice = baseline['baseline_dice']

# Create new dataset
Using the model trained on a partition we will create masks the remaining data from the train partition.

In [11]:
splits = ['A', 'B', 'C', 'D', 'E']
partitions = [0.1,0.2,0.3,0.4,0.5,0.6,0.7]

# Cycle through all models
for split in splits:
    for partition in partitions:
        # Load the stored data
        name = f'./model/Split:{split}|Partition:{partition}|New'
        data = torch.load(name)
        # Extract data
        train_indices = data['train_indices']
        remaining_indices = data['remaining_indices']
        test_indices = data['test_indices']
        orig_jaccard_score = data['jaccard_score']
        orig_dice_score = data['dice_score']
        orig_iterations = data['num_iterations']
        state_dict = data['state_dict']
        # Load base model
        trained_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
            in_channels=3, out_channels=1, init_features=64, pretrained=False, trust_repo=True).to(device)
        # Load saved model
        trained_model.load_state_dict(state_dict)
        
        # Create new dataset
        new_dataset = []
        # Get train loader for fold
        remaining_loader = DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=torch.utils.data.SubsetRandomSampler(remaining_indices),
            num_workers = num_workers
        )
        # Create masks of remaining data
        images, new_masks = create_masks(trained_model, device, remaining_loader)
        # Create dataset with new masks
        new_dataset.append(BaseDataset(images, new_masks))
        # Get images and masks used to train saved model
        base_images, base_masks = get_masks(train_indices,dataset)
        # Create dataset with ground truth masks and images
        new_dataset.append(BaseDataset(base_images, base_masks))
        # Concatenate the two so we have a dataset with generated masks and truth maks
        # this will be our train dataset
        train_dataset = torch.utils.data.ConcatDataset(new_dataset)
        
        # Create train loader for new dataset
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size = BATCH_SIZE,
            shuffle = True,
            num_workers = num_workers,
        )    
        
        ## Initialize the model ##
        
        # We will begin our learning rate at 0.01 
        lr = 0.01
        # Optimizer for model
        optimizer = torch.optim.Adam(trained_model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,100)
        
        ## Initialize the training ##
    
        # Initialize previous and current loss for stopping rule
        
        L_MA = 1 # Moving average of loss
        L_k = 0 # Current loss
        
        # To determine in threshold is too low
        counter = 0
        # Train model until stopping rule is reached
        while(stopping_rule(L_MA, L_k, threshold) or counter < 10):
            # Train model and compute loss
            L_k = train_model(trained_model, loss_fn, device, train_loader, optimizer)
            
            # Initialization
            if(L_MA == 0):
                L_MA = L_k
            
            # Find EMA of losses
            L_MA = moving_avg(alpha, L_MA, L_k)
            counter += 1
            
        # To determine in threshold is too low
        print(counter) 
        
        # Get test dataset
        test_loader = DataLoader(
                dataset=dataset,
                batch_size=BATCH_SIZE,
                sampler=torch.utils.data.SubsetRandomSampler(test_indices),
                num_workers = num_workers,
        )
        
        # Test model on test split
        jaccard_score, dice_score = test_model(model, device, test_loader, jaccard, dice)
        print(f'Split:{split}|Partition:{partition}|NewJaccard:{jaccard_score}|OGJaccard:{orig_jaccard_score}|NewDice:{dice_score}|OGDice:{orig_dice_score}')
        name = f'./new_models/Split:{split}|Partition:{partition}|New'
        # Save model and results
        state = {
            'state_dict' : model.state_dict(),
            'jaccard_score' : jaccard_score,
            'dice_score' : dice_score,
            'num_iterations' : counter,
            'orig_jaccard_score' : orig_jaccard_score,
            'orig_dice_score' : orig_dice_score,
            'orig_iterations' : orig_iterations,
        }
        torch.save(state,f=name)
        
        
        

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


14
Split:A|Partition:0.1|NewJaccard:0.8328580458958944|OGJaccard:0.8005861387108312|NewDice:0.9082560232191375|OGDice:0.8886485153978522


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


39
Split:A|Partition:0.2|NewJaccard:0.823756481661941|OGJaccard:0.8260780067154856|NewDice:0.902275769999533|OGDice:0.9038065671920776


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


62
Split:A|Partition:0.3|NewJaccard:0.834005305261323|OGJaccard:0.8396076957384745|NewDice:0.9088735995870648|OGDice:0.9117339741099965


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


6
Split:A|Partition:0.4|NewJaccard:0.8102510047681404|OGJaccard:0.8023915778506886|NewDice:0.8945544849742543|OGDice:0.8897204055930629


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


65
Split:A|Partition:0.5|NewJaccard:0.9001473784446716|OGJaccard:0.8930019971096155|NewDice:0.947074476516608|OGDice:0.9432177724260272


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


23
Split:A|Partition:0.6|NewJaccard:0.8947329792109403|OGJaccard:0.8860039169138129|NewDice:0.944011056061947|OGDice:0.9390563513293411


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


35
Split:A|Partition:0.7|NewJaccard:0.8962158217574611|OGJaccard:0.9002582141847322|NewDice:0.9446930487950643|OGDice:0.9471398230754968


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


24
Split:B|Partition:0.1|NewJaccard:0.7772222865711559|OGJaccard:0.7807320591175195|NewDice:0.8726213303479281|OGDice:0.8752328861843456


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


49
Split:B|Partition:0.2|NewJaccard:0.8658977891459609|OGJaccard:0.8660176775672219|NewDice:0.927350969025583|OGDice:0.9273869937116449


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


58
Split:B|Partition:0.3|NewJaccard:0.8906751275062561|OGJaccard:0.8739764383344939|NewDice:0.9417863036646987|OGDice:0.9322500427563986


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


29
Split:B|Partition:0.4|NewJaccard:0.895634544618202|OGJaccard:0.8926929072900252|NewDice:0.9446266636703954|OGDice:0.9429937745585586


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


10
Split:B|Partition:0.5|NewJaccard:0.8729685274037448|OGJaccard:0.8398388136516918|NewDice:0.9317316828352032|OGDice:0.9125800186937506


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


17
Split:B|Partition:0.6|NewJaccard:0.8918473756674564|OGJaccard:0.8900907744060863|NewDice:0.9425324364141985|OGDice:0.9414529962973162


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


34
Split:B|Partition:0.7|NewJaccard:0.8891258113311998|OGJaccard:0.883630431059635|NewDice:0.940983280991063|OGDice:0.9378901647798943


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


118
Split:C|Partition:0.1|NewJaccard:0.8577379999738751|OGJaccard:0.8450681747812213|NewDice:0.9230032834139738|OGDice:0.9156237930962534


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


8
Split:C|Partition:0.2|NewJaccard:0.8757391102386244|OGJaccard:0.8666279045018283|NewDice:0.933282303087639|OGDice:0.928167310628024


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


36
Split:C|Partition:0.3|NewJaccard:0.8439733097047517|OGJaccard:0.8483899253787417|NewDice:0.9146057096394625|OGDice:0.9174740206111561


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


16
Split:C|Partition:0.4|NewJaccard:0.8823944655331698|OGJaccard:0.8869172443043102|NewDice:0.9371631958267905|OGDice:0.9397015138105913


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


46
Split:C|Partition:0.5|NewJaccard:0.8887879505301967|OGJaccard:0.8901004917693861|NewDice:0.940779239842386|OGDice:0.9413657585779825


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


23
Split:C|Partition:0.6|NewJaccard:0.8910806757031065|OGJaccard:0.8919210614580096|NewDice:0.9420594428524827|OGDice:0.942535111398408


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


36
Split:C|Partition:0.7|NewJaccard:0.8974887439698884|OGJaccard:0.8819074450117169|NewDice:0.9455887187610973|OGDice:0.9368810364694307


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


75
Split:D|Partition:0.1|NewJaccard:0.792255851355466|OGJaccard:0.7772848804791769|NewDice:0.8833158828995444|OGDice:0.8730095209497394


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


26
Split:D|Partition:0.2|NewJaccard:0.8768581206148321|OGJaccard:0.8722107916167288|NewDice:0.9339657696810636|OGDice:0.9309458588108872


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


40
Split:D|Partition:0.3|NewJaccard:0.8875501264225353|OGJaccard:0.8847688273950056|NewDice:0.9398159511161573|OGDice:0.9382852370088751


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


30
Split:D|Partition:0.4|NewJaccard:0.8720780195611896|OGJaccard:0.8799027150327509|NewDice:0.9309219392863187|OGDice:0.9354553186532223


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


8
Split:D|Partition:0.5|NewJaccard:0.8104832660068165|OGJaccard:0.8305145736896631|NewDice:0.8944107601136873|OGDice:0.9064425970568801


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


16
Split:D|Partition:0.6|NewJaccard:0.8649323293657014|OGJaccard:0.8316080479910879|NewDice:0.9270441297328833|OGDice:0.9074085499301101


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


62
Split:D|Partition:0.7|NewJaccard:0.893074833985531|OGJaccard:0.8313601269866481|NewDice:0.9431628801605918|OGDice:0.9072384111809008


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


161
Split:E|Partition:0.1|NewJaccard:0.8220972057544824|OGJaccard:0.7971197926636898|NewDice:0.9015412095821265|OGDice:0.8853532328750148


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


7
Split:E|Partition:0.2|NewJaccard:0.774028870192441|OGJaccard:0.7271572893316095|NewDice:0.870843069119887|OGDice:0.8401159517692797


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


33
Split:E|Partition:0.3|NewJaccard:0.7658198457775693|OGJaccard:0.8201130592461788|NewDice:0.8660906101718093|OGDice:0.8999443252881368


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


17
Split:E|Partition:0.4|NewJaccard:0.8022179585514646|OGJaccard:0.7868092385205355|NewDice:0.8882634332685759|OGDice:0.8782555460929871


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


48
Split:E|Partition:0.5|NewJaccard:0.8852056535807523|OGJaccard:0.8829007509982947|NewDice:0.9386203451590105|OGDice:0.9371938506762186


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


26
Split:E|Partition:0.6|NewJaccard:0.8837962746620178|OGJaccard:0.8637632167700565|NewDice:0.9380025086980878|OGDice:0.9265058889533534


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


38
Split:E|Partition:0.7|NewJaccard:0.8741676446163293|OGJaccard:0.8878839413324991|NewDice:0.932399798523296|OGDice:0.940051636912606
