# COVID-19 CT-scan segmentation with Unet3+

## Setting up

In [None]:
!cp -r /kaggle/input/source-covid/* /kaggle/working/

In [None]:
%pip install -r requirements.txt

In [None]:
import os
import shutil
from glob import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.unet3plus import unet3plus
from src.losses import init_num_classes, iou, dice_coef, dice_coef_loss

import tensorflow as tf
from tensorflow.keras import ops
import keras_cv

import seaborn as sns

## Config

In [None]:
input_shape = (512, 512, 3)
num_classes = 3
batch_size = 2
gradient_accumulation_steps = 16 # Actual batch size = batch_size * gradient_accumulation_steps
num_epochs = 1000
learning_rate = 7e-5
weight_decay = 0.0
initial_epoch = 0

init_num_classes(num_classes)

## Datasets loader and augmentation

In [None]:
class Augment(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.augment_RandomFlip = tf.keras.layers.RandomFlip(mode='horizontal')
        self.augment_RandomCutout = keras_cv.layers.RandomCutout(height_factor=0.1, width_factor=0.1, fill_mode="constant", fill_value=0.0)
        self.augment_RandomContrast = tf.keras.layers.RandomContrast(factor = 0.2)
        self.augment_RandomBrightness = tf.keras.layers.RandomBrightness(factor = 0.2, value_range=[0.0, 1.0])
        
    def call(self, inputs, labels):
        labels = labels[:,:,:,0]
        labels = tf.stack([labels, labels, labels], -1)
        labels = tf.cast(labels, 'float32')
        
        ouput = tf.concat([inputs, labels], -1)
        
        ouput =  self.augment_RandomFlip(ouput)
        for i in range(20):
            ouput =  self.augment_RandomCutout(ouput)
        labels = ouput[:,:,:,4]
        labels = tf.expand_dims(tf.cast(labels, 'uint8'),axis = -1)
        
        ouput = self.augment_RandomContrast(ouput)
        ouput = self.augment_RandomBrightness(ouput)   
        inputs = ouput[:,:,:,0:3]
        
        return inputs, labels

In [None]:
def load_data(path):
  train_images = sorted(glob(os.path.join(path, 'train', 'images', '*')))
  train_masks = sorted(glob(os.path.join(path, 'train', 'masks', '*')))

  test_images = sorted(glob(os.path.join(path, 'test', 'images', '*')))
  test_masks = sorted(glob(os.path.join(path, 'test', 'masks', '*')))

  validation_images = sorted(glob(os.path.join(path, 'val', 'images', '*')))
  validation_masks = sorted(glob(os.path.join(path, 'val', 'masks', '*')))

  return (train_images, train_masks), (test_images, test_masks), (validation_images, validation_masks)


def read_image(path):
  path = path.decode()

  img = tf.keras.utils.load_img(
      path,
      color_mode="rgb",
      target_size=(input_shape[0], input_shape[1]),
      interpolation="bilinear",
  )
  img = tf.keras.utils.img_to_array(img, dtype='float32')
  return img


def read_mask(path):
  path = path.decode()

  img = tf.keras.utils.load_img(
      path,
      color_mode="grayscale",
      target_size=(input_shape[0], input_shape[1]),
      interpolation="bilinear",
  )
  img = tf.keras.utils.img_to_array(img, dtype='float32')
  return img


def tf_parse(image, mask):
  def _parse(image, mask):
    image = read_image(image)
    mask = read_mask(mask)
    return image, mask

  image, mask = tf.numpy_function(_parse, [image, mask], [tf.float32, tf.float32])
  image.set_shape([input_shape[0], input_shape[1], input_shape[2]])
  image = tf.cast(image, tf.float32) / 255.0
  mask.set_shape([input_shape[0], input_shape[1], 1])
  return image, mask


def tf_dataset(image, mask, batch=batch_size, aug=False):
  dataset = tf.data.Dataset.from_tensor_slices((image, mask))
  dataset = dataset.map(tf_parse, num_parallel_calls=tf.data.AUTOTUNE)
  num_samples = len(dataset)
  dataset = dataset.shuffle(num_samples)
  dataset = dataset.batch(batch)
  dataset = dataset.map(Augment(), num_parallel_calls=tf.data.AUTOTUNE) if (aug == True) else dataset
  dataset = dataset.prefetch(tf.data.AUTOTUNE)
  return dataset, num_samples

In [None]:
def display_images(display_list):
    plt.figure(figsize=(10, 10))
    title = ['Input Image', 'True Mask', 
             'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

In [None]:
# Load datasets
(train_images, train_masks), (test_images, test_masks), (validation_images, validation_masks) = load_data('dataset')

train_dataset, train_num_samples = tf_dataset(train_images, train_masks, batch_size, aug=False)
test_dataset, test_num_samples = tf_dataset(test_images, test_masks, batch_size, aug=False)
validation_dataset, val_num_samples = tf_dataset(validation_images, validation_masks, batch_size, aug=False)

In [None]:
# Visualize the training dataset
for image, mask in train_dataset.take(3):
    display_images([image[0], mask[0]])

## The Unet3+ model

In [None]:
model = unet3plus(input_shape=input_shape, num_classes=num_classes, use_pretrain = False, fine_tune_at = False)

model.compile(
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        gradient_accumulation_steps=gradient_accumulation_steps
    ),
    loss=dice_coef_loss(),
    metrics=[
        dice_coef,
        iou,
    ],
)

model.summary()

## Run experiment

In [None]:
csv_logger = tf.keras.callbacks.CSVLogger(
    'logs/result.csv',
    separator=",",
    append=True
)

model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    'logs/checkpoint.model.keras',
    monitor="val_loss",
    verbose=0,
    mode="min",
)

early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=5,
    verbose=1,
    mode="min",
    restore_best_weights=True,
    start_from_epoch=10,
)
        
callbacks=[
    csv_logger,
    model_checkpoint, 
    early_stopping,
]

In [None]:
unbatch_train_ds = train_dataset.unbatch()    
true_masks = list(unbatch_train_ds.map(lambda x, y: y))

pixel_0 = pixel_1 = pixel_2 = 0
for mask in true_masks:
   pixel_0 += np.sum(mask==0)
   pixel_1 += np.sum(mask==1)
   pixel_2 += np.sum(mask==2)

total_pixel = np.sum([pixel_0, pixel_1, pixel_2])

weight_for_0 = pixel_0 / total_pixel
weight_for_1 = pixel_1 / total_pixel
weight_for_2 = pixel_2 / total_pixel

print('Weight for class 0: {:.10f}'.format(weight_for_0))
print('Weight for class 1: {:.10f}'.format(weight_for_1))
print('Weight for class 2: {:.10f}'.format(weight_for_2))

class_weights = tf.constant([weight_for_0, weight_for_1, weight_for_2])
class_weights = class_weights/tf.reduce_sum(class_weights)

def add_sample_weights(image=0, label=0):
    class_weights = tf.constant([weight_for_0, weight_for_1, weight_for_2])
    class_weights = class_weights/tf.reduce_sum(class_weights)
    sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))
    return image, label, sample_weights

In [None]:
def run_experiment(model, resume=False):
    if resume == False:
        initial_epoch=0
        try:
            shutil.rmtree('logs')
        except:
            print('directory not found')
        os.mkdir('logs')
    else:
        model.load_weights('logs/checkpoint.model.keras')
        df = pd.read_csv('logs/result.csv')
        initial_epoch = df['epoch'].values[-1]

    model.fit(
        train_dataset.map(add_sample_weights),
        epochs=num_epochs,
        initial_epoch=initial_epoch,
        callbacks=callbacks,
        validation_data=validation_dataset,
    )

    model.save("logs/model.keras")

    return model

In [14]:
model = run_experiment(model, resume=False)

## Model evaluation

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask

In [None]:
def get_test_image_and_annotation_arrays():
    ds = test_dataset.unbatch()
    ds_num_img = test_num_samples
    ds = ds.batch(ds_num_img)
    
    for y_true_images, y_true_segments in ds.take(ds_num_img):
        y_true_images = y_true_images
        y_true_segments = y_true_segments
    
    return y_true_images, y_true_segments

y_true_images, y_true_segments = get_test_image_and_annotation_arrays()

In [None]:
def plot_history(csv):
    df = pd.read_csv(csv)
    df[['epoch', 'loss', 'val_loss']].plot(
        x='epoch',
        y=['loss', 'val_loss'],
        xlabel='epoch',
        ylabel='loss',
        title='Train and Validation Loss Over Epochs'
    )
    plt.legend()
    plt.grid()
    plt.show()

In [None]:
def compute_metrics(y_trues, y_preds, show_ncm=False):
    y_trues = y_trues.numpy()
    y_preds = y_preds.numpy()
    C00 = C01 = C02 = C10 = C11 = C12 = C20 = C21 = C22 = 0
    
    for y_true, y_pred in zip(y_trues, y_preds):
        C00 += np.sum((y_true == 0) & (y_pred == 0))
        C01 += np.sum((y_true == 0) & (y_pred == 1))
        C02 += np.sum((y_true == 0) & (y_pred == 2))
        
        C10 += np.sum((y_true == 1) & (y_pred == 0))
        C11 += np.sum((y_true == 1) & (y_pred == 1))
        C12 += np.sum((y_true == 1) & (y_pred == 2))
        
        C20 += np.sum((y_true == 2) & (y_pred == 0))
        C21 += np.sum((y_true == 2) & (y_pred == 1))
        C22 += np.sum((y_true == 2) & (y_pred == 2))
    
    TP0 = C00
    TP1 = C11
    TP2 = C22
    
    FP0 = C01+C02
    FP1 = C10+C12
    FP2 = C20+C21
    
    FN0 = C10+C20
    FN1 = C01+C21
    FN2 = C02+C12
    
    precision0 = TP0/(TP0+FP0)
    precision1 = TP1/(TP1+FP1)
    precision2 = TP2/(TP2+FP2)
    precision = [precision0, precision1, precision2]
    
    recall0 = TP0/(TP0+FN0)
    recall1 = TP1/(TP1+FN1)
    recall2 = TP2/(TP2+FN2)
    recall = [recall0, recall1, recall2]
    
    dice0 = 2 * TP0 / (2 * TP0 + FN0 + FP0)
    dice1 = 2 * TP1 / (2 * TP1 + FN1 + FP1)
    dice2 = 2 * TP2 / (2 * TP2 + FN2 + FP2)
    dice_score = [dice0, dice1, dice2]
    
    iou0 = TP0 / (TP0 + FN0 + FP0)
    iou1 = TP1 / (TP1 + FN1 + FP1)
    iou2 = TP2 / (TP2 + FN2 + FP2)
    iou_score = [iou0 , iou1, iou2]
    
    cm = [[C00,C01,C02],
          [C10,C11,C12],
          [C20,C21,C22]]

    display_string_list = ["Mask {}: IOU: {} Dice Score: {}".format(idx, i, dc) for idx, (i, dc) in enumerate(zip(np.round(iou_score, 4), np.round(dice_score, 4)))]
    display_string = "\n\n".join(display_string_list)
    print(display_string)
    
    display_string_list = ["Mask {}: Precision: {} Recall: {}".format(idx, i, dc) for idx, (i, dc) in enumerate(zip(np.round(precision, 4), np.round(recall, 4)))]
    display_string = "\n\n".join(display_string_list)
    print(f'\n{display_string}')
    
    print(f"\nMean dice score: {round(np.mean(dice_score),4)}\n")
    print(f"Mean iou: {round(np.mean(iou_score),4)}")

    if show_ncm==True:
      ncm = np.round(cm/np.sum(cm, axis=1).reshape(-1,1),4)
      fig, ax = plt.subplots(figsize=(12, 8))
      ax = sns.heatmap(ncm, annot=True, cmap='Blues', fmt='g', annot_kws={"size":15})
      ax.set_title('Normalized confusion matrix\n\n', fontsize=15);
      ax.set_xlabel('\nPredicted label', fontsize=15)
      ax.set_ylabel('True label ', fontsize=15);
      ax.xaxis.set_ticklabels(['Non lung nor infection','Lung','Infection'], fontsize=13)
      ax.yaxis.set_ticklabels(['Non lung nor infection','Lung','Infection'], fontsize=13)
      plt.savefig('logs/normalized_confusion_atrix.png')
      plt.show()

In [None]:
def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display_images([image[0], mask[0], create_mask(pred_mask)[0]])
            compute_metrics(mask[0], create_mask(pred_mask)[0])

In [None]:
model.load_weights('logs/model.keras')

y_pred_masks = model.predict(y_true_images, batch_size=batch_size)
y_pred_masks = create_mask(y_pred_masks)

compute_metrics(y_true_segments, y_pred_masks, show_ncm=True)
plot_history('logs/result.csv')

show_predictions(test_dataset, num=3)

In [None]:
!zip -r logs.zip logs