In [1]:
!pip install torch
!pip install torchmetrics

[0m

# Imports
Packages required for experiment

In [2]:
# General use for DL
import torch
import numpy as np

# To loader dataset
from torch.utils.data import DataLoader

# Data Measurements
from torchmetrics import JaccardIndex
from torchmetrics import Dice

# Dataset import
from dataset import BUIDSegmentationDataset
from base_dataset import BaseDataset

# Stopping Rule
When this function returns a value less than 0.1 % = 0.001 then we know to stop training

In [3]:
# Finds percent change between previous and current loss
# and if it is less than threshold return false
def stopping_rule(L_k, L_k1, threshold):
    return abs(L_k - L_k1) / L_k > threshold

In [4]:
# Calculates the moving average
def moving_avg(alpha, L_MA, L_k):
    return alpha * L_MA + (1-alpha) * L_k

# Training Function

In [5]:
def train_model(model, loss_fn, device, train_loader, optimizer):
    # Initalize loss
    average_loss = 0
    # Train on dataset
    model.train()
    for batch_idx, (X,y) in enumerate(train_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        # Compute loss
        loss = loss_fn(output, mask)
        average_loss += loss.item()
        # Optimize model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Return average loss
    return average_loss / len(train_loader)

Specify hyperparameters

In [6]:
## Preliminary variables ##

# Specifies whether to train on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loss for training
loss_fn = torch.nn.BCELoss()

# Measurements
jaccard = JaccardIndex(task='multiclass', num_classes = 2, average = 'micro').to(device)
dice = Dice(num_classes = 2, average = 'micro').to(device)

# Batch Size
BATCH_SIZE = 16

# Stopping threshold
threshold = 0.001

# Speeds up training
num_workers = 8

# Alpha for EMA
alpha = 0.9

# Get Masks
Function to extracts iamges and masks of given indices

In [7]:
def get_masks(indices, dataset):
    masks = []
    images = []
    for i in indices:
        image, mask = dataset.__getitem__(i)
        images.append(image)
        masks.append(mask)
    return images, masks
    

# Create Masks
Function to create masks of given images

In [8]:
def create_masks(model, device, loader):
    # Initialize 
    masks = []
    images = []
    # Create masks
    model.eval()
    for batch_idx, (X,y) in enumerate(loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        # Detach from CPU and squeeze batch(1) dimension
        masks.append(output.detach().cpu().squeeze(0))
        images.append(image.detach().cpu().squeeze(0))
    # Segmented masks
    return images, masks

# Testing Function

In [9]:
def test_model(model, device, test_loader, jaccard, dice):
    # Initalize average jaccard and dice
    average_jaccard = 0
    average_dice = 0
    # Test the model
    model.eval()
    for batch_idx, (X,y) in enumerate(test_loader):
        # Get batch
        image, mask = X.to(device), y.to(device)
        # Get results
        output = model(image)
        average_jaccard += jaccard(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
        average_dice += dice(torch.where(output > 0.5, 1, 0),torch.where(mask > 0.50, 1, 0)).item()
    # Get average of dice and jaccard scores
    average_jaccard /= len(test_loader)
    average_dice /= len(test_loader)

    # Return values
    return average_jaccard, average_dice

# Data Processing
Importing the dataset into the Juypter Notebook enviroment for use

In [10]:
# Path to images and mask
root_dir = './BUID'
# U-Net we are using takes 3 x 256 x 256 images
size = 256
# Importing the dataset
dataset = BUIDSegmentationDataset(root_dir, size)


# Load data
Load the baseline data and indices splits from previous stage.

In [11]:
# File name
name = 'baseline.pt'
baseline = torch.load(f=name)
# Extract indices and baseline jaccard and dice scores
all_indices = baseline['fold_dict']
baseline_jaccard = baseline['baseline_jaccard']
baseline_dice = baseline['baseline_dice']

# Create new dataset
Using the model trained on a partition we will create masks the remaining data from the train partition.

In [None]:
splits = ['A', 'B', 'C', 'D', 'E']
partitions = [0.1,0.2,0.3,0.4,0.5,0.6,0.7]

# Cycle through all models
for split in splits:
    for partition in partitions:
        # Load the stored data
        name = f'./model/Split:{split}|Partition:{partition}|New'
        data = torch.load(name)
        # Extract data
        train_indices = data['train_indices']
        remaining_indices = data['remaining_indices']
        test_indices = data['test_indices']
        orig_jaccard_score = data['jaccard_score']
        orig_dice_score = data['dice_score']
        orig_iterations = data['num_iterations']
        state_dict = data['state_dict']
        # Load base model
        trained_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
            in_channels=3, out_channels=1, init_features=64, pretrained=False, trust_repo=True).to(device)
        # Load saved model
        trained_model.load_state_dict(state_dict)
        
        # Create new dataset
        new_dataset = []
        # Get train loader for fold
        remaining_loader = DataLoader(
            dataset=dataset,
            batch_size=1,
            sampler=torch.utils.data.SubsetRandomSampler(remaining_indices),
            num_workers = num_workers
        )
        # Create masks of remaining data
        images, new_masks = create_masks(trained_model, device, remaining_loader)
        # Create dataset with new masks
        new_dataset.append(BaseDataset(images, new_masks))
        # Get images and masks used to train saved model
        base_images, base_masks = get_masks(train_indices,dataset)
        # Create dataset with ground truth masks and images
        new_dataset.append(BaseDataset(base_images, base_masks))
        # Concatenate the two so we have a dataset with generated masks and truth maks
        # this will be our train dataset
        train_dataset = torch.utils.data.ConcatDataset(new_dataset)
        
        # Create train loader for new dataset
        train_loader = DataLoader(
            dataset=train_dataset,
            batch_size = BATCH_SIZE,
            shuffle = True,
            num_workers = num_workers,
        )    
        
        ## Initialize the model ##
        
        # We will begin our learning rate at 0.01 
        lr = 0.01
        # Optimizer for model
        optimizer = torch.optim.Adam(trained_model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,100)
        
        ## Initialize the training ##
    
        # Initialize previous and current loss for stopping rule
        
        L_MA = 1 # Moving average of loss
        L_k = 0 # Current loss
        
        # To determine in threshold is too low
        counter = 0
        # Train model until stopping rule is reached
        while(stopping_rule(L_MA, L_k, threshold) or counter < 10):
            # Train model and compute loss
            L_k = train_model(trained_model, loss_fn, device, train_loader, optimizer)
            
            # Initialization
            if(L_MA == 0):
                L_MA = L_k
            
            # Find EMA of losses
            L_MA = moving_avg(alpha, L_MA, L_k)
            counter += 1
            
        # To determine in threshold is too low
        print(counter) 
        
        # Get test dataset
        test_loader = DataLoader(
                dataset=dataset,
                batch_size=BATCH_SIZE,
                sampler=torch.utils.data.SubsetRandomSampler(test_indices),
                num_workers = num_workers,
        )
        
        # Test model on test split
        jaccard_score, dice_score = test_model(trained_model, device, test_loader, jaccard, dice)
        print(f'Split:{split}|Partition:{partition}|NewJaccard:{jaccard_score}|OGJaccard:{orig_jaccard_score}|NewDice:{dice_score}|OGDice:{orig_dice_score}')
        name = f'./new_models/Split:{split}|Partition:{partition}|New'
        # Save model and results
        state = {
            'state_dict' : trained_model.state_dict(),
            'jaccard_score' : jaccard_score,
            'dice_score' : dice_score,
            'num_iterations' : counter,
            'orig_jaccard_score' : orig_jaccard_score,
            'orig_dice_score' : orig_dice_score,
            'orig_iterations' : orig_iterations,
        }
        torch.save(state,f=name)
        
        
        

Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


95
Split:A|Partition:0.1|NewJaccard:0.8672114312648773|OGJaccard:0.8460275053977966|NewDice:0.9286053955554963|OGDice:0.9164737701416016


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


104
Split:A|Partition:0.2|NewJaccard:0.8734017372131347|OGJaccard:0.8686314165592194|NewDice:0.9320660591125488|OGDice:0.9295298278331756


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


104
Split:A|Partition:0.3|NewJaccard:0.900985598564148|OGJaccard:0.8930659413337707|NewDice:0.9477347075939179|OGDice:0.9431972205638885


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


180
Split:A|Partition:0.4|NewJaccard:0.8998079001903534|OGJaccard:0.8934301137924194|NewDice:0.9470205962657928|OGDice:0.9433905601501464


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


53
Split:A|Partition:0.5|NewJaccard:0.9061295866966248|OGJaccard:0.9102159023284913|NewDice:0.9504400253295898|OGDice:0.952603530883789


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


176
Split:A|Partition:0.6|NewJaccard:0.9086047232151031|OGJaccard:0.9109900891780853|NewDice:0.9519734084606171|OGDice:0.9532379150390625


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


82
Split:A|Partition:0.7|NewJaccard:0.9029858469963074|OGJaccard:0.9092088997364044|NewDice:0.9484447777271271|OGDice:0.9521798133850098


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


166
Split:B|Partition:0.1|NewJaccard:0.8753970265388489|OGJaccard:0.86333127617836|NewDice:0.9334462463855744|OGDice:0.9264231383800506


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


139
Split:B|Partition:0.2|NewJaccard:0.9019554555416107|OGJaccard:0.891900897026062|NewDice:0.9482815444469452|OGDice:0.9427349388599395


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


159
Split:B|Partition:0.3|NewJaccard:0.9202809631824493|OGJaccard:0.9102275788784027|NewDice:0.9582464218139648|OGDice:0.9529102981090546


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


295
Split:B|Partition:0.4|NewJaccard:0.9361911892890931|OGJaccard:0.9310519218444824|NewDice:0.9669740974903107|OGDice:0.9641994774341583


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


178
Split:B|Partition:0.5|NewJaccard:0.9379238665103913|OGJaccard:0.9317119181156158|NewDice:0.967889529466629|OGDice:0.9645150184631348


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


149
Split:B|Partition:0.6|NewJaccard:0.9341915428638459|OGJaccard:0.935376501083374|NewDice:0.9658273041248322|OGDice:0.9665213882923126


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


181
Split:B|Partition:0.7|NewJaccard:0.9380175173282623|OGJaccard:0.9324437916278839|NewDice:0.9678819954395295|OGDice:0.9649609863758087


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


135
Split:C|Partition:0.1|NewJaccard:0.8638936340808868|OGJaccard:0.852782028913498|NewDice:0.9268801033496856|OGDice:0.9202640533447266


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


127
Split:C|Partition:0.2|NewJaccard:0.8696028411388397|OGJaccard:0.8524227738380432|NewDice:0.9299839973449707|OGDice:0.9199829399585724


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


161
Split:C|Partition:0.3|NewJaccard:0.8815026879310608|OGJaccard:0.8639395058155059|NewDice:0.936619633436203|OGDice:0.9265961349010468


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


199
Split:C|Partition:0.4|NewJaccard:0.8918827056884766|OGJaccard:0.8869680285453796|NewDice:0.9427349090576171|OGDice:0.9395377814769745


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


105
Split:C|Partition:0.5|NewJaccard:0.8981361389160156|OGJaccard:0.895977133512497|NewDice:0.9460402190685272|OGDice:0.9450104057788848


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


189
Split:C|Partition:0.6|NewJaccard:0.9072723925113678|OGJaccard:0.904334819316864|NewDice:0.951031494140625|OGDice:0.9494717597961426


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


137
Split:C|Partition:0.7|NewJaccard:0.902508270740509|OGJaccard:0.9055923461914063|NewDice:0.9487021446228028|OGDice:0.9503263473510742


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master


99
Split:D|Partition:0.1|NewJaccard:0.8615512669086456|OGJaccard:0.8586161732673645|NewDice:0.9252488434314727|OGDice:0.9236314773559571


Using cache found in /root/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
