In [None]:
# run notebooks with functions necessary for this notebook
# please modify the path if it differs
%run /content/src/image_preprocessing.ipynb

In [None]:
import random
import cv2
import numpy as np
import sys
import pandas as pd 
import matplotlib.pyplot as plt
import tensorflow as tf
import pydicom
import tensorflow.keras

from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from tqdm.notebook import tqdm

import wandb
from wandb.keras import WandbCallback

In [None]:
def get_balanced_train_valid_tuples(dataframe, train_data, valid_data):
    """
    Returns balanced training data IDs and their labels and validation data IDs and labels.
        
    :param dataframe: the dataframe with image IDs and labels
    :param train_data: list with training data studies
    :param valid_data: list with validation data studies
        
    :return: training IDs, training labels, validation IDs, validation labels
    """
    valid_df = dataframe.loc[dataframe['Study'].isin(valid_data)].reset_index().drop("index", axis=1)
    train_df = dataframe.loc[dataframe['Study'].isin(train_data)].reset_index().drop("index", axis=1)

    one_type_cnt = len(train_df['ID'].unique())//6
    balanced = []

    for ich_type in ["epidural", "intraventricular", "subarachnoid", "intraparenchymal", "subdural", "any"]:
        if ich_type != "any":
            id_and_label = train_df[train_df['ID'].isin(train_df[(train_df["Type"] == ich_type) & (train_df["Label"] == 1)]['ID'].unique())][['ID', 'Label']]
        else:
            id_and_label = train_df[train_df['ID'].isin(train_df[(train_df["Type"] == "any") & (train_df["Label"] == 0)]['ID'].unique())][['ID', 'Label']]
        ids = list(id_and_label['ID'][::6])
        labels = np.reshape(id_and_label['Label'].values, (-1, 6))
        to_random = list([(i, l) for i, l in zip(ids, labels)])
        tmp_values = to_random
        while len(to_random) < one_type_cnt:
            to_random.extend(tmp_values)

        rand_values = random.sample(list(to_random), one_type_cnt)
        balanced.extend(rand_values)

    return np.array([i[0] for i in balanced]), np.array([i[1] for i in balanced]), valid_df['ID'].values[::6], np.reshape(valid_df['Label'].values, (-1, 6))

def get_train_valid_tuples(dataframe, train_data, valid_data):
    """
    Returns training and validation data IDs and their labels.
        
    :param dataframe: the dataframe with image IDs and labels
    :param train_data: list with training data studies
    :param valid_data: list with validation data studies
        
    :return: training IDs, training labels, validation IDs, validation labels
    """
    valid_df = dataframe.loc[dataframe['Study'].isin(valid_data)].reset_index().drop("index", axis=1)
    train_df = dataframe.loc[dataframe['Study'].isin(train_data)].reset_index().drop("index", axis=1)
    
    return train_df['ID'].values[::6], np.reshape(train_df['Label'].values, (-1, 6)), valid_df['ID'].values[::6], np.reshape(valid_df['Label'].values, (-1, 6))

In [None]:
def find_neighbors_to_image(image_id, dataframe):
    """
    Finds the neighbouring slices to an image. If there is no neighbours, an array filled with zeros will be returned.
        
    :param image_id: the image ID for which the neighbours will be found
    :param dataframe: the dataframe with image IDs, studies and vertical positions
        
    :return: (bottom neighbour, top neighbour)
    """
    study_id = dataframe[dataframe['ID'] == image_id]["Study"].values[0]
    study_ids = dataframe[dataframe['Study'] == study_id]["ID"].unique()
    positions = []
    for img_id in study_ids:
        positions.append(dataframe[dataframe['ID'] == img_id]["Position"].values[0])
    positions_dict = dict(zip(positions, study_ids))
    sorted_keys = sorted(positions_dict)
    
    pos = dataframe[dataframe['ID'] == image_id]["Position"].values[0]
    pos_index = sorted_keys.index(pos)
    
    b_neigh_pos = sorted_keys[pos_index - 1] if pos_index != 0 else sorted_keys[pos_index]
    b_neigh = positions_dict[b_neigh_pos]
    t_neigh_pos = sorted_keys[pos_index + 1] if pos_index != len(sorted_keys) - 1 else sorted_keys[pos_index]
    t_neigh = positions_dict[t_neigh_pos]
    return b_neigh, t_neigh

