In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

!pip install pydicom
import pydicom
from pydicom import dcmread

!pip install awscli
import awscli

!pip install pynrrd
import nrrd

!pip install SimpleITK==1.2.4
import SimpleITK as sitk

import os
from os import listdir
from os.path import join

import matplotlib.pyplot as plt
import numpy as np
import random

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator 
from tensorflow import keras
!pip install keras_unet
from keras_unet.models import custom_vnet

import logging

In [None]:
!cat /content/drive/My\ Drive/config/awscli.ini
path = "/content/drive/My Drive/config/awscli.ini"

import os
!export AWS_SHARED_CREDENTIALS_FILE=/content/drive/My\ Drive/config/awscli.ini
path = "/content/drive/My Drive/config/awscli.ini"
os.environ['AWS_SHARED_CREDENTIALS_FILE'] = path

!aws s3 cp s3://medical-image-segmentation/lungs/smaller-resampled/train-nrrd-resampled.zip .
!aws s3 cp s3://medical-image-segmentation/lungs/smaller-resampled/val-nrrd-resampled.zip .
!aws s3 cp s3://medical-image-segmentation/lungs/smaller-resampled/test-nrrd-resampled.zip .

!unzip train-nrrd-resampled
!unzip val-nrrd-resampled
!unzip test-nrrd-resampled

In [None]:
def val_gen():
  patient_list = listdir("val-nrrd-resampled")
  while True:
    random.shuffle(patient_list)
    for patient in patient_list:
      img_data = sitk.GetArrayFromImage(sitk.ReadImage(join("val-nrrd-resampled", patient, "image.nrrd")))
      mask_data = sitk.GetArrayFromImage(sitk.ReadImage(join("val-nrrd-resampled", patient, "mask.nrrd")))

      img = img_data.reshape((1, img_data.shape[0], img_data.shape[1], img_data.shape[2], 1))
      mask = mask_data.reshape((1, mask_data.shape[0], mask_data.shape[1], mask_data.shape[2], 1))
      yield img, mask

def train_gen():
  patient_list = listdir("train-nrrd-resampled")
  while True:
    random.shuffle(patient_list)
    for patient in patient_list:
      img_data = sitk.GetArrayFromImage(sitk.ReadImage(join("train-nrrd-resampled", patient, "image.nrrd")))
      mask_data = sitk.GetArrayFromImage(sitk.ReadImage(join("train-nrrd-resampled", patient, "mask.nrrd")))

      img = img_data.reshape((1, img_data.shape[0], img_data.shape[1], img_data.shape[2], 1))
      mask = mask_data.reshape((1, mask_data.shape[0], mask_data.shape[1], mask_data.shape[2], 1))
      yield img, mask

In [None]:
import keras.backend as K
import math
def DiceLoss(targets, inputs, smooth=1e-6):
    inputs = K.flatten(inputs)
    targets = K.flatten(targets)
    
    intersection = K.sum(targets * inputs)
    dice = (2*intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    return 1 - dice

In [None]:
model = custom_vnet(
    input_shape=(None, None, None, 1),
    use_batch_norm=True,
    num_classes=1,
    filters=20,
    dropout=0.25,
    output_activation='sigmoid',
    num_layers=3
)

In [None]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = custom_vnet(
        input_shape=(None, None, None, 1),
        use_batch_norm=True,
        num_classes=1,
        filters=16,
        dropout=0.25,
        output_activation='sigmoid',
        )
    model.compile(optimizer='adam', loss=DiceLoss, metrics=[DiceLoss, tf.keras.metrics.Precision(), tf.keras.metrics.Recall()], run_eagerly=False)

In [None]:
model.compile(optimizer='adam', loss=DiceLoss, metrics=[DiceLoss, tf.keras.metrics.Precision(), tf.keras.metrics.Recall()], run_eagerly=True)

In [None]:
train = train_gen()
val = val_gen()
model.fit_generator(generator=train,
                    steps_per_epoch=42,
                    validation_data=val,
                    validation_steps=6,
                    epochs=50)

In [None]:
model.save(f'LCTSC-preliminary-3d-model.h5')

In [None]:
def show_predictions(path):
  for patient in listdir(path):
    print('showing patient: ', patient)
    img_data = sitk.GetArrayFromImage(sitk.ReadImage(join(path, patient, "image.nrrd")))
    mask_data = sitk.GetArrayFromImage(sitk.ReadImage(join(path, patient, "mask.nrrd")))

    img = img_data.reshape((img_data.shape[0], img_data.shape[1], img_data.shape[2], 1))
    mask = mask_data.reshape((mask_data.shape[0], mask_data.shape[1], mask_data.shape[2], 1))
    pred = model.predict(img.reshape((1, img_data.shape[0], img_data.shape[1], img_data.shape[2], 1))).reshape((img_data.shape[0], img_data.shape[1], img_data.shape[2], 1)) > 0.5
    
    rows = len(img) // 10 + 1
    plt.axis('off')
    fig = plt.figure(figsize=(200,80), dpi= 100)
    for i in range(len(img)):
        row_num = i // 10
        col_num = i % 10
        img_slice = img[i].reshape((img_data.shape[1], img_data.shape[2]))
        mask_slice = mask[i].reshape((mask_data.shape[1], mask_data.shape[2]))
        pred_slice = pred[i].reshape((img_data.shape[1], img_data.shape[2]))
        ax = fig.add_subplot(rows,30,i*3+1)
        ax.imshow(img_slice, cmap="gray")
        ax.set_axis_off()
        ax = fig.add_subplot(rows,30,i*3+2)
        ax.imshow(mask_slice, cmap="gray")
        ax.set_axis_off()
        ax = fig.add_subplot(rows,30,i*3+3)
        ax.imshow(pred_slice, cmap="gray")
        ax.set_axis_off()
    plt.show()