In [None]:
from IPython.core.display import HTML

styles = """
<style>
.section-heading { 
  background:#008080; 
  border:0; 
  color:white; 
  text-align:center; 
  height: 100px; 
  display: flex;  
  justify-content: center;
  align-items: center;
}
</style>
"""
HTML(styles)

<a id="top"></a>

<h2 class="section-heading">
    <span>
        Quick Navigation
    </span>
</h2>

* [Overview](#overview)
* [Data Visualization](#data_viz)
    

* [Competition Metric](#10)
* [Sample Submission](#20)
    

* [Modeling](#modeling)

<a id="overview"></a>

<h2 class="section-heading">
    <span>
      Overview
    </span>
</h2>

* TODO

In [None]:
import pathlib
import pandas as pd
from utils import competition_name, path
import matplotlib.pyplot as plt
import seaborn as sns
import pydicom
import numpy as np
import random
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
import cv2

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (InputLayer, Conv3D, MaxPool3D, Dropout, Flatten, 
                                     BatchNormalization, GlobalAveragePooling3D, Dense)
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.metrics import AUC

sns.set_theme(style="whitegrid", palette="pastel")
%matplotlib inline

<a id="extraction"></a>

<h2 class="section-heading">
    <span>
      Extraction
    </span>
</h2>

In [None]:
train_path = (path / 'train')
labels = (pd.read_csv(path / 'train_labels.csv', dtype={'BraTS21ID': str})
          .set_index('BraTS21ID'))
labels.head()

**NOTE**: There are some unexpected issues with the following three cases in the training dataset, participants can exclude the cases during training: [00109, 00123, 00709]. We have checked and confirmed that the testing dataset is free from such issues.

In [None]:
exclusions = ['00109', '00123', '00709']
mask = labels.index.isin(exclusions)
labels.loc[mask, ]

In [None]:
print(f"Label Count Pre-Removal:  {len(labels):4d}")
labels = labels.loc[~mask, ]
print(f"Label Count Post-Removal: {len(labels):4d}")

Let's split our training and test set first using just the indexes.

`DataGenerator` inspiration from [this blog](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly) from Stanford.

In [None]:
class MRILoader:
    MRI_TYPES = ("FLAIR", "T1w", "T1wCE", "T2w")
    ROT_CHOICES = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
    PATH = path
    
    def __init__(self, mri_type, num_images=64, image_size=256):
        self.mri_type = mri_type
        self.image_size = image_size
        self.num_images = num_images
        
    def create_path(self, patient_id, image_num, split='train'):
        return self.PATH / split / patient_id / self.mri_type / f'Image-{image_num}.dcm'
    
    def load_dicom_image(self, path, voi_lut=True, rotation=None):
        dicom = pydicom.read_file(path)
        
        image = pydicom.read_file(path).pixel_array
        if voi_lut:
            image = pydicom.pixel_data_handlers.util.apply_voi_lut(image, dicom)

        if rotation is not None:
            image = cv2.rotate(image, rotation)

        image = cv2.resize(image, (self.image_size, self.image_size))
        
        return image

    def load_dicom_images_3d(self, patient_id, split="train", rotation=None):

        filepath = path / split / patient_id / self.mri_type
        dicom_filenames = sorted(filepath.glob("*.dcm"), key=lambda x: x.stem.split('-')[-1])
        s = self.images_selector(len(dicom_filenames))
        
        scan_images = np.stack([self.load_dicom_image(f, rotation=rotation) for f in dicom_filenames[s]]).T
        
        if scan_images.shape[-1] < self.num_images:
            cnt_of_images_to_add = self.num_images - scan_images.shape[-1]
            n_zero = np.zeros((self.image_size, self.image_size, cnt_of_images_to_add))
            scan_images = np.concatenate((scan_images,  n_zero), axis = -1)

        if np.min(scan_images) < np.max(scan_images):
            scan_images = scan_images - np.min(scan_images)
            scan_images = scan_images / np.max(scan_images)

        return np.expand_dims(scan_images, -1)
    
    def images_selector(self, num_files):
        middle = num_files // 2
        p1 = max(0, middle - self.num_images // 2)
        p2 = min(num_files, middle + self.num_images // 2)
        return slice(p1, p2)
    

loaders = {mri_type: MRILoader(mri_type) for mri_type in MRILoader.MRI_TYPES}

In [None]:
ldr = MRILoader('FLAIR')
dicom_path = ldr.create_path('00675', 90)
ldr.load_dicom_image(dicom_path)
image = ldr.load_dicom_images_3d("00100")
image.shape

<a id="data_generator"></a>

<h2 class="section-heading">
    <span>
      Data Generator
    </span>
</h2>

`DataGenerator` inspiration from [this blog](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly) from Stanford.

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, mri_type, brats21ids, labels, batch_size=4, dim=(128, 128, 64), n_channels=1, shuffle=True):
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.brats21ids = brats21ids
        self.n_channels = n_channels
        self.shuffle = shuffle
        
        self.mri_type = mri_type
        self.loader = MRILoader(mri_type, image_size=dim[0])
        
        self.on_epoch_end()
    
    def __len__(self):
        """
        Denotes the number of batches per epoch
        Believe this is used to define the iterator length to pass for indexes in __getitem__
        """
        return len(self.brats21ids) // self.batch_size
    
    def __getitem__(self, index):
        """
        Generates one batch of data
        """
        start, stop = index * self.batch_size, (index + 1) * self.batch_size
        indexes = self.indexes[start:stop]
        
        # Find list of IDs
        batch_brats21ids = (self.brats21ids[i] for i in indexes)
        
        # Generate data        
        X, y = self.__generate_data(batch_brats21ids)
        
        return X, y
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.brats21ids))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
    
    def __generate_data(self, batch_brats21ids):
        """
        Produces batches of data. 
        Takes as argument the list of IDs of the target batch
        """
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, 1), dtype=int)
        
        for i, brats21id in enumerate(batch_brats21ids):
            X[i,] = self.loader.load_dicom_images_3d(brats21id)
            y[i] = self.labels.loc[brats21id, 'MGMT_value']
        
        return X, y

In [None]:
brats21ids_train, brats21ids_valid = train_test_split(labels.index, test_size=0.18, random_state=11)
print(f'{"Training Size:":20}{len(brats21ids_train): >6d}\n'
      f'{"Validation Size:":20}{len(brats21ids_valid): >6d}')

<a id="data_viz"></a>

<h2 class="section-heading">
    <span>
    Data Visualization
    </span>
</h2>

In [None]:
# random.seed(11)


brats21id = labels.sample(1).index[0]
patient_path = path / 'train' / brats21id

brats21id = patient_path.stem

fig, axes = plt.subplots(1, 4, figsize=(16, 5))

for ax, mri_type in zip(axes, MRILoader.MRI_TYPES):
    ldr = loaders[mri_type]
    mri_type_path = patient_path / mri_type
    image_paths = sorted(mri_type_path.glob('*'), key=lambda x: x.stem.split('-')[-1])

    image_path = image_paths[len(image_paths) // 2]
    data = ldr.load_dicom_image(image_path)

    outcome = labels.loc[brats21id, 'MGMT_value']

    ax.imshow(data, cmap="gray")
    ax.set_title(mri_type, fontsize=16)
    ax.axis("off")

fig.suptitle(f'Patient #{brats21id}: {outcome}', size=24)
plt.tight_layout()
plt.show()

In [None]:
bar_labels = (labels.MGMT_value.value_counts(normalize=True)
              .sort_index()
              .map('{:0.1%}'.format)
              .to_list())
training_size = labels.shape[0]

ax = sns.countplot(data=labels, x='MGMT_value')
sns.despine(right=False)

ax.set_title('MGMT Distribution ({:0.0f} records)'.format(training_size), size=16)
ax.set_xticklabels(('Not Present', 'Present'))
ax.set(xlabel='', ylabel='')

# add bar labels
for p, label in zip(ax.patches, bar_labels):
    ax.annotate(label, (p.get_x()+0.375, p.get_height()+0.15))
    
plt.show()

In [None]:
patient_dirs = train_path.glob('*')

scan_sizes = dict()

records = dict()

for patient_dir in patient_dirs:
    record = dict()
    for mri_type in MRILoader.MRI_TYPES:
        images_path = patient_dir / mri_type
        record[mri_type] = len(list(images_path.glob('*.dcm')))
    
    records[patient_dir.name] = record

In [None]:
df_mri_image_counts = pd.DataFrame(records).T

ax = sns.boxplot(data=df_mri_image_counts)
ax.set_title('Distribution of MRI Scan Image Counts')

df_mri_image_counts.describe().applymap('{:,.0f}'.format)

In [None]:
df_mri_image_counts.loc[df_mri_image_counts.FLAIR < 60, ]

In [None]:
sns.histplot(data=df_mri_image_counts, x='FLAIR')
plt.show()

df_mri_image_counts.loc[:, 'FLAIR'].value_counts().nlargest(5)

In [None]:
df_mri_image_counts.sample(5)

**NOTE**: There are some unexpected issues with the following three cases in the training dataset, participants can exclude the cases during training: [00109, 00123, 00709]. We have checked and confirmed that the testing dataset is free from such issues.

<a id="modeling"></a>

<h2 class="section-heading">
    <span>
        Modeling
    </span>
</h2>

In [None]:
def create_model(image_size):
    inputs = keras.Input((image_size, image_size, 64, 1))
        
    X = Conv3D(64, kernel_size=3, activation='relu')(inputs)
    X = MaxPool3D(pool_size=2)(X)
    X = BatchNormalization()(X)

    X = Conv3D(128, kernel_size=3, activation='relu')(X)
    X = MaxPool3D(pool_size=2)(X)
    X = BatchNormalization()(X)

    X = Conv3D(256, kernel_size=3, activation='relu')(X)
    X = MaxPool3D(pool_size=2)(X)
    X = BatchNormalization()(X)
                
    X = Conv3D(512, kernel_size=3, activation='relu')(X)
    X = MaxPool3D(pool_size=2)(X)
    X = BatchNormalization()(X)

    X = GlobalAveragePooling3D()(X)
    X = Dense(units=1024, activation='relu')(X)
    X = Dropout(0.3)(X)

    outputs = Dense(units=1, activation='sigmoid')(X) 
    
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

In [None]:
model = create_model(128)
model.summary()

#### Compile Model

In [None]:
initial_learning_rate = 0.0001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy', AUC(name='auc')])

#### Callbacks

In [None]:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor="auc", patience=15)

#### Data Generators

In [None]:
mri_type = 'FLAIR'
sizes = dict(
    batch_size = 2,
    dim = (128, 128, 64)
)
training_set_generator = DataGenerator(mri_type, brats21ids_train, labels, **sizes)
validation_set_generator = DataGenerator(mri_type, brats21ids_valid, labels, **sizes)

#### Run Models

In [None]:
model.fit(training_set_generator,
          validation_data=validation_set_generator,
          # use_multiprocessing=True, # can't use these as there is a multiprocessing ipython issue :\
          # workers=6,
          epochs=5, 
          verbose=1,
          shuffle=True,
          callbacks=[early_stopping_cb, checkpoint_cb])

In [None]:
yhat_valid_probs = model.predict()

In [None]:
frames = [pd.Series(brats21id_idx_valid), pd.Series(yhat_valid)]
results = (pd.concat(frames, axis=1, keys=['brats21id', 'MGMT_value_pred'])
           .groupby('brats21id', as_index=False).mean()
           .rename(columns={'MGMT_value': 'MGMT_value_preds'})
           .merge(training_labels, left_on='brats21id', right_index=True, how='left')
          )

In [None]:
results.columns

In [None]:
roc_auc_score(results.MGMT_value, results.MGMT_value_pred)