In [None]:
def img_preprocessing(image, b_neigh=None, t_neigh=None, dim=(224,224), window=(40,80), apply_clahe=False):
    """
    Preprocess an CT image using correcting HU values, brain segmentation, new spacing application, reshaping and converting to three channel image.
        
    :param image: the image to preprocess
    :param b_neigh: the bottom neighbouring slice
    :param t_neigh: the top neighbouring slice
    :param dim: a tuple of the desired dimensions of the image
    :param window: a tuple representing the window center and window width
    :param apply_clahe: whether to use CLAHE or not
        
    :return: the preprocessed image
    """
    if b_neigh is not None and t_neigh is not None:
        channels = []
        for img in (b_neigh, image, t_neigh):
            pixel_array = img.pixel_array
            pixel_array = hu_to_pixels(pixel_array, img.RescaleIntercept, img.RescaleSlope, window[0], window[1])
            pixel_array = segment_brain(pixel_array)
            pixel_array = apply_new_spacing(pixel_array, np.array(img.PixelSpacing), [1, 1])
            pixel_array = crop_or_reshape(pixel_array, dim)
            if apply_clahe:
                pixel_array = clahe(pixel_array)
            channels.append(pixel_array)
        return to_3_channels(channels[0], channels[1], channels[2])
    else:
        pixel_array = image.pixel_array
        pixel_array = hu_to_pixels(pixel_array, image.RescaleIntercept, image.RescaleSlope, window[0], window[1])
        pixel_array = segment_brain(pixel_array)
        pixel_array = apply_new_spacing(pixel_array, np.array(image.PixelSpacing), [1, 1])
        pixel_array = crop_or_reshape(pixel_array, dim)
        if apply_clahe:
            pixel_array = clahe(pixel_array)
        pixel_array = cv2.cvtColor(pixel_array.astype('uint8'), cv2.COLOR_GRAY2BGR)
        return pixel_array

