# Training of the Ostia Detector Model

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from datetime import datetime
from utils import *
from math import *
from OstiaDetector import OstiaDetector
from scipy.stats import multivariate_normal

In [2]:
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Device Count: {torch.cuda.device_count()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUDA Available: True
CUDA Device Count: 1


## Load the data and define the training parameters

In [3]:
# Define the directories containing the files
image_directory = "/data/training_data/"
label_directory = "/data/training_data/"
graph_directory = "/data/training_data/"

# List all the NIfTI files in the directory
image_files = [os.path.join(image_directory, f) for f in os.listdir(image_directory) if f.endswith('img.nii.gz')]
label_files = [os.path.join(label_directory, f) for f in os.listdir(label_directory) if f.endswith('label.nii.gz')]
graph_files = [os.path.join(graph_directory, f) for f in os.listdir(graph_directory) if f.endswith('graph.json')]

In [4]:
# Create a list of dictionnaries with the data, exclude samples with alleged incorrect number of nodes in their ground-truth graphs
exclude = ["102624", "27badc", "0604cd", "26d228", "04222e", "22cdd3", "29db0c", "0b9189", "1a42d5", "1b3c33", "2f74d6", "108c99", 
           "0cce0b", "064c3e", "234666", "28955b", "0a5b04", "107165", "23c591", "0b06d2","089ee1", "18beb4", "1f594d"]
data = []

for i in image_files:
    if i.split(image_directory)[1][:6] in exclude:
        continue
    for l in label_files:
        for g in graph_files:
            if i.split(image_directory)[1][:6] == l.split(label_directory)[1][:6] == g.split(graph_directory)[1][:6]: # identify the sample
                data += [{'image': i, 'label': l, 'graph': g}]
                break

In [5]:
training, hold_out = data[:-15], data[-15:]

In [6]:
# Training parameters
epochs = 20
batch_size = 32
batches_per_epoch = len(training)
total_iterations = epochs * batch_size * batches_per_epoch
print("epochs: {}\nbatch size: {} \nbatches per epoch: {}\ntotal iterations: {}".format(
    epochs, batch_size, batches_per_epoch, total_iterations))

epochs: 20
batch size: 32 
batches per epoch: 162
total iterations: 103680


In [7]:
initial_lrate = 1e-1
ostia_model = OstiaDetector().to(device)
optimizer = Adam(ostia_model.parameters(), lr=initial_lrate)
loss_fn = nn.MSELoss()

In [8]:
def step_decay(epoch):
    epochs_drop = 10000 / batch_size / batches_per_epoch
    drop = 0.1
    
    lrate = initial_lrate * pow(drop, floor((1 + epoch) / (5 * epochs_drop)))
    return lrate

In [9]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

## Training functions

In [10]:
# Preprocessing of data, compute reference proximity values for patches centered around ostia points and random points (negative samples)
def get_batch(id_elem):
    input_data, d_ostia = [], []
    
    # Constants (from Coronary Artery Centerline Extraction in Cardiac CT Angiography Using a CNN-Based Orientation Classifier)
    a = 6
    dm_ostium = 16
    
    # Fraction of patches centered around ostia points
    p = 7/10 
    
    # Get one data sample and retrieve the coronary ostia and the image
    sample = data[id_elem]
    image, ostia = load_seeds(sample['image'], sample['graph'])
    
    # List of 0 and 1 at random positions
    idx_references = np.random.randint(len(ostia), size=int(p * batch_size))
    
    # Add samples centered around ostia points
    for idx in idx_references:
        point = multivariate_normal.rvs(ostia[idx], cov=16) # add gaussian noise to the coordinates of the chosen ostium 
        patch, bounds = segment_image(image, point)

        if patch.shape == (19, 19, 19):
            d_ostium = 0
            dc_ostium = dm_ostium
            
            # Get the distance of the patch center to the closest ostium point
            for ostium in ostia:
                if np.all((ostium >= bounds[0]) & (ostium < bounds[1]), axis=0): # ensure the ostium point is within the patch bounds
                    if calculate_distance(point, ostium) < dc_ostium:
                        dc_ostium = calculate_distance(point, ostium)

            if dc_ostium < dm_ostium: # if the distance is below the defined threshold
                d_ostium = np.exp(a * (1 - (dc_ostium/dm_ostium))) - 1
                
            input_data.append(patch) 
            d_ostia.append(d_ostium)
                    
    # Add negative samples
    while len(input_data) < batch_size:
        point = np.random.randint(image.shape)
        patch, _ = segment_image(image, point)
        j = 0
        
        if patch.shape == (19, 19, 19):
            # Ensure the random point is not close to any ostium point
            for ostium in ostia:
                if calculate_distance(point, ostium) > dm_ostium:
                    j += 1
                    
            if j == len(ostia):
                input_data.append(patch)
                d_ostia.append(0)
                    
    # Formatting (converting into arrays before converting into tensors reduces the execution time)
    input_data = torch.tensor(np.array(input_data), dtype=torch.float32).reshape(-1, 1, 19, 19, 19)
    d_ostia = torch.tensor(np.array(d_ostia), dtype=torch.float32).reshape(-1, 1)

    return input_data, d_ostia

In [11]:
def train_one_batch(id_elem):
    running_loss = 0.
    
    X, Y = get_batch(id_elem)

    X, Y = X.to(device), Y.to(device)
    
    # Put gradients to zero for every batch
    optimizer.zero_grad()
    
    # Make predictions for this batch
    outputs = ostia_model(X)

    # Compute the loss and its gradients 
    loss = loss_fn(Y, outputs)
    loss.backward()

    # Adjust learning weights
    optimizer.step()

    running_loss += loss.item()
            
    return running_loss

## Training loop

In [12]:
epoch_number = 0
train_losses = np.empty(0)
validation_losses = np.empty(0)

for epoch in range(epochs):
    print('EPOCH {}:'.format(epoch_number + 1))
    
    # Make sure gradient tracking is on, and do a pass over the data
    ostia_model.train(True)
    
    running_loss = 0.
    running_long_loss = 0.
    
    id_elems = np.arange(len(training))
    np.random.shuffle(id_elems)
                         
    for i, id_elem in enumerate(id_elems):
        loss = train_one_batch(id_elem)
        running_loss += loss
        running_long_loss += loss
    
        if (i+1) % 20 == 0:
            avg_loss = running_loss / 20
            print('  batch {} loss: {}'.format(i + 1, avg_loss))
            train_losses = np.append(train_losses, avg_loss)
            running_loss = 0.
            
        if (i+1) % 100 == 0:
            avg_long_loss = running_long_loss / 100
            print('\n  BATCH100 loss: {}\n'.format(avg_long_loss))
            running_long_loss = 0.
            
    new_lrate = step_decay(epoch_number)
    optimizer.param_groups[0]['lr'] = new_lrate
    
    # We don't need gradients on to do reporting and testing
    ostia_model.train(False)
    
    print('TRAIN loss {}'.format(avg_loss))
    
    running_vloss = 0.
    id_elems = np.arange(len(training), len(data))
                         
    for id_elem in id_elems:
        vloss = train_one_batch(id_elem)
        running_vloss += vloss
    
    avg_vloss = running_vloss / len(id_elems)
    
    validation_losses = np.append(validation_losses, avg_vloss)
    
    print('TEST loss {}'.format(avg_vloss))
    
    epoch_number += 1   
    model_path = './models/ostia_model_{}_{}'.format(timestamp, epoch_number)
    torch.save(ostia_model.state_dict(), model_path)

EPOCH 1:
  batch 20 loss: 2593.609997558594
  batch 40 loss: 1846.0680358886718
  batch 60 loss: 1959.745669555664
  batch 80 loss: 1373.7342910766602
  batch 100 loss: 1464.800830078125

  BATCH100 loss: 1847.591764831543

  batch 120 loss: 1372.1501510620117
  batch 140 loss: 1074.2879638671875
  batch 160 loss: 1489.5362518310546
TRAIN loss 1489.5362518310546
TEST loss 1676.6760782877604
EPOCH 2:
  batch 20 loss: 1325.3512878417969
  batch 40 loss: 1199.1827606201173
  batch 60 loss: 1208.3746337890625
  batch 80 loss: 1168.408364868164
  batch 100 loss: 1214.1961196899415

  BATCH100 loss: 1223.1026333618165

  batch 120 loss: 1174.4673583984375
  batch 140 loss: 1021.461328125
  batch 160 loss: 896.9181655883789
TRAIN loss 896.9181655883789
TEST loss 1411.5084167480468
EPOCH 3:
  batch 20 loss: 980.699518585205
  batch 40 loss: 1428.490330505371
  batch 60 loss: 948.400991821289
  batch 80 loss: 1194.651988220215
  batch 100 loss: 1179.110333251953

  BATCH100 loss: 1146.270632476

In [13]:
np.savez("losses.npz", arr1=train_losses, arr2=validation_losses)