In [2]:
import os
import sys
import random
import warnings

import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt

from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label

from keras.models import Model, load_model
from keras.layers import Input
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K

import tensorflow as tf

Using TensorFlow backend.


In [1]:
#Download the dataset
!wget https://zenodo.org/record/3757476/files/COVID-19-CT-Seg_20cases.zip?download=1
!wget https://zenodo.org/record/3757476/files/Lung_and_Infection_Mask.zip?download=1

--2020-05-28 13:49:31--  https://zenodo.org/record/3757476/files/COVID-19-CT-Seg_20cases.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.184.117.155
Connecting to zenodo.org (zenodo.org)|188.184.117.155|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1105395143 (1.0G) [application/octet-stream]
Saving to: ‘COVID-19-CT-Seg_20cases.zip?download=1’


2020-05-28 13:52:08 (7.01 MB/s) - ‘COVID-19-CT-Seg_20cases.zip?download=1’ saved [1105395143/1105395143]

--2020-05-28 13:52:09--  https://zenodo.org/record/3757476/files/Lung_and_Infection_Mask.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.184.117.155
Connecting to zenodo.org (zenodo.org)|188.184.117.155|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11714838 (11M) [application/octet-stream]
Saving to: ‘Lung_and_Infection_Mask.zip?download=1’


2020-05-28 13:53:01 (4.12 MB/s) - ‘Lung_and_Infection_Mask.zip?download=1’ saved [11714838/11714838]



In [0]:
#Rename zip file
!mv COVID-19-CT-Seg_20cases.zip?download=1 COVID-19-CT-Seg_20cases.zip
!mv Lung_and_Infection_Mask.zip?download=1 Lung_and_Infection_Mask.zip

In [4]:
#extract data 
!mkdir images && unzip COVID-19-CT-Seg_20cases.zip -d images
!mkdir masks && unzip Lung_and_Infection_Mask.zip -d masks
!rm images/ReadMe.txt


Archive:  COVID-19-CT-Seg_20cases.zip
  inflating: images/coronacases_001.nii.gz  
  inflating: images/coronacases_002.nii.gz  
  inflating: images/coronacases_003.nii.gz  
  inflating: images/coronacases_004.nii.gz  
  inflating: images/coronacases_005.nii.gz  
  inflating: images/coronacases_006.nii.gz  
  inflating: images/coronacases_007.nii.gz  
  inflating: images/coronacases_008.nii.gz  
  inflating: images/coronacases_009.nii.gz  
  inflating: images/coronacases_010.nii.gz  
  inflating: images/radiopaedia_10_85902_1.nii.gz  
  inflating: images/radiopaedia_10_85902_3.nii.gz  
  inflating: images/radiopaedia_14_85914_0.nii.gz  
  inflating: images/radiopaedia_27_86410_0.nii.gz  
  inflating: images/radiopaedia_29_86490_1.nii.gz  
  inflating: images/radiopaedia_29_86491_1.nii.gz  
  inflating: images/radiopaedia_36_86526_0.nii.gz  
  inflating: images/radiopaedia_40_86625_0.nii.gz  
  inflating: images/radiopaedia_4_85506_1.nii.gz  
  inflating: images/radiopaedia_7_85703_0.nii

In [0]:
#load data

In [0]:
#IOU metrics
def mean_iou(y_true, y_pred):
    prec = []
    for t in np.arange(0.5, 1.0, 0.05):
        y_pred_ = tf.to_int32(y_pred > t)
        score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, 2)
        K.get_session().run(tf.local_variables_initializer())
        with tf.control_dependencies([up_opt]):
            score = tf.identity(score)
        prec.append(score)
    return K.mean(K.stack(prec), axis=0)

In [0]:
#Dice Coefficient metrics
def dice_coef(y_true, y_pred, smooth=10e-6):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)
def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [23]:
# Build U-Net model
inputs = Input((None, None, 3))

c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (inputs)
c1 = Dropout(0.1) (c1)
c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c1)
p1 = MaxPooling2D((2, 2)) (c1)

c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p1)
c2 = Dropout(0.1) (c2)
c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c2)
p2 = MaxPooling2D((2, 2)) (c2)

c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p2)
c3 = Dropout(0.2) (c3)
c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c3)
p3 = MaxPooling2D((2, 2)) (c3)

c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p3)
c4 = Dropout(0.2) (c4)
c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (p4)
c5 = Dropout(0.3) (c5)
c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c5)

u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u6)
c6 = Dropout(0.2) (c6)
c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c6)

u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u7)
c7 = Dropout(0.2) (c7)
c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c7)

u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u8)
c8 = Dropout(0.1) (c8)
c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c8)

u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u9)
c9 = Dropout(0.1) (c9)
c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c9)

outputs = Conv2D(4, (1, 1), activation='softmax') (c9)

model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss= dice_coef_loss, metrics= [dice_coef])
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, None, None, 6 1792        input_4[0][0]                    
__________________________________________________________________________________________________
dropout_10 (Dropout)            (None, None, None, 6 0           conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, None, None, 6 36928       dropout_10[0][0]                 
____________________________________________________________________________________________

In [0]:
#training function
results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=8, epochs=60)

In [0]:
#plot the training progress
fig, axs = plt.subplots(1, 2, figsize = (15, 4))

training_loss = results.history['loss']
validation_loss = results.history['val_loss']

training_accuracy = results.history['dice_coef']
validation_accuracy = results.history['val_dice_coef']

epoch_count = range(1, len(training_loss) + 1)

axs[0].plot(epoch_count, training_loss, 'r--')
axs[0].plot(epoch_count, validation_loss, 'b-')
axs[0].legend(['Training Loss', 'Validation Loss'])

axs[1].plot(epoch_count, training_accuracy, 'r--')
axs[1].plot(epoch_count, validation_accuracy, 'b-')
axs[1].legend(['Training Accuracy', 'Validation Accuracy'])