#  U-Net Test: CT Bone Segmentation

In [None]:
from PIL import Image
import numpy as np
import scipy.io as scio
import glob
import torch

import torch.nn as nn
import torch.nn.functional as F
import os
import random
from tqdm import tqdm

from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt

from utils import match_files_from_patient, readBinImage, readUCharImage
from dataset import CTMaskDataset, CTPTMaskDataset
from model import UNet
from train import train_net, test_net
from losses import FocalLoss, MixedLoss, dice, IoU

import imgaug as iaa
from datetime import datetime
import pickle

## Define Datasets
Get the CT and label mask files into Datasets.

In [None]:
train_patient_idxs = [1, 2, 4, 5]
train_dev_data = []
for idx in train_patient_idxs:
    for day_selection in range(1,4):
        patient_data = match_files_from_patient(idx, day_selection, mode='CT_PT_SPINE')
        train_dev_data.extend(patient_data)
        
test_patient_idxs = [3, 6]
test_data = []
for idx in test_patient_idxs:
    for day_selection in range(1,4):
        patient_data = match_files_from_patient(idx, day_selection, mode='CT_PT_SPINE')
        test_data.extend(patient_data)

seed = 544
K = int(0.1 * len(train_dev_data))
np.random.shuffle(train_dev_data)
dev_data = train_dev_data[:K]
train_data = train_dev_data[K:]
print('train: {}, dev: {}, test: {}'.format(len(train_data), len(dev_data), len(test_data)))

batch_size = 2

train_dataset = CTPTMaskDataset(train_data)
train_generator = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
dev_dataset = CTPTMaskDataset(dev_data, offset=(0,0), output_size=(512, 512), augment=False)
dev_generator = DataLoader(dev_dataset, batch_size=batch_size, shuffle=True, num_workers=6)

## Dataset verification
Check the dataset containers for correct data and shape.

In [None]:
idx = np.random.choice(len(train_dataset))
sample_data = train_dataset[idx]

ct_pt = sample_data['data']
ct, pt = ct_pt[0, :, :], ct_pt[1, :, :]
mask = sample_data['label']
print(idx, ct_pt.shape, mask.shape)

fig = plt.Figure(figsize=(10,20))
ax = plt.subplot(1,2,1)
ax.imshow(ct)
#ax.imshow(mask, alpha=0.7, cmap='jet')
ax.imshow(pt, alpha=0.3, cmap='jet')
bx = plt.subplot(1,2,2)
bx.imshow(pt)
plt.show()

## Create datalog folder

In [None]:

current_datetime = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
output_directory = 'logs-' + current_datetime
if not os.path.isdir(output_directory):
    os.makedirs(output_directory)

## Training initialization
Defines initial hyperparameters for learning, loss criterion, optimizer, etc...

In [None]:
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")

# model hyperparameters
lr_model = 0.0003
decay_step_size = 500
gamma = 2.0
focal_gain = 10.0

# declare model
channel_in, num_classes = 2, 1
model = UNet(channel_in, num_classes)
model.to(device)

# learning schema
criterion = MixedLoss(focal_gain, gamma)
optimizer = torch.optim.Adam(model.parameters(), lr=lr_model)
learning_scheduler = lr_scheduler.StepLR(optimizer, step_size=decay_step_size, gamma=0.1)

num_epochs = 200
ckpt_save_interval = 2

# loss log initialization
train_losses = []
val_losses = []
best_score = 0.0


## Training loop
General training scheme

In [None]:
for epoch in range(num_epochs):
    # run training step
    train_loss = train_net(model, 
                           device, 
                           train_generator,
                           batch_size,
                           criterion, 
                           optimizer,
                           learning_scheduler,
                           epoch, 
                           log_interval=100, 
                           print_log=True)

    # run validation step
    val_loss = test_net(model, 
                        device, 
                        dev_generator, 
                        print_log=True)

    # save best models
    if val_loss[0] > best_score:
        best_score = val_loss[0]
        torch.save(model.state_dict(), 
                   '{}/ckpt.model-{}.pt'.format(output_directory, epoch))

    # record losses to logs
    train_losses.append(train_loss)
    val_losses.append(val_loss)

## Training charts
Not sure if this works right now.

In [None]:
dice_scores, iou_scores = zip(*val_losses)
n = range(0, len(dice_scores))
fig = plt.figure(figsize=(12,8))
plt.plot(n, dice_scores, label='dice')
plt.plot(n, iou_scores, label='iou')
plt.grid('on')
plt.legend()

## Inference test
This isn't working right now.

In [None]:
ckpt_path = 'ckpt.model-87.pt'

use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
model = UNet(2,1)
model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
model.to(device)
model.eval()

patient_idx = 9
day_selection = 1

patient_data = get_ct_pt_mask_data(patient_idx, day_selection)
test_dataset = CTPTMaskDataset(patient_data, offset=(96, 96), output_size=(512, 512), augment=False)
test_generator = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=10)

test_loss = 0.0
correct = 0.0

spine_masks = np.zeros((len(patient_data), 512, 512))
ct_images = np.zeros((len(patient_data), 2, 512, 512))
with torch.no_grad():
    for idx in range(0, len(patient_data)):
        batch_data = test_dataset[idx]
        ct_data = batch_data['data']
        ct_images[idx, :, :] = ct_data
        
        cts = np.expand_dims(ct_data, axis=0)
        cts = torch.from_numpy(cts).to(device)

        outputs = model(cts)
        masks_probs = torch.squeeze(F.sigmoid(outputs))
        mask = masks_probs.cpu().numpy()
        spine_masks[idx, :, :] = mask
        
scio.savemat(os.path.join(output_directory, 'test-{}-{}.mat'.format(patient_idx, day_selection)), 
             {'mask':np.transpose(spine_masks, [1, 2, 0]),
              'ct': np.transpose(ct_images, [0, 2, 3, 1])})