In [1]:
from IPython.core.display import HTML

styles = """
<style>
.section-heading { 
  background:#008080; 
  border:0; 
  color:white; 
  text-align:center; 
  height: 100px; 
  display: flex;  
  justify-content: center;
  align-items: center;
}
</style>
"""
HTML(styles)

<a id="top"></a>

<h2 class="section-heading">
    <span>
        Quick Navigation
    </span>
</h2>

* [Overview](#overview)
* [Data Visualization](#data_viz)
    

* [Competition Metric](#10)
* [Sample Submission](#20)
    

* [Modeling](#modeling)

<a id="overview"></a>

<h2 class="section-heading">
    <span>
      Overview
    </span>
</h2>

* TODO

In [None]:
import pathlib
import pandas as pd
from utils import competition_name, path
import matplotlib.pyplot as plt
import seaborn as sns
import pydicom
import numpy as np
import random
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
import cv2

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import InputLayer, Conv2D, MaxPool2D, Conv2D, Dropout, Flatten, Dense
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from tensorflow.keras.metrics import AUC

sns.set_theme(style="whitegrid", palette="pastel")
%matplotlib inline

In [None]:
MRI_TYPES = ("FLAIR", "T1w", "T1wCE", "T2w")

<a id="extraction"></a>

<h2 class="section-heading">
    <span>
      Extraction
    </span>
</h2>

In [None]:
list(path.glob('*'))

In [None]:
train_path = (path / 'train')
patient_paths = list(train_path.glob('*'))
patient_paths[0:5], len(patient_paths)

In [None]:
training_labels = (pd.read_csv(path / 'train_labels.csv', dtype={'BraTS21ID': str})
                   .set_index('BraTS21ID'))
training_labels.head()

In [None]:
IMAGE_SIZE = 256
NUM_IMAGES = 64
BATCH_SIZE= 4

In [None]:
class Loader:
    ROT_CHOICES = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
    
    def __init__(self, mri_type):
        self.mri_type = mri_type
    
    def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotation=None):
        pixels = pydicom.read_file(path).pixel_array
        
        data  = apply_voi_lut(pixels, dicom) if voi_lut else pixels

        if rotation is not None:
            data = cv2.rotate(data, rotation)

        data = cv2.resize(data, (img_size, img_size))
        return data

    def load_dicom_images_3d(self, scan_id, num_imgs=NUM_IMAGES, img_size=IMAGE_SIZE, split="train", rotation=None):

        filespath = path / split / scan_id / self.mri_type
        files = sorted(filepaths.glob("*.dcm"), key=lambda x: x.stem.split('-')[-1])

        middle = len(files) // 2
        num_imgs2 = num_imgs // 2
        p1 = max(0, middle - num_imgs2)
        p2 = min(len(files), middle + num_imgs2)

        img3d = np.stack([load_dicom_image(f, rotate=rotation) for f in files[p1:p2]]).T

        if img3d.shape[-1] < num_imgs:
            n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
            img3d = np.concatenate((img3d,  n_zero), axis = -1)

        if np.min(img3d) < np.max(img3d):
            img3d = img3d - np.min(img3d)
            img3d = img3d / np.max(img3d)

        return np.expand_dims(img3d,0)

In [None]:
brats21ids_train, brats21ids_test = train_test_split(brats21id_idx, test_size=0.2, random_state=11)

In [None]:
# random.seed(11)

patient_path = random.choice(patient_paths)

brats21id = patient_path.stem

fig, axes = plt.subplots(1, 4, figsize=(16, 5))

for ax, mri_type in zip(axes, MRI_TYPES):
    mri_type_path = patient_path / mri_type
    image_paths = sorted(mri_type_path.glob('*'), key=lambda x: x.stem.split('-')[-1])

    image_path = image_paths[int(len(image_paths) / 2)]
    data = load_dicom(image_path)

    outcome = training_labels.loc[brats21id, 'MGMT_value']

    ax.imshow(data, cmap="gray")
    ax.set_title(mri_type, fontsize=16)
    ax.axis("off")

fig.suptitle(f'Patient #{brats21id}: {outcome}', size=24)
plt.tight_layout()
plt.show()

<a id="data_viz"></a>

<h2 class="section-heading">
    <span>
    Data Visualization
    </span>
</h2>

In [None]:
bar_labels = training_labels.MGMT_value.value_counts(normalize=True).sort_index().map('{:0.1%}'.format).to_list()
training_size = training_labels.shape[0]

ax = sns.countplot(data=training_labels, x='MGMT_value')
sns.despine(right=False)

ax.set_title('MGMT Distribution ({:0.0f} records)'.format(training_size), size=16)
ax.set_xticklabels(('Not Present', 'Present'))
ax.set(xlabel='', ylabel='')

# add bar labels
for p, label in zip(ax.patches, bar_labels):
    ax.annotate(label, (p.get_x()+0.375, p.get_height()+0.15))
    
plt.show()

**NOTE**: There are some unexpected issues with the following three cases in the training dataset, participants can exclude the cases during training: [00109, 00123, 00709]. We have checked and confirmed that the testing dataset is free from such issues.

<a id="modeling"></a>

<h2 class="section-heading">
    <span>
        Modeling
    </span>
</h2>

In [None]:
class Patient:
    MRI_TYPES = ("FLAIR", "T1w", "T1wCE", "T2w")
    TRAINING_LABELS = training_labels
    PATH = path
    TRAIN_PATH = PATH / 'train'
    
    def __init__(self, brats21id):
        self.brats21id = str(brats21id).zfill(5)
        self.outcome = training_labels.loc[self.brats21id, 'MGMT_value']
        self.patient_train_path = self.TRAIN_PATH / self.brats21id
        
    def __repr__(self):
        return f"Patient(brats21id={self.brats21id})"
        
    def load_mri_images(self, mri_type, subset='train'):
        images = list()
        
        mri_type_path = self.patient_train_path / mri_type
        image_paths = sorted(mri_type_path.glob('*'), key=lambda x: x.stem.split('-')[-1])
        
        for image_path in image_paths:
            image = load_dicom(image_path)
            images.append(image)
        
        return images

In [None]:
p = Patient(675)
assert str(p) == 'Patient(brats21id=00675)'
assert p.outcome == 1
images = p.load_mri_images('FLAIR')
assert len(images) == 196

In [None]:
brats21ids = training_labels.index.tolist()
X, y, brats21id_idx = list(), list(), list()
mri_type = 'FLAIR'

for brats21id in tqdm(brats21ids):
    p = Patient(brats21id)
    images = p.load_mri_images(mri_type)
    X += images
    y += [p.outcome] * len(images)
    brats21id_idx += [brats21id] * len(images)

In [None]:
splits = train_test_split(X,  y, brats21id_idx, test_size=0.2, random_state=11)
X_train, X_valid, y_train, y_valid, brats21id_idx_train, brats21id_idx_valid = splits

X_train, X_valid = (tf.expand_dims(x, axis=-1) for x in [X_train, X_valid])

y_train, y_valid = (to_categorical(y) for y in [y_train, y_valid])

In [None]:
X_train[0].shape

In [None]:
model = tf.keras.Sequential(layers=[
    InputLayer(input_shape=(224, 224, 1)),
    Rescaling(1.0/255, name='rescaling_1'),
    Conv2D(64, kernel_size=(2, 2), activation='relu', name='conv2d_1'),
    MaxPool2D((1, 1), name='max_pooling2d_1'),
    Dropout(0.1, name='dropout_1'),
    Flatten(name='flatten'),
    Dense(32, activation='relu', name='dense_1'),
    Dense(2, activation='softmax', name='dense_2')
])

In [None]:
model.summary()

In [None]:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', AUC()])

In [None]:
model.fit(X_train, y_train, validation_split=0.1, batch_size=32, epochs=10)

In [None]:
yhat_valid_probs = model.predict(X_valid)
yhat_valid = np.argmax(yhat_valid_probs, axis=1)

In [None]:
frames = [pd.Series(brats21id_idx_valid), pd.Series(yhat_valid)]
results = (pd.concat(frames, axis=1, keys=['brats21id', 'MGMT_value_pred'])
           .groupby('brats21id', as_index=False).mean()
           .rename(columns={'MGMT_value': 'MGMT_value_preds'})
           .merge(training_labels, left_on='brats21id', right_index=True, how='left')
          )

In [None]:
results.columns

In [None]:
roc_auc_score(results.MGMT_value, results.MGMT_value_pred)