In [None]:
import os
from os import listdir
from os.path import join
import logging
import random

import awscli

import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from keras_unet.models import custom_vnet
import keras.backend as K

from google.colab import drive
drive.mount("/content/drive", force_remount=True)
import SimpleITK as sitk

In [None]:
!cat /content/drive/My\ Drive/config/awscli.ini
!export AWS_SHARED_CREDENTIALS_FILE=/content/drive/My\ Drive/config/awscli.ini
PATH = "/content/drive/My Drive/config/awscli.ini"
os.environ['AWS_SHARED_CREDENTIALS_FILE'] = PATH

!aws s3 cp s3://medical-image-segmentation/lungs/smaller-resampled/train-nrrd-resampled.zip .
!aws s3 cp s3://medical-image-segmentation/lungs/smaller-resampled/val-nrrd-resampled.zip .
!aws s3 cp s3://medical-image-segmentation/lungs/smaller-resampled/test-nrrd-resampled.zip .

!unzip train-nrrd-resampled
!unzip val-nrrd-resampled
!unzip test-nrrd-resampled

In [None]:
def data_gen(split):
    """
    Creates a data generator depending on the split argument
    Args:
        split (str): the name of the split - train, val, or test\
    Returns:
        the data generator
    """
    directory = split + "-nrrd-resampled"
    patient_list = listdir(directory)
    while True:
        random.shuffle(patient_list)
        for patient in patient_list:
            # reading image and mask nrrd files, getting data arrays, and reshaping them
            img_sitk = sitk.ReadImage(join(directory, patient, "image.nrrd"))
            mask_sitk = sitk.ReadImage(join(directory, patient, "mask.nrrd"))
            
            img_data = sitk.GetArrayFromImage(img_sitk)
            mask_data = sitk.GetArrayFromImage(mask_sitk)
            
            z_dim, x_dim, y_dim = img_data.shape
            img = img_data.reshape((1, z_dim, x_dim, y_dim, 1))
            mask = mask_data.reshape((1, z_dim, x_dim, y_dim, 1))
            yield img, mask

In [None]:
def dice_loss(targets, inputs, smooth=1e-6):
    """
    Computes the dice loss given targets and predictions
    Args:
        targets (array): the ground truth masks
        inputs (array): the predicted masks
        smooth (num): additional overlapping surface area
    Returns:
        the dice loss value
    """
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    intersection = K.sum(targets * inputs)
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice

In [None]:
# initializing multi GPU model through distributed strategy
STATS = [dice_loss, tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
STRATEGY = tf.distribute.MirroredStrategy()
with STRATEGY.scope():
    MODEL = custom_vnet(
        input_shape=(None, None, None, 1),
        use_batch_norm=True,
        num_classes=1,
        filters=16,
        dropout=0.25,
        output_activation='sigmoid',
        )
    MODEL.compile(optimizer='adam', loss=dice_loss, metrics=STATS, run_eagerly=False)

In [None]:
# initializing single GPU model
MODEL = custom_vnet(
    input_shape=(None, None, None, 1),
    use_batch_norm=True,
    num_classes=1,
    filters=20,
    dropout=0.25,
    output_activation='sigmoid',
    num_layers=3
)
MODEL.compile(optimizer='adam', loss=dice_loss, metrics=STATS, run_eagerly=True)

In [None]:
TRAIN = data_gen("train")
VAL = data_gen("val")
model.fit_generator(generator=TRAIN,
                    steps_per_epoch=42,
                    validation_data=VAL,
                    validation_steps=6,
                    epochs=50)

In [None]:
MODEL.save(f'LCTSC-preliminary-3d-model.h5')

In [None]:
def show_predictions(path):
    """
    Displays images, masks, and predictions side by side in a grid format
    Args:
        path (str): the path to the folder containing patient folders, which each
                    contain an image and mask nrrd file
    Returns:
        None
    """
    for patient in listdir(path):
        logging.info('showing patient: ', patient)
        img_data = sitk.GetArrayFromImage(sitk.ReadImage(join(path, patient, "image.nrrd")))
        mask_data = sitk.GetArrayFromImage(sitk.ReadImage(join(path, patient, "mask.nrrd")))
        z_dim, x_dim, y_dim = img_data.shape
        
        # getting the image, mask, and predicted mask as arrays
        img = img_data.reshape((z_dim, x_dim, y_dim, 1))
        mask = mask_data.reshape((z_dim, x_dim, y_dim 1))
        pred = MODEL.predict(img.reshape((1, z_dim, x_dim, y_dim, 1))).reshape((z_dim, x_dim, y_dim, 1)) > 0.5
        rows = len(img) // 10 + 1
        plt.axis('off')
        fig = plt.figure(figsize=(200, 80), dpi=100)
        
        # showing image, mask, prediction respectively in matplotlib plot
        for i in range(len(img)):
            img_slice = img[i].reshape((img_data.shape[1], img_data.shape[2]))
            mask_slice = mask[i].reshape((mask_data.shape[1], mask_data.shape[2]))
            pred_slice = pred[i].reshape((img_data.shape[1], img_data.shape[2]))
            ax_img = fig.add_subplot(rows, 30, i*3+1)
            ax_img.imshow(img_slice, cmap="gray")
            ax_img.set_axis_off()
            ax_mask = fig.add_subplot(rows, 30, i*3+2)
            ax_mask.imshow(mask_slice, cmap="gray")
            ax_mask.set_axis_off()
            ax_pred = fig.add_subplot(rows, 30, i*3+3)
            ax_pred.imshow(pred_slice, cmap="gray")
            ax_pred.set_axis_off()
        plt.show()

In [None]:
show_predictions("test-nrrd-resampled")