Acknowledgements: <br>
https://www.kaggle.com/ohbewise/a-rsna-mri-solution-from-dicom-to-submission <br>
U.Baid, et al., “The RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor Segmentation and Radiogenomic Classification”, arXiv:2107.02314, 2021.

## Install and import libraries

In [None]:
# If this line fails please see the prerequisite above
!pip install --quiet --no-index --find-links ../input/pip-download-torchio/ --requirement ../input/pip-download-torchio/requirements.txt

In [None]:
# import libraries
import os
import csv
import pickle
import numpy as np
import pandas as pd
import nibabel as nib
import torchio as tio
import tensorflow as tf
from pathlib import Path
import matplotlib.pyplot as plt

# Parameters to limit the processing power needed.
scan_types    = ['FLAIR','T2w'] # uses all scan types

## Preprocess data: DICOM to normalized NifiTi with TorchIO


In [None]:
# Preprocess data 
data_dir   = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/'
out_dir    = '/kaggle/working/processed'

for dataset in ['train']:
    dataset_dir = f'{data_dir}{dataset}'
    patients = os.listdir(dataset_dir)
    
    # Remove cases the competion host said to exclude 
    # https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/262046
    if '00109' in patients: patients.remove('00109')
    if '00123' in patients: patients.remove('00123')
    if '00709' in patients: patients.remove('00709')
    
    print(f'Total patients in {dataset} dataset: {len(patients)}')

    count = 0
    for patient in patients:
        count = count + 1
        print(f'{dataset}: {count}/{len(patients)}')

        for scan_type in scan_types:
            scan_src  = f'{dataset_dir}/{patient}/{scan_type}/'
            scan_dest = f'{out_dir}/{dataset}/{patient}/{scan_type}/'
            Path(scan_dest).mkdir(parents=True, exist_ok=True)
            image = tio.ScalarImage(scan_src)
            transforms = [
                tio.ToCanonical(),
                tio.Resample(1),
                tio.ZNormalization(masking_method=tio.ZNormalization.mean),
                tio.CropOrPad((128,128,64)),
            ]
            transform = tio.Compose(transforms)
            preprocessed = transform(image)
            preprocessed.save(f'{scan_dest}/{scan_type}.nii.gz')

## Build datasets: NifiTi to Split Dataset with NiBabel

In [None]:
# build datasets

# dataset processing functions
def read_nifti_file(filepath):
    """Read and load volume"""
    # Read file
    scan = nib.load(filepath)
    # Get raw data
    scan = scan.get_fdata()
    return scan

def add_batch_channel(volume):
    """Process validation data by adding a channel."""
    volume = tf.expand_dims(volume, axis=-1)
    volume = tf.expand_dims(volume, axis=0)
    return volume

def process_scan(filepath):
    scan = read_nifti_file(filepath)
    volume = add_batch_channel(scan)
    return volume

# get labels
labels_df = pd.read_csv(data_dir+'train_labels.csv', index_col=0)

# split patients
patients = os.listdir(f'{out_dir}/train')
from sklearn.model_selection import train_test_split
train, validation = train_test_split(patients, test_size=0.3, random_state=42)
print(f'{len(patients)} total patients.\n   {len(train)} in the train split.\n   {len(validation)} in the validation split')

splits_dict = {'train':train, 'validation':validation}

for scan_type in scan_types:
    print(f'{scan_type} start')
    for split_name, split_list in splits_dict.items():
        print(f'   {split_name} start')
        label_list = []
        filepaths = []
        for patient in split_list:
            label = labels_df._get_value(int(patient), 'MGMT_value')
            label = add_batch_channel(label)
            label_list.append(label)
            filepath  = f'{out_dir}/train/{patient}/{scan_type}/{scan_type}.nii.gz'
            filepaths.append(filepath)

        features = np.array([process_scan(filepath) for filepath in filepaths if filepath])
        labels = np.array(label_list, dtype=np.uint8)
        dataset = tf.data.Dataset.from_tensor_slices((features, labels))
        
        # save dataset   
        tf_data_path = f'./datasets/{scan_type}_{split_name}_dataset'
        tf.data.experimental.save(dataset, tf_data_path, compression='GZIP')
        with open(tf_data_path + '/element_spec', 'wb') as out_:  # also save the element_spec to disk for future loading
            pickle.dump(dataset.element_spec, out_)
        print(f'   {split_name} done')
    print(f'{scan_type} done')

## Define, train, and evaluate model:  Dataset to Model with Tensorflow


train test split:

In [None]:
train_df = pd.read_csv(f"{data_directory}/train_labels.csv")
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.2, 
    random_state=12, 
    stratify=train_df["MGMT_value"],
)
df_train.tail()

