In [None]:
# Load general libraries
import os
import numpy as np
import nibabel as nib
from tqdm.notebook import tqdm
import random

# Load plotting and images
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

# Load sklearn functions
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from imblearn.under_sampling import RandomUnderSampler

# For parallelization
import dask

# Tensorflow and training
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.backend import int_shape
from tensorflow.keras import backend as K
from keras.optimizers import Adam, SGD

from livelossplot import PlotLossesKeras

In [None]:
# Check for GPU support
if tf.test.gpu_device_name(): 
    print('Default GPU Device: \
    {}'.format(tf.test.gpu_device_name()))
else:
   print("Please install GPU version of TF")

In [None]:
# Load custom helper functions
from utility import *

In [None]:
# Load custom functions and network for training
from custom_unet import *
from metrics import *

In [None]:
def make_subplots(plot_imgs, i, depth=0):
    f, axarr = plt.subplots(1,8, figsize=(15,15))

    axarr[0].imshow(plot_imgs[i][:,:,depth])
    axarr[1].imshow(plot_imgs[i+1][:,:,depth])
    axarr[2].imshow(plot_imgs[i+2][:,:,depth])
    axarr[3].imshow(plot_imgs[i+3][:,:,depth])
    axarr[4].imshow(plot_imgs[i+4][:,:,depth])
    axarr[5].imshow(plot_imgs[i+5][:,:,depth])
    axarr[6].imshow(plot_imgs[i+6][:,:,depth])
    axarr[7].imshow(plot_imgs[i+7][:,:,depth])

In [None]:
def make_plots_depth(img, depth=3, offset=0):
    f, axarr = plt.subplots(1, depth, figsize=(15,15))
    
    for i in range(depth):
        axarr[i].imshow(img[:,:,offset+i])