In [None]:
class DataGenerator(tensorflow.keras.utils.Sequence):
    """
    A class for training and test data generator.
    """
    def __init__(self, list_IDs, labels, img_path, dataframe=None, batch_size=32, dim=(224,224), n_classes=1, shuffle=True, 
                 window=(40, 80), context3d=True, augment=True, clahe=False, noise=False, image_format='jpg'):
        """
        :param list_IDS: the list of IDs of the data to generate
        :param labels: labels corresponding to the IDs
        :param img_path: a path to the images which will be generated
        :param dataframe: a dataframe with the data IDs, studies and position - required only when the data format is dicom
        :param batch_size: the batch size of the input data
        :param dim: the dimensions of the image data
        :param n_classes: the number of predicted classes
        :param shuffle: whether to randomly shuffle data or not
        :param window: the window to use on dicom images
        :param context3d: whether to use 3D context or not
        :param augment: whether to augment the data or not
        :param clahe: whether to apply CLAHE on the images
        :param noise: whether to add noise to the images
        :param image_format: the format of the images - supported are "jpg" and "dicom"
        """
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.img_path = img_path
        self.dataframe = dataframe
        self.n_classes = n_classes
        self.shuffle = shuffle
        
        self.window = window
        self.augmentation = DataAugmentation(123, max_angle=30)
        self.augment = augment
        self.context3d = context3d
        self.clahe = clahe
        self.noise = noise
        self.image_format = image_format
        
        self.on_epoch_end()
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def __data_generation(self, list_IDs_temp, indexes):
        # Initialization
        X = np.empty((self.batch_size, *self.dim, 3))
        y = np.empty((self.batch_size, self.n_classes), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            if IMAGE_FORMAT == "jpg":
                sample = cv2.imread(self.img_path + ID + ".jpg")
            elif IMAGE_FORMAT == "dicom":
                sample = pydicom.filereader.dcmread(self.img_path + ID + ".dcm")
                if self.context3d and self.dataframe is None:
                    print('If you do not include the dataframe, the 3D context will not be generated.')
                    sample = img_preprocessing(sample, None, None, self.dim, self.window, self.clahe)
                elif self.context3d:
                    b_neigh, t_neigh = find_neighbors_to_image(ID, self.dataframe)
                    b_neigh = cv2.imread(self.img_path + b_neigh + ".jpg")
                    t_neigh = cv2.imread(self.img_path + t_neigh + ".jpg")
                    sample = img_preprocessing(sample, b_neigh, t_neigh, self.dim, self.window, self.clahe)
                else:
                    sample = img_preprocessing(sample, None, None, self.dim, self.window, self.clahe)
            else:
                print(f'Format "{self.image_format}" is not supported. Valid formats are "jpg" or "dicom".')
            
            if self.augment:
                X[i,] = self.augmentation.random_augment(sample, self.noise)
            else:
                X[i,] = sample
            y[i] = self.labels[indexes[i]]
        
        return X, y
    
    def __len__(self):
        return int(np.floor(len(self.list_IDs) / self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        X, y = self.__data_generation(list_IDs_temp, indexes)

        return X, y

In [None]:
def weighted_multi_label_log_loss(y_true, y_pred):
    """
    Log loss function with doubled weight on 'any' label.
    
    :param y_true: the ground truth labels
    :param y_pred: the predicted labels
    
    :return: the loss
    """
    weights = [1., 1., 1., 1., 1., 2.]
    eps = tf.keras.backend.epsilon()
    
    new_y_true = tf.cast(y_true, tensorflow.float32)
    new_y_pred = tf.keras.backend.clip(y_pred, eps, 1.0-eps)

    out = -(         new_y_true  * tf.keras.backend.log(      new_y_pred) * weights
            + (1.0 - new_y_true) * tf.keras.backend.log(1.0 - new_y_pred) * weights)
    
    return tf.keras.backend.mean(out, axis=-1)

In [None]:
class ComputeAnyLabel(tf.keras.layers.Layer):
    """
    A layer which computes the 'any' label as the max value of other labels.
    """
    def __init__(self, num_outputs):
        super(ComputeAnyLabel, self).__init__()
        self.num_outputs = num_outputs

    def build(self, input_shape):
        pass

    def call(self, input_tensor):
        input_max = tf.reshape(tf.math.reduce_max(input_tensor,axis=1), (len(input_tensor),1))
        return tf.concat([input_tensor, input_max], 1)

In [None]:
def create_model(base, shape, pooling, optimizer, lr, loss, metrics, weights=None):
    """
    Creates and compiles a CNN model.
    
    :param base: the base architecture to use
    :param shape: the shape of the input data
    :param pooling: the pooling to use in the model
    :param optimizer: the optimizer to use during the model training
    :param lr: the learning rate to use during the model training
    :param loss: the loss function to use during the model training
    :param metrics: the metrics to use for the model evaluation
    :param weights: the initialization weights of the model
    
    :return: a compiled model
    """
    if base == "DenseNet121":
        xbase_model = tf.keras.applications.DenseNet121
    elif base == "DenseNet169":
        xbase_model = tf.keras.applications.DenseNet121
    elif base == "DenseNet201":
        xbase_model = tf.keras.applications.DenseNet121
    elif base == "EfficientNetB0":
        xbase_model = tf.keras.applications.EfficientNetB0
    elif base == "EfficientNetB1":
        xbase_model = tf.keras.applications.EfficientNetB1
    elif base == "EfficientNetB2":
        xbase_model = tf.keras.applications.EfficientNetB2
    elif base == "EfficientNetB3":
        xbase_model = tf.keras.applications.EfficientNetB3
    elif base == "EfficientNetB4":
        xbase_model = tf.keras.applications.EfficientNetB4
    elif base == "EfficientNetB5":
        xbase_model = tf.keras.applications.EfficientNetB5
    elif base == "InceptionV3":
        xbase_model = tf.keras.applications.InceptionV3
    elif base == "InceptionResNetV2":
        xbase_model = tf.keras.applications.InceptionResNetV2
    elif base == "ResNet50":
        xbase_model = tf.keras.applications.ResNet50
    elif base == "ResNet50V2":
        xbase_model = tf.keras.applications.ResNet50V2
    elif base == "ResNet101":
        xbase_model = tf.keras.applications.ResNet101
    elif base == "ResNet101V2":
        xbase_model = tf.keras.applications.ResNet101V2
    elif base == "ResNet152":
        xbase_model = tf.keras.applications.ResNet152
    elif base == "ResNet152V2":
        xbase_model = tf.keras.applications.ResNet152V2
    else:
        print(f'Model {base} is not supported. Please use an another model or modify the create_model() function in model_functions.ipynb.')

    base_model = xbase_model(input_shape=shape, pooling=pooling, include_top=False)
    model = tf.keras.models.Model(inputs=base_model.input, 
                                  outputs=ComputeAnyLabel(6)(tf.keras.layers.Dense(5, activation="sigmoid")(base_model.output)))
    if weights is not None:
        model.load_weights(weights)
    model.compile(optimizer=optimizer(lr), loss=loss, metrics=metrics)
    return model