#  U-Net Test: *Multi-class* CT Bone Segmentation

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim import lr_scheduler
from losses import FocalLoss, MixedLoss, dice, IoU
from model import UNet
from dataset import CTMulticlassDataset

from train import train_net, train_net_multiclass
from eval import eval_net
from predict import predict
from utils import (readFloatImage, readBinImage, readUCharImage, 
                   plotSomeImages, matchFilesFromPatient)

from datetime import datetime

## Define Datasets
Get the CT and label mask files into Datasets.

In [None]:
# get patient data list from nested file system
patient_idxs = [1, 3, 4, 5]
ct_data = []
for idx in patient_idxs:
    for day_selection in range(1,4):
        matched_data = matchFilesFromPatient(idx, 
                                             day_selection, 
                                             mode='CT_SPINE_STERNUM_PELVIS')
        ct_data.extend(matched_data)

random.shuffle(ct_data)

# set training and validation size
train_set_size = 1000
val_set_size = int(0.2 * train_set_size)
train_data = ct_data[0:train_set_size]
val_data = ct_data[train_set_size:train_set_size + val_set_size]

train_dataset = CTMulticlassDataset(train_data, augment=True)
val_dataset = CTMulticlassDataset(val_data, augment=False)

batch_size = 3
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=1)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=1)

## Dataset verification
Check the dataset containers for correct data and shape.

In [None]:
idx = np.random.choice(len(train_dataset))
sample_data = train_dataset[idx]

ct = sample_data['image'][0,:,:]
spine_mask = sample_data['target'][0,:,:]
stern_mask = sample_data['target'][1,:,:]
pelvi_mask = sample_data['target'][2,:,:]

images = {'ct': ct,
          'spine': spine_mask,
          'sternum': stern_mask,
          'pelvis': pelvi_mask,
         }
plotSomeImages(images, 1, 4)

## Create datalog folder

In [None]:
current_datetime = datetime.now().strftime('%Y.%m.%d-%H.%M.%S')
output_directory = 'trainlog-' + current_datetime
if not os.path.isdir(output_directory):
    os.makedirs(output_directory)

## Training initialization
Define hyperparameters for learning, loss function, optimizer, etc...

In [None]:
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")

# model hyperparameters
lr_model = 0.0003
decay_step_size = 500

# declare model
model = UNet(n_channels = 1, 
             n_classes = 3,
             large_model = True)

model.to(device)

# learning schema
criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=lr_model)

learning_scheduler = lr_scheduler.StepLR(optimizer, 
                                         step_size=decay_step_size, 
                                         gamma=0.1)
num_epochs = 5
ckpt_save_interval = 2

# loss log initialization
train_losses = []
val_losses = []
best_score = 0.0

## Training loop
General training scheme with validation statistics recorded per epoch. Best models are saved.

In [None]:
for epoch in range(num_epochs):
    # run training step
    train_loss = train_net_multiclass(model, 
                                      device, 
                                      train_loader,
                                      batch_size,
                                      criterion, 
                                      optimizer,
                                      learning_scheduler,
                                      epoch,
                                      print_log=True)

    # save models
    torch.save(model.state_dict(), '{}/unet_model_{}.pt'.format(output_directory, epoch))

# NO VALIDATION YET FOR MULTICLASS... TRAINING WILL JUST ATTEMPT TO OPTIMIZE THE LOSS FOREVER.

## Training charts
Not sure if this works right now.

In [None]:
dice_scores, iou_scores = zip(*val_losses)
n = range(0, len(dice_scores))
fig = plt.figure(figsize=(12,8))
plt.plot(n, dice_scores, label='dice')
plt.plot(n, iou_scores, label='iou')
plt.grid('on')
plt.legend()

## Inference test
Grab a random CT image from the loaded dataset and performs a prediction!

In [None]:
model.eval()

ridx = random.randint(0, len(val_loader)-1)
    
ct = val_dataset[ridx]['image'].unsqueeze(0)
ct = ct.to(device)
mask = val_dataset[ridx]['target'].unsqueeze(0)
mask = mask.to(device)
pred = model(ct)

with torch.no_grad():
    pred = torch.sigmoid(pred)
    pred = model(ct).cpu()
    pred = pred.numpy()

pred_spine = pred[0,0,:,:]
pred_stern = pred[0,1,:,:]
pred_pelvi = pred[0,2,:,:]

mask = mask.cpu()

images = {'true_spine': mask[0,0,:,:],
          'true_stern': mask[0,1,:,:],
          'true_pelvi': mask[0,2,:,:],
          'pred_spine': pred_spine,
          'pred_stern': pred_stern,
          'pred_pelvi': pred_pelvi
         }
plotSomeImages(images, 2, 3)