## Import data
Pretrain on the dataset downloaded from [here](https://www.kaggle.com/polomarco/chest-ct-segmentation) and extract to directory `./transfer`

In [None]:
# Set locations of extracted images
image_dir = "./transfer/images/images"
lbl_dir = "./transfer/masks/masks"

# Set label map and images with missing segmentation masks
label_map = {"lung": 0, "heart": 1, "trachea":2}
no_mask_p = {"ID00149637202232704462834", "ID00222637202259066229764"}

In [None]:
# Function to load the downloaded images and their segmentation masks
# uses the above configured directories and takes one of the three labels:
# {trachea, lung, hear} to load the corresponding segm. mask
def get_transfer_images(label="trachea"):
    
    # Sub function to load pictures for a single patient
    def load_patient(patient):
        
        # Find all files for given patient
        p_files = [file for file in files if patient in file]
        p_file_tuples = [(file, int(file[file.find("_", -10)+1:-4])) for file in p_files]
        p_file_tuples = sorted(p_file_tuples, key=lambda x: x[1])
        
        img = []
        lbl = []
        # Iterate over files and reconstruct 3D image
        for img_f, l in p_file_tuples:
            
            patient_id_end = img_f.find("_", -10)
            patient_id = img_f[:patient_id_end]
            
            img_l = np.asarray(Image.open(image_dir + "/" + img_f))
            img.append(img_l)
            
            lbl_l = np.asarray(Image.open(f"{lbl_dir}/{patient_id}_mask_{l}.jpg"))
            lbl_l = lbl_l[:,:,label_map[label]]
            lbl.append(lbl_l)
            
        return np.moveaxis(np.array(img, dtype=np.float32), 0, -1), np.moveaxis(np.array(lbl, dtype=np.float32), 0, -1)
    
    
    # Get all files in directory
    files = [f for f in os.listdir(image_dir)]
    
    patients_set = set()
    images = []
    lbls = []
    
    # Extract patient ID from files
    for file in files:
        patient_id_end = file.find("_", -10)
        patient_id = file[:patient_id_end]
        
        patients_set.add(patient_id)
    
    # Load image for each found patient
    for p in tqdm(patients_set):
        
        if p not in no_mask_p:
            p_img, p_lbl = load_patient(p)
            images.append(p_img)
            lbls.append(p_lbl)
       
    print(f"Imported {len(images)} form {len(patients_set)} patients")
    print("Shape img:", images[0].shape, " Shape lbl:", lbls[0].shape)
    return images, lbls

#### Load images
Load images with trachea segmentation masks

In [None]:
imgs, lbls = get_transfer_images()

In [None]:
# Correct labels to only contain 0 and 1
def lable_correction(l):
    l[l > 0] = 1
    return l

dask_objs = []
for lbl in tqdm(lbls):
    dask_objs.append(dask.delayed(lable_correction)(lbl))
    
lbls = dask.compute(*dask_objs, njobs=16)

#### Validation split

In [None]:
imgs_train, imgs_valid, lbls_train, lbls_valid = train_test_split(
    imgs,
    lbls,
    test_size=0.1,
    random_state=42
)
print("Data", len(imgs_train), len(imgs_valid))

#### Conversion for training
Convert 3D images for training. Create a separate sample for each layer, where we consider 2 neighbors on each side
on the input and 1 neighbor on each side on the output. Thus the model will learn a mapping from a 5 channel 3D slice
to a 3 channel 3D slice

In [None]:
NEIGHBORS = 2
OUTPUT_NEIGHBORS = 1

In [None]:
train_imgs_sep = convert_depth_to_imgs_keras(imgs_train, neighbors=NEIGHBORS)
valid_imgs_sep = convert_depth_to_imgs_keras(imgs_valid, neighbors=NEIGHBORS)

In [None]:
train_lbls_sep = convert_depth_to_imgs_keras(lbls_train, neighbors=OUTPUT_NEIGHBORS)
valid_lbls_sep = convert_depth_to_imgs_keras(lbls_valid, neighbors=OUTPUT_NEIGHBORS)

##### Setup data for training

In [None]:
# Validate on full patients and load into contiguous memory arrays
valid_x = np.array(valid_imgs_sep, dtype=np.float32)
valid_y = np.array(valid_lbls_sep, dtype=np.float32)

# Save memory and delete unused variables
del valid_imgs_sep
del valid_lbls_sep

print("Validation data:", len(valid_x))

#### Data Augmentation
Augment training data with flips, rotations and noise for more robust training

In [None]:
def apply_transformations(img, lbl):
    
    # Flip left-right randomly
    choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
    img = tf.cond(choice < 0.5, lambda: img, lambda: tf.image.flip_left_right(img))
    lbl = tf.cond(choice < 0.5, lambda: lbl, lambda: tf.image.flip_left_right(lbl))
    
    # Flip up-down randomly
    choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
    img = tf.cond(choice < 0.5, lambda: img, lambda: tf.image.flip_up_down(img))
    lbl = tf.cond(choice < 0.5, lambda: lbl, lambda: tf.image.flip_up_down(lbl))
    
    # Rotate by random angle
    angle = tf.random.uniform(shape=[], minval=0, maxval=360, dtype=tf.int32)
    angle = tf.dtypes.cast(angle, tf.float32)   
    img = tfa.image.rotate(img, angle)
    lbl = tfa.image.rotate(lbl, angle)
    
    # Add noise to image
    noise = tf.random.normal(shape=tf.shape(img), mean=1.0, stddev=0.05, dtype=tf.float32)
    noise_img = tf.dtypes.cast(img, tf.float32) * noise
    
    return (noise_img, lbl)

In [None]:
train_x = []
train_y = []

# load only a subset of the training data
MAX_SAMPLES = 6000

# apply augmentations
for x, y in tqdm(zip(train_imgs_sep[:MAX_SAMPLES], train_lbls_sep[:MAX_SAMPLES]), total=MAX_SAMPLES):
    x_out, y_out = apply_transformations(x, y)
    train_x.append(np.array(x_out))
    train_y.append(np.array(y_out))

# Load into contiguous memory arrays
train_x = np.array(train_x, dtype=np.float32)
train_y = np.array(train_y, dtype=np.float32)

# Save memory
del train_imgs_sep
del train_lbls_sep

print("Training data:", len(train_x))

## Setup Model Training

#### Focal Tversky Loss
Implementation and further resource references at [Kaggle Post](https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch)

In [None]:
#Keras
ALPHA = 0.5    # False Positive Penalty
BETA = 1.0     # False Negative Penalty
GAMMA = 4.0    # Focus more on false predictions

def FocalTverskyLoss(targets, inputs, alpha=ALPHA, beta=BETA, gamma=GAMMA, smooth=1e-6):
    
        #flatten label and prediction tensors
        inputs = K.flatten(inputs)
        targets = K.flatten(targets)
        
        #True Positives, False Positives & False Negatives
        TP = K.sum((inputs * targets))
        FP = K.sum(((1-targets) * inputs))
        FN = K.sum((targets * (1-inputs)))
               
        Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
        FocalTversky = K.pow((1 - Tversky), gamma)
        
        return FocalTversky

### Training
Train custom U-Net architecture from [Github: karolzak/keras-unet](https://github.com/karolzak/keras-unet)

In [None]:
# Load network
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
    
    unet = custom_unet(
        train_x[0].shape,
        num_classes=OUTPUT_NEIGHBORS * 2 + 1, # Configure number of output channels
        filters=64, # number of filters in the first convolutional block (increased by factor 2 with depth)
        use_batch_norm=True, # use batch normalization
        dropout=0.2,  # set to value to use dropout after initial conv block
        dropout_change_per_layer=0.0, # keep dropout on each layer constant
        dropout_type='spatial', # use spatial dropout i.e. drop entire filters
        num_layers=4, # 4 convolutional blocks (original U-Net depth)
        upsample_mode='deconv', # use transposed convolutions in the upsampling part of the network
        use_dropout_on_upsampling=False # don't use dropout in the upsampling part of the network
    )
    
    unet.compile(
        optimizer=Adam(),
        loss=FocalTverskyLoss,
        metrics=[iou, iou_thresholded, tf.keras.metrics.AUC()]
    )
    
# unet.summary()

In [None]:
# Checkpoint best validation loss model
save_best_cb = tf.keras.callbacks.ModelCheckpoint(
    './model/best_checkpoint_transfer', monitor='val_loss', verbose=1, save_best_only=True,
    save_weights_only=False, mode='min', save_freq='epoch'
)

In [None]:
# Pretrain on the trachea dataset for 40 epochs over the loaded subset of data
EPOCHS = 40
BATCH_SIZE = 24

history = unet.fit(
    train_x, train_y,
    epochs=EPOCHS,
    batch_size = BATCH_SIZE,
    validation_data=(valid_x, valid_y),
    callbacks=[PlotLossesKeras(), save_best_cb]
)

### Store model
Store pretrained model for further use on the actual colon cancer dataset

In [None]:
os.makedirs("./model", exist_ok=True)
unet.save('./model/pretrained_transfer_trachea_1', overwrite=False)