In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import model_from_json

import nibabel as nib
import gdown
import imageio
from IPython.display import clear_output
import os

from utils.generate_datasets import norm, resize_img, resize_mask, draw_grid, elastic_transform, generate_dataset
from utils.generate_models import generate_unet, generate_unet_512
from utils.make_predictions import color_mask, create_mask, show_training_predictions, plot_losses

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [None]:
# To plot using LaTeX, sometimes it gives trouble, in that case comment these two lines
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

In [None]:
# Set CPU as the only available physical device, as my GPU memory is not enough sometimes
# tf.config.set_visible_devices([], 'GPU')

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

## Dataset

The images are provided by [COVID-19 CT segmentation dataset
](http://medicalsegmentation.com/covid19/), it is a dataset of 100 axial CT images from more than 40 patients with COVID-19. The images were segmented by a radiologist using 3 labels: 

- Ground-glass (mask value = 1)
- Consolidation (mask value = 2)
- Pleural effusion (mask value = 3)

In [None]:
! mkdir images
! mkdir gif

In [None]:
# Training images 
gdown.download('https://drive.google.com/uc?id=1nYbe37SmMIwBQJ35MR3coDEKqaMeuiCu',
               output='images/tr_im.nii', quiet=True)

# Training masks
gdown.download('https://drive.google.com/uc?id=16Wdd97TAI3IBFTaQ7yth1qSo7wsEcZCc',
               output='images/tr_mask.nii', quiet=True)

# Validation dataset
gdown.download('https://drive.google.com/uc?id=1xNVxrnIlO96ydXy5b6rLLuAvgbFT2Tz0',
               output='images/val_im.nii', quiet=True)

In [None]:
imgs = nib.load('images/tr_im.nii')
masks = nib.load('images/tr_mask.nii')
validation = nib.load('images/val_im.nii')

x_o = imgs.get_fdata()
y_o = masks.get_fdata()

x_o = np.array([norm(resize_img(x_o[:,:,i], 512)) for i in range(imgs.shape[2])])
y_o = np.array([resize_mask(y_o[:,:,i], 512) for i in range(imgs.shape[2])])

x_o.shape, y_o.shape

## Generating datasets and training the model

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        clear_output(wait=True)
        show_training_predictions(self.model, x_val, y_val, size, epoch, logs)

### 512 $\times$ 512 model

In [None]:
model_512 = generate_unet_512()
model_512.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

In [None]:
# Hyperparameters
samples = 4000
size = 512

TRAIN_LENGTH = samples
BATCH_SIZE = 100
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
EPOCHS = 40

In [None]:
checkpoint_512 = ModelCheckpoint("models/best_weights_512.h5", 
                                 monitor='val_loss', 
                                 verbose=1,
                                 save_best_only=True, 
                                 mode='min')

The 512$\times$512 dataset is quite heavy (for my computer), so it is better to divide it in *batches*.

In [None]:
batches = 10
imgs_per_batch = int(imgs.shape[2] / batches)
samples_per_batch = int(samples / batches)
losses = []
val_losses = []

for i in range(batches):
    print("Generating augmented 512x512 dataset of size {}".format(samples_per_batch))
    x_val, y_val, x, y = generate_dataset(
                       imgs.slicer[:, :, imgs_per_batch*i:imgs_per_batch*(i+1)], 
                       masks.slicer[:, :, imgs_per_batch*i:imgs_per_batch*(i+1)], 
                       size, samples_per_batch)
    
    print("Training 512x512 (step {})".format(i+1))
    trained_512 = model_512.fit(x, y, validation_data=(x_val, y_val),
                            epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH,
                            callbacks=[checkpoint_512, DisplayCallback()])
    
    losses.append(trained_512.history['loss'])
    val_losses.append(trained_512.history['val_loss'])

In [None]:
plot_losses(losses[0], val_losses[0])

Save the model.

In [None]:
# serialize model to JSON
model_json = model_512.to_json()
with open('models/model_512.json', 'w') as json_file:
    json_file.write(model_json)
    
# serialize weights to HDF5
model_512.save_weights('models/model_512.h5')
print('Saved model to disk')

### 224 $\times$ 224 model

In [None]:
model_224 = generate_unet(224)
model_224.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

In [None]:
# Hyperparameters
samples = 4000
size = 224

TRAIN_LENGTH = samples
BATCH_SIZE = 100
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
EPOCHS = 40

In [None]:
checkpoint_224 = ModelCheckpoint("models/best_weights_224.h5", 
                                 monitor='val_loss', 
                                 verbose=1,
                                 save_best_only=True, 
                                 mode='min')

print("Generating augmented 224x224 dataset of size {}".format(samples))
x_val, y_val, x, y = generate_dataset(imgs, masks, size, samples)

print("Training 224x224")
trained_224 = model_224.fit(x, y, validation_data=(x_val, y_val),
                            epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH,
                            callbacks = [checkpoint_224, DisplayCallback()])

In [None]:
plot_losses(trained_224.history['loss'], trained_224.history['val_loss'])

Save the model.

In [None]:
# serialize model to JSON
model_json = model_224.to_json()
with open('models/model_224.json', 'w') as json_file:
    json_file.write(model_json)
# serialize weights to HDF5
model_224.save_weights('models/model_224.h5')
print('Saved model to disk')

### 192 $\times$ 192 model

In [None]:
model_192 = generate_unet(192)
model_192.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

In [None]:
# Hyperparameters
samples = 4000
size = 192

TRAIN_LENGTH = samples
BATCH_SIZE = 100
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
EPOCHS = 40

In [None]:
checkpoint_192 = ModelCheckpoint("models/best_weights_192.h5", 
                                 monitor='val_loss', 
                                 verbose=1,
                                 save_best_only=True, 
                                 mode='min')

print("Generating augmented 192x192 dataset of size {}".format(samples))
x_val, y_val, x, y = generate_dataset(imgs, masks, size, samples)

print("Training 192x192")
trained_192 = model_192.fit(x, y, validation_data=(x_val, y_val),
                            epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH,
                            callbacks = [checkpoint_192, DisplayCallback()])

In [None]:
plot_losses(trained_192.history['loss'], trained_192.history['val_loss'])

Save the model.

In [None]:
# serialize model to JSON
model_json = model_192.to_json()
with open('models/model_192.json', 'w') as json_file:
    json_file.write(model_json)
# serialize weights to HDF5
model_192.save_weights('models/model_192.h5')
print('Saved model to disk')

## GIF

Create a simple animation of the training process for one of the models.

In [None]:
images = []
filenames = ['gif/192_{}.jpg'.format(i) for i in range(EPOCHS)]

for filename in filenames:
    images.append(imageio.imread(filename))
    
imageio.mimsave('gif/training_192.gif', images, duration=0.4, loop=0, fps=30)