In [1]:
!pip install torch
!pip install torchmetrics

[0m

# Imports
Packages required for experiment

In [2]:
# General use for DL
import torch
import numpy as np

# To loader dataset
from torch.utils.data import DataLoader

# Data Measurements
from torchmetrics import JaccardIndex
from torchmetrics import Dice

# Dataset import
from dataset import ChestXRaysSegmentationDataset

# Stopping Rule
When this function returns a value less than 0.1 % = 0.001 then we know to stop training

In [3]:
# Finds percent change between previous and current loss
# and if it is less than threshold return false
def stopping_rule(L_k, L_k1, threshold):
    return abs(L_k - L_k1) / L_k > threshold

In [4]:
# Calculates the moving average
def moving_avg(alpha, L_MA, L_k):
    return alpha * L_MA + (1-alpha) * L_k

# Training Function

In [5]:
def train_model(model, loss_fn, device, train_loader, optimizer):
    # Initalize loss
    average_loss = 0
    # Train on dataset
    model.train()
    for batch_idx, (X,y) in enumerate(train_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        # Compute loss
        loss = loss_fn(output, mask)
        average_loss += loss.item()
        # Optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Return average loss
    return average_loss / len(train_loader)

# Testing Function

In [6]:
def test_model(model, device, test_loader, jaccard, dice):
    # Initalize average jaccard and dice
    average_jaccard = 0
    average_dice = 0
    # Test the model
    model.eval()
    for batch_idx, (X,y) in enumerate(test_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        average_jaccard += jaccard(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
        average_dice += dice(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
    # Get average of dice and jaccard scores
    average_jaccard /= len(test_loader)
    average_dice /= len(test_loader)

    # Return values
    return average_jaccard, average_dice

# Function to partition list
We will use this function to partition the list of train indices into eighths. If we remember the train indices consist of 80 % of the original dataset, so each partition will contain 10 % of the original dataset. 

In [7]:
def split_into_eights(list):
    # Floor division of length of list
    partition_size = len(list) // 8
    remainder  = len(list) % 8
    
    # List that will store each parition
    partitions = []
    
    # Partition the list, if partition is not even distrubutes 
    # remainder between beginning paritions
    start = 0
    for i in range(8):
        end = start + partition_size + (1 if i < remainder and i != 0 else 0)
        partitions.append(list[start:end])
        start = end
    return partitions  

# Load Data
Load the baseline data and indices splits from previous stage.

In [8]:
# File name
name = 'baseline.pt'
baseline = torch.load(f=name)
# Extract indices and baseline jaccard and dice scores
all_indices = baseline['fold_dict']
baseline_jaccard = baseline['baseline_jaccard']
baseline_dice = baseline['baseline_dice']


Load the dataset

In [9]:
# Path to images and masks
image_path = './chest_xray/images'
mask_path = './chest_xray/masks'
size = 256
# Import the dataset
dataset = ChestXRaysSegmentationDataset(image_path,mask_path,size)

100%|██████████| 704/704 [00:00<00:00, 2801508.55it/s]
100%|██████████| 704/704 [00:00<00:00, 2713961.41it/s]


Specify hyperparameters

In [10]:
## Preliminary variables ##

# Specifies whether to train on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loss for training
loss_fn = torch.nn.BCELoss()

# Measurements
jaccard = JaccardIndex(task='multiclass', num_classes = 2, average = 'micro').to(device)
dice = Dice(num_classes = 2, average = 'micro').to(device)

# Batch Size
BATCH_SIZE = 16

# Stopping threshold
threshold = 0.001

# Speeds up training
num_workers = 8

# Alpha for EMA
alpha = 0.9

In [11]:
splits = ['A', 'B', 'C', 'D', 'E']
partitions = [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7]
# Cycle through all test splits
for split in splits:
    # Get indices of test and train points in dataset
    indices = all_indices[split]
    train_indices = indices[0]
    test_indices = indices[1]
    # Split train indices into eighths
    partition_indices = split_into_eights(train_indices)
    # Create test dataloader for fold
    test_loader = DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE,
        sampler=torch.utils.data.SubsetRandomSampler(test_indices),
        num_workers = num_workers
    )
    # Cycle through all partition lengths of training data (0.1,0.2,...,0.7)
    for i in range(1, len(partition_indices)):
        
        # Initialize Jaccard and Dice
        average_jaccard = 0
        average_dice = 0
                
        train_indices_i = np.hstack(partition_indices[0:i])
        remaining_indices = np.hstack(partition_indices[i:])
        
        # Create a train dataloader for partition
        train_loader = DataLoader(
            dataset=dataset,
            batch_size = BATCH_SIZE,
            sampler=torch.utils.data.SubsetRandomSampler(train_indices_i),
            num_workers = num_workers
        )
        
        ## Initialize the model ##

        # Loading an untrained model to GPU/CPU
        model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
            in_channels=3, out_channels=1, init_features=64, pretrained=False, trust_repo=True).to(device)
        # We will begin our learning rate at 0.01 
        lr = 0.01
        # Optimizer for model
        optimizer = torch.optim.Adam(model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,25)
        
        ## Initialize the training ##
    
        # Initialize previous and current loss for stopping rule
        
        L_MA = 1 # Moving average of loss
        L_k = 0 # Current loss
        
        # To determine in threshold is too low
        counter = 0
        # Train model until stopping rule is reached
        while(stopping_rule(L_MA, L_k, threshold) or counter < 10):
            # Train model and compute loss
            L_k = train_model(model, loss_fn, device, train_loader, optimizer)
            
            # Initialization of EMA
            if(L_MA == 0):
                L_MA = L_k
            
            # Find EMA of losses
            L_MA = moving_avg(alpha, L_MA, L_k)
            counter += 1
               
        # To determine in threshold is too low
        print(counter)
        # Test model on test split
        jaccard_score, dice_score = test_model(model, device, test_loader, jaccard, dice)
        print(f'Split:{split}|Partition:{partitions[i]}|Jaccard:{jaccard_score}|Dice:{dice_score}')
        # Save model, scores, and indices for next step
        name = f'./model/Split:{split}|Partition:{partitions[i]}|New'
        state = {
            'train_indices' : train_indices_i,
            'remaining_indices' : remaining_indices,
            'test_indices' : test_indices,
            'jaccard_score' : jaccard_score,
            'dice_score' : dice_score,
            'num_iterations' : counter,
            'state_dict' : model.state_dict(),
        }
        torch.save(state, f=name)
        
    

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


70
Split:A|Partition:0.1|Jaccard:0.8689605924818251|Dice:0.9295347929000854


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


945
Split:A|Partition:0.2|Jaccard:0.9219139085875617|Dice:0.9592072632577684


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


344
Split:A|Partition:0.3|Jaccard:0.938945902718438|Dice:0.9684959914949205


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


692
Split:A|Partition:0.4|Jaccard:0.9364785022205777|Dice:0.9671646356582642


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


190
Split:A|Partition:0.5|Jaccard:0.9423675338427225|Dice:0.9703133967187669


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


612
Split:A|Partition:0.6|Jaccard:0.9524468845791287|Dice:0.9756412837240431


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


406
Split:A|Partition:0.7|Jaccard:0.9588484764099121|Dice:0.9789837797482809


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


593
Split:B|Partition:0.1|Jaccard:0.8715821570820279|Dice:0.9308970438109504


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


364
Split:B|Partition:0.2|Jaccard:0.9259426660007901|Dice:0.9614260064231025


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


192
Split:B|Partition:0.3|Jaccard:0.9342492487695482|Dice:0.9659856160481771


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


189
Split:B|Partition:0.4|Jaccard:0.9233459234237671|Dice:0.960111571682824


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


176
Split:B|Partition:0.5|Jaccard:0.9450196160210503|Dice:0.971705953280131


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


300
Split:B|Partition:0.6|Jaccard:0.9482164780298868|Dice:0.9734048247337341


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


156
Split:B|Partition:0.7|Jaccard:0.9581540889210172|Dice:0.978609475824568


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


248
Split:C|Partition:0.1|Jaccard:0.8733338382509019|Dice:0.932131807009379


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


2040
Split:C|Partition:0.2|Jaccard:0.9353870815700955|Dice:0.9665791193644205


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


337
Split:C|Partition:0.3|Jaccard:0.9347047209739685|Dice:0.9661864572101169


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


278
Split:C|Partition:0.4|Jaccard:0.9419408043225607|Dice:0.9700709184010824


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


466
Split:C|Partition:0.6|Jaccard:0.941914955774943|Dice:0.9700656533241272


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


161
Split:C|Partition:0.7|Jaccard:0.9586775369114346|Dice:0.9788919422361586


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


50
Split:D|Partition:0.1|Jaccard:0.8724015752474467|Dice:0.9316100676854452


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


103
Split:D|Partition:0.2|Jaccard:0.9172800845570035|Dice:0.9567976792653402


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


194
Split:D|Partition:0.3|Jaccard:0.9223662681049771|Dice:0.9595598777135214


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


225
Split:D|Partition:0.4|Jaccard:0.9278005162874857|Dice:0.9624797900517782


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


312
Split:D|Partition:0.5|Jaccard:0.942874981297387|Dice:0.970564497841729


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


331
Split:D|Partition:0.6|Jaccard:0.950145939985911|Dice:0.9744140108426412


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


208
Split:D|Partition:0.7|Jaccard:0.9586351977454292|Dice:0.9788732992278205


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


124
Split:E|Partition:0.1|Jaccard:0.8861263460583158|Dice:0.9393250147501627


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


156
Split:E|Partition:0.2|Jaccard:0.9101462033059862|Dice:0.9528409441312155


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


434
Split:E|Partition:0.3|Jaccard:0.9410359329647489|Dice:0.9695913195610046


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


609
Split:E|Partition:0.4|Jaccard:0.9479281769858466|Dice:0.9732448061307272


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


164
Split:E|Partition:0.5|Jaccard:0.9438536630736457|Dice:0.9711029595798917


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


286
Split:E|Partition:0.6|Jaccard:0.9448421663708158|Dice:0.9716153343518575


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


541
Split:E|Partition:0.7|Jaccard:0.9597532219356961|Dice:0.9794496960110135
