In [None]:
from IPython.core.display import HTML

styles = """
<style>
.section-heading { 
  background:#008080; 
  border:0; 
  color:white; 
  text-align:center; 
  height: 100px; 
  display: flex;  
  justify-content: center;
  align-items: center;
}
</style>
"""
HTML(styles)

<a id="top"></a>

<h2 class="section-heading">
    <span>
        Quick Navigation
    </span>
</h2>

* [Overview](#overview)
* [Data Visualization](#data_viz)
    

* [Competition Metric](#10)
* [Sample Submission](#20)
    

* [Modeling](#modeling)

<a id="overview"></a>

<h2 class="section-heading">
    <span>
      Overview
    </span>
</h2>

* TODO

In [None]:
import pathlib
import pandas as pd
from utils import competition_name, path
import matplotlib.pyplot as plt
import seaborn as sns
import pydicom
import numpy as np
import random
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
import cv2

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import (InputLayer, Conv3D, MaxPool3D, Dropout, Flatten, 
                                     BatchNormalization, GlobalAveragePooling3D, Dense)
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.metrics import AUC

sns.set_theme(style="whitegrid", palette="pastel")
%matplotlib inline

<a id="extraction"></a>

<h2 class="section-heading">
    <span>
      Extraction
    </span>
</h2>

In [None]:
list(path.glob('*'))

In [None]:
train_path = (path / 'train')
patient_paths = list(train_path.glob('*'))
patient_paths[0:5], len(patient_paths)

In [None]:
training_labels = (pd.read_csv(path / 'train_labels.csv', dtype={'BraTS21ID': str})
                   .set_index('BraTS21ID'))
training_labels.head()

**NOTE**: There are some unexpected issues with the following three cases in the training dataset, participants can exclude the cases during training: [00109, 00123, 00709]. We have checked and confirmed that the testing dataset is free from such issues.

In [None]:
exclusions = ['00109', '00123', '00709']
mask = training_labels.index.isin(exclusions)
training_labels.loc[mask, ]

In [None]:
training_labels = training_labels.loc[~mask, ]
training_labels.shape

Let's split our training and test set first using just the indexes.

In [None]:
class Loader:
    MRI_TYPES = ("FLAIR", "T1w", "T1wCE", "T2w")
    ROT_CHOICES = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
    PATH = path
    
    def __init__(self, mri_type, num_images=64, image_size=128):
        self.mri_type = mri_type
        self.image_size = image_size
        self.num_images = num_images
        
    def create_path(self, patient_id, image_num, split='train'):
        return self.PATH / split / patient_id / self.mri_type / f'Image-{image_num}.dcm'
    
    def load_dicom_image(self, path, voi_lut=True, rotation=None):
        dicom = pydicom.read_file(path)
        
        image = pydicom.read_file(path).pixel_array
        if voi_lut:
            image = pydicom.pixel_data_handlers.util.apply_voi_lut(image, dicom)

        if rotation is not None:
            image = cv2.rotate(image, rotation)

        image = cv2.resize(image, (self.image_size, self.image_size))
        
        return image

    def load_dicom_images_3d(self, patient_id, split="train", rotation=None):

        filepath = path / split / patient_id / self.mri_type
        dicom_filenames = sorted(filepath.glob("*.dcm"), key=lambda x: x.stem.split('-')[-1])
        s = self.images_selector(len(dicom_filenames))
        
        scan_images = np.stack([self.load_dicom_image(f, rotation=rotation) for f in dicom_filenames[s]]).T
        
        if scan_images.shape[-1] < self.num_images:
            cnt_of_images_to_add = self.num_images - scan_images.shape[-1]
            n_zero = np.zeros((self.image_size, self.image_size, cnt_of_images_to_add))
            scan_images = np.concatenate((scan_images,  n_zero), axis = -1)

        if np.min(scan_images) < np.max(scan_images):
            scan_images = scan_images - np.min(scan_images)
            scan_images = scan_images / np.max(scan_images)

        return np.expand_dims(scan_images, -1)
    
    def images_selector(self, num_files):
        middle = num_files // 2
        p1 = max(0, middle - self.num_images // 2)
        p2 = min(num_files, middle + self.num_images // 2)
        return slice(p1, p2)
    

loaders = {mri_type: Loader(mri_type) for mri_type in Loader.MRI_TYPES}

In [None]:
ldr = Loader('FLAIR')
dicom_path = ldr.create_path('00675', 90)
ldr.load_dicom_image(dicom_path)
image = ldr.load_dicom_images_3d("00100")
image.shape

In [None]:
# random.seed(11)

patient_path = random.choice(patient_paths)

brats21id = patient_path.stem

fig, axes = plt.subplots(1, 4, figsize=(16, 5))

for ax, mri_type in zip(axes, Loader.MRI_TYPES):
    ldr = loaders[mri_type]
    mri_type_path = patient_path / mri_type
    image_paths = sorted(mri_type_path.glob('*'), key=lambda x: x.stem.split('-')[-1])

    image_path = image_paths[len(image_paths) // 2]
    data = ldr.load_dicom_image(image_path)

    outcome = training_labels.loc[brats21id, 'MGMT_value']

    ax.imshow(data, cmap="gray")
    ax.set_title(mri_type, fontsize=16)
    ax.axis("off")

fig.suptitle(f'Patient #{brats21id}: {outcome}', size=24)
plt.tight_layout()
plt.show()

<a id="data_viz"></a>

<h2 class="section-heading">
    <span>
    Data Visualization
    </span>
</h2>

In [None]:
bar_labels = training_labels.MGMT_value.value_counts(normalize=True).sort_index().map('{:0.1%}'.format).to_list()
training_size = training_labels.shape[0]

ax = sns.countplot(data=training_labels, x='MGMT_value')
sns.despine(right=False)

ax.set_title('MGMT Distribution ({:0.0f} records)'.format(training_size), size=16)
ax.set_xticklabels(('Not Present', 'Present'))
ax.set(xlabel='', ylabel='')

# add bar labels
for p, label in zip(ax.patches, bar_labels):
    ax.annotate(label, (p.get_x()+0.375, p.get_height()+0.15))
    
plt.show()

In [None]:
patient_dirs = train_path.glob('*')
patient_dir = next(patient_dirs)
patient_dir.name

In [None]:
patient_dirs = train_path.glob('*')

scan_sizes = dict()

records = dict()

for patient_dir in patient_dirs:
    record = dict()
    for mri_type in Loader.MRI_TYPES:
        images_path = patient_dir / mri_type
        record[mri_type] = len(list(images_path.glob('*.dcm')))
    
    records[patient_dir.name] = record

In [None]:
df_mri_image_counts = pd.DataFrame(records).T

ax = sns.boxplot(data=df_mri_image_counts)
ax.set_title('Distribution of MRI Scan Image Counts')

df_mri_image_counts.describe().applymap('{:,.0f}'.format)

In [None]:
df_mri_image_counts.loc[df_mri_image_counts.FLAIR < 60, ]

In [None]:
sns.countplot(data=df_mri_image_counts, x='FLAIR')

In [None]:
df_mri_image_counts.sample(5)

**NOTE**: There are some unexpected issues with the following three cases in the training dataset, participants can exclude the cases during training: [00109, 00123, 00709]. We have checked and confirmed that the testing dataset is free from such issues.

<a id="modeling"></a>

<h2 class="section-heading">
    <span>
        Modeling
    </span>
</h2>

In [None]:
class Patient:
    MRI_TYPES = ("FLAIR", "T1w", "T1wCE", "T2w")
    TRAINING_LABELS = training_labels
    PATH = path
    TRAIN_PATH = PATH / 'train'
    
    def __init__(self, brats21id):
        self.brats21id = str(brats21id).zfill(5)
        self.outcome = training_labels.loc[self.brats21id, 'MGMT_value']
        self.patient_train_path = self.TRAIN_PATH / self.brats21id
        
    def __repr__(self):
        return f"Patient(brats21id={self.brats21id})"
        
    """
    def load_mri_images(self, mri_type, subset='train'):
        images = list()
        
        mri_type_path = self.patient_train_path / mri_type
        image_paths = sorted(mri_type_path.glob('*'), key=lambda x: x.stem.split('-')[-1])
        
        for image_path in image_paths:
            image = load_dicom(image_path)
            images.append(image)
        
        return images
    """

In [None]:
p = Patient(675)
assert str(p) == 'Patient(brats21id=00675)'
assert p.outcome == 1
# images = p.load_mri_images('FLAIR')
# assert len(images) == 196

In [None]:
brats21ids_idx = training_labels.index.tolist()
training_cnt = len(brats21ids_idx)
X, y, brats21ids = np.zeros((training_cnt, 128, 128, 64, 1)), np.zeros((training_cnt, 1)),  np.zeros((training_cnt, 1))
mri_type = 'FLAIR'
ldr = loaders[mri_type]

for i, brats21id in enumerate(tqdm(brats21ids_idx)):
    p = Patient(brats21id)
    X[i] = ldr.load_dicom_images_3d(brats21id)
    y[i] = p.outcome
    brats21id_idx = brats21id

In [None]:
splits = train_test_split(X,  y, brats21ids, test_size=0.2, random_state=11)
X_train, X_valid, y_train, y_valid, brats21ids_train, brats21ids_valid = splits

In [None]:
model = tf.keras.Sequential(layers=[
    InputLayer(input_shape=(128, 128, 64, 1)),
    # Rescaling(1.0/255, name='rescaling_1'),
    
    Conv3D(64, kernel_size=3, activation='relu', name='conv3d_1'),
    MaxPool3D(pool_size=2, name='max_pooling3d_1'),
    BatchNormalization(),
    
    Conv3D(64, kernel_size=3, activation='relu', name='conv3d_2'),
    MaxPool3D(pool_size=2, name='max_pooling3d_2'),
    BatchNormalization(),
    
    Conv3D(128, kernel_size=3, activation='relu', name='conv3d_3'),
    MaxPool3D(pool_size=2, name='max_pooling3d_3'),
    BatchNormalization(),
    
    Conv3D(256, kernel_size=3, activation='relu', name='conv3d_4'),
    MaxPool3D(pool_size=2, name='max_pooling3d_4'),
    BatchNormalization(),
    
    GlobalAveragePooling3D(),
    Dense(units=512, activation='relu'),
    Dropout(0.3),
    
    Dense(units=1, activation='sigmoid')    
])

In [None]:
model.summary()

In [None]:
??model.fit

In [None]:
initial_learning_rate = 0.0001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy', AUC(name='auc')])
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor="auc", patience=15)

In [None]:
model.fit(X_train, y_train, validation_data=(X_valid, y_valid), batch_size=3, epochs=100, callbacks=[checkpoint_cb, early_stopping_cb])

In [None]:
yhat_valid_probs = model.predict(X_valid)

In [None]:
frames = [pd.Series(brats21id_idx_valid), pd.Series(yhat_valid)]
results = (pd.concat(frames, axis=1, keys=['brats21id', 'MGMT_value_pred'])
           .groupby('brats21id', as_index=False).mean()
           .rename(columns={'MGMT_value': 'MGMT_value_preds'})
           .merge(training_labels, left_on='brats21id', right_index=True, how='left')
          )

In [None]:
results.columns

In [None]:
roc_auc_score(results.MGMT_value, results.MGMT_value_pred)