model definition and training:

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = EfficientNet3D.from_name("efficientnet-b7", override_params={'num_classes': 2}, in_channels=1)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features=1, bias=True)
    
    def forward(self, x):
        out = self.net(x)
        return out

In [None]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion

        self.best_valid_score = np.inf
        self.n_patience = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):        
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_time = self.train_epoch(train_loader)
            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_loss, train_time
            )
            
            self.info_message(
                "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
                n_epoch, valid_loss, valid_auc, valid_time
            )

            # if True:
            #if self.best_valid_score < valid_auc: 
            if self.best_valid_score > valid_loss: 
                self.save_model(n_epoch, save_path, valid_loss, valid_auc)
                self.info_message(
                     "auc improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                    self.best_valid_score, valid_loss, self.lastmodel
                )
                self.best_valid_score = valid_loss
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message("\nValid auc didn't improve last {} epochs.", patience)
                break
            
    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        sum_loss = 0

        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(X).squeeze(1)
            
            loss = self.criterion(outputs, targets)
            loss.backward()

            sum_loss += loss.detach().item()

            self.optimizer.step()
            
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)

                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(torch.sigmoid(outputs).tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        self.lastmodel = f"{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth"
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_mri_type(df_train, df_valid, mri_type):
    if mri_type=="all":
        train_list = []
        valid_list = []
        for mri_type in mri_types:
            df_train.loc[:,"MRI_Type"] = mri_type
            train_list.append(df_train.copy())
            df_valid.loc[:,"MRI_Type"] = mri_type
            valid_list.append(df_valid.copy())

        df_train = pd.concat(train_list)
        df_valid = pd.concat(valid_list)
    else:
        df_train.loc[:,"MRI_Type"] = mri_type
        df_valid.loc[:,"MRI_Type"] = mri_type

    print(df_train.shape, df_valid.shape)
    display(df_train.head())
    
    train_data_retriever = Dataset(
        df_train["BraTS21ID"].values, 
        df_train["MGMT_value"].values, 
        df_train["MRI_Type"].values,
        augment=True
    )

    valid_data_retriever = Dataset(
        df_valid["BraTS21ID"].values, 
        df_valid["MGMT_value"].values,
        df_valid["MRI_Type"].values
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=4,
        shuffle=True,
        num_workers=8,pin_memory = True
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=4,
        shuffle=False,
        num_workers=8,pin_memory = True
    )

    model = Model()
    model.to(device)

    #checkpoint = torch.load("best-model-all-auc0.555.pth")
    #model.load_state_dict(checkpoint["model_state_dict"])

    #print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    criterion = torch_functional.binary_cross_entropy_with_logits

    trainer = Trainer(
        model, 
        device, 
        optimizer, 
        criterion
    )

    history = trainer.fit(
        10, 
        train_loader, 
        valid_loader, 
        f"{mri_type}", 
        10,
    )
    
    return trainer.lastmodel

modelfiles = None

if not modelfiles:
    modelfiles = [train_mri_type(df_train, df_valid, m) for m in mri_types]
    print(modelfiles)

## Write predictions to submission.csv: Model Prediction to Submission

In [None]:
# write predictions to submission.csv

# Set up directories
data_dir   = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/'
test_dir   = f'{data_dir}test'
patients = os.listdir(test_dir)
if demo:
    patients = patients[:10]
print(f'Total patients: {len(patients)}\n\n')

out_dir    = '/kaggle/working/processed'

scan_types = ['FLAIR', 'T2w']

for scan_type in scan_types:
    f = open(f'/kaggle/working/submission.csv', 'w')
    writer = csv.writer(f)
    writer.writerow(['BraTS21ID','MGMT_value'])
    for patient in patients:
        # dicom to nifiti
        scan_src  = f'{test_dir}/{patient}/{scan_type}/'
        scan_dest = f'{out_dir}/test/{patient}/{scan_type}/'
        Path(scan_dest).mkdir(parents=True, exist_ok=True)
        image = tio.ScalarImage(scan_src)  # subclass of Image
        transforms = [
            tio.ToCanonical(),
            tio.Resample(1),
            tio.ZNormalization(masking_method=tio.ZNormalization.mean),
            tio.CropOrPad((128,128,64)),
        ]
        transform = tio.Compose(transforms)
        preprocessed = transform(image)
        filepath = f'{scan_dest}/{scan_type}.nii.gz'
        preprocessed.save(filepath)
        
        # process_scan
        case = process_scan(filepath)

        # tf model
        model = tf.keras.models.load_model(f'./models/{scan_type}')

        # get prediction
        prediction = model.predict(case)
        
        # write prediction
        print(f'{patient},{prediction[0][0]}')
        writer.writerow([patient, prediction[0][0]])

    f.close()