**Note**: network architecture and original data split strategy are based on [this noteboook](https://www.kaggle.com/mmellinger66/brain-tumor-basic-tensorflow-model).

In this notebook, I want to demonstrate the importance of splitting data properly, especially when using 2d cnn where each data point is a slice. If we split the data normally, the validation AUC can be deceptively high because similar slices of the same patient can appear in both train and validation set. Thus we should split the data based on patient id. Here are some noteable changes I made compared to the original notebook:
1. Use split-by-patient strategy on top of their original split strategy (I called it holdout sets in this notebook)
2. Reimplement using pytorch
3. Prediction for a patient is the average of slices probabilities (instead of the average of slices 0-1 predictions in the original notebook)
4. Track both patient AUC and slice AUC when training
5. Calculate out-of-fold AUC and average AUC across all folds to get a more reliable result
6. Remove the result rounding part because I didn't find it meaningful

In [1]:
import sys
sys.path.append('../input/monai-v070')

In [2]:
import os
import cv2
import glob
import pydicom
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import time
import datetime
from dataclasses import dataclass, field
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.metrics import roc_auc_score
from copy import deepcopy

from monai.data import CacheDataset, DataLoader
from monai.transforms import *

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [3]:
DATA_DIR = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/'
MRI_TYPES = ["FLAIR", "T1w", "T2w", "T1wCE"]

## Dataset

In [4]:
class BrainTumorDataset(CacheDataset):
    def __init__(self, root_dir, patient_ids, mri_types, annotations, section, *args, **kwargs):
        self.root_dir = root_dir
        self.patient_ids = patient_ids
        self.mri_types = mri_types
        self.annotations = annotations
        data = self.get_data()
        if section is not None:
            train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
            data = train_data if section=='train' else val_data
        super(BrainTumorDataset, self).__init__(data, *args, **kwargs)
    
    def get_data(self):
        data = []
        for patient_id in tqdm(self.patient_ids):
            if self.annotations is not None:
                label = self.annotations[self.annotations['BraTS21ID'] 
                                         == int(patient_id)]['MGMT_value'].item()
            else:
                label = 0 # dummy value
            for slice_path in self.get_patient_slice_paths(patient_id):
                data.append({
                    'image': slice_path,
                    'label': label,
                    'patient_id': patient_id
                })
        return data
    
    def get_patient_slice_paths(self, patient_id):
        '''
        Returns an array of all the images of a particular type for a particular patient ID
        '''
        assert(set(self.mri_types) <= set(MRI_TYPES))
        patient_path = os.path.join(self.root_dir, str(patient_id).zfill(5))
        patient_slice_paths = []
        for mri_type in self.mri_types:
            paths = sorted(
                glob.glob(os.path.join(patient_path, mri_type, "*.dcm")), 
                key=lambda x: int(x[:-4].split("-")[-1]),
            )

            num_images = len(paths)
            start = int(num_images * 0.25)
            end = int(num_images * 0.75)

            interval = 3
            if num_images < 10: 
                interval = 1
            patient_slice_paths.extend(paths[start:end:interval])
        return patient_slice_paths
    
class LoadDicomd(MapTransform):
    def __init__(self, img_size, *args, **kwargs):
        self.img_size = img_size
        super(LoadDicomd, self).__init__(*args, **kwargs)
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            d[key] = self.load_dicom(d[key])
        return d

    def load_dicom(self, path):
        ''' 
        Reads a DICOM image, standardizes so that the pixel values are between 0 and 1, 
        then rescales to 0 and 255
        '''
        dicom = pydicom.read_file(path)
        data = dicom.pixel_array
        if np.max(data) != 0:
            data = data / np.max(data)
        data = (data * 255).astype(np.uint8)
        data = cv2.resize(data, (self.img_size, self.img_size)) / 255
        return np.expand_dims(data, axis=0)

## Model

In [5]:
class Simple2dCNN(nn.Module):
    def __init__(self, 
                 input_channels=1, 
                 n_classes=2, 
                 img_size=32, 
                 conv1_filters=128,
                 conv2_filters=64,
                 dropout_prob=0.1,
                 fc1_units=48):
        super(Simple2dCNN, self).__init__()
        
        self.relu = nn.ReLU()
        
        self.conv1 = nn.Conv2d(input_channels, conv1_filters, 4)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(conv1_filters, conv2_filters, 2)
        self.maxpool2 = nn.MaxPool2d(1)
        
        self.dropout = nn.Dropout(dropout_prob)
        last_feature_map_size = (img_size - 3) // 2 - 1
        self.fc1 = nn.Linear(conv2_filters * last_feature_map_size**2, fc1_units)
        self.fc2 = nn.Linear(fc1_units, n_classes)

    def forward(self, x):
        # (None, 1, 32, 32)
        x = self.relu(self.conv1(x)) # (None, 128, 29, 29)
        x = self.maxpool1(x) # (None, 128, 14, 14)
        
        x = self.relu(self.conv2(x)) # (None, 64, 13, 13)
        x = self.maxpool2(x) # (None, 64, 13, 13)
        
        x = self.dropout(x)
        x = x.view(x.size(0), -1) # (None, 64 * 13 * 13)
        x = self.relu(self.fc1(x)) # (None, 48)
        x = self.fc2(x) # (None, 2)
        return x

## Pipeline

In [6]:
@dataclass
class Config:
    train_dir: str = os.path.join(DATA_DIR, 'train')
    test_dir: str = os.path.join(DATA_DIR, 'test')
    annotation_path: str = os.path.join(DATA_DIR, 'train_labels.csv')
    n_classes: int = 2
    img_size: int = 32
    n_workers: int = 4
    early_stopping_rounds: int = 3
    n_folds: int = 5
        
        
class Pipeline:
    def __init__(self, config):
        self.args = config
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.annotations = None
        self.model = None
        self.load_model()
        # transforms
        self.preaugment_transform = [
            LoadDicomd(keys="image", img_size=self.args.img_size),
        ]
        self.augment_transform = [] # todo: add some augmentations
        self.postaugment_transform = [
            ToTensord(keys="image", dtype=torch.float),
            ToTensord(keys="label", dtype=torch.int64),
        ]
        
    def load_annotations(self):
        self.annotations = pd.read_csv(self.args.annotation_path)
        # exclude 3 cases
        self.annotations = self.annotations[~self.annotations['BraTS21ID'].isin([109, 123, 709])]
        self.annotations = self.annotations.reset_index(drop=True)
        skf = StratifiedKFold(n_splits=self.args.n_folds, shuffle=True, random_state=42)
        # split by patient, stratify based on target value
        folds = skf.split(self.annotations['BraTS21ID'].values, self.annotations['MGMT_value'].values)
        for i, (train_indices, val_indices) in enumerate(folds):
            self.annotations.loc[val_indices, 'fold'] = i
        self.annotations['fold'] = self.annotations['fold'].astype(int)
    
    def load_model(self, weights_path=None):
        self.model = Simple2dCNN(input_channels=1, 
                                 n_classes=self.args.n_classes,
                                 img_size=self.args.img_size).to(self.device)
        if weights_path:
            weights = torch.load(weights_path, map_location=self.device)
            self.model.load_state_dict(weights)
        
    def prepare_datasets(self, mri_types, fold, cache_rate):
        """
        Data format:
        {
            'image': torch tensor (batch_size, 1, 32, 32),
            'label': torch tensor (batch_size, )
            'patient_id'
        }
        Output: torch tensor (batch_size, 2)
        """
        train_transform = Compose(
            self.preaugment_transform +
            self.augment_transform +
            self.postaugment_transform
        )
        val_transform = Compose(
            self.preaugment_transform +
            self.postaugment_transform
        )
        
        train_ids = self.annotations[self.annotations['fold']!=fold]['BraTS21ID'].values.tolist()
        val_holdout_ids = self.annotations[self.annotations['fold']==fold]['BraTS21ID'].values.tolist()
        
        train_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                     patient_ids=train_ids, 
                                     mri_types=mri_types,  
                                     annotations=self.annotations,
                                     transform=train_transform,
                                     section='train',
                                     cache_rate=cache_rate,
                                     num_workers=self.args.n_workers)
        val_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                   patient_ids=train_ids, 
                                   mri_types=mri_types,  
                                   annotations=self.annotations,
                                   transform=val_transform,
                                   section='val',
                                   cache_rate=cache_rate,
                                   num_workers=self.args.n_workers)
        val_holdout_ds = BrainTumorDataset(root_dir=self.args.train_dir, 
                                           patient_ids=val_holdout_ids, 
                                           mri_types=mri_types, 
                                           annotations=self.annotations, 
                                           transform=val_transform,
                                           section=None,
                                           cache_rate=cache_rate,
                                           num_workers=self.args.n_workers)
        return train_ds, val_ds, val_holdout_ds
    
    def prepare_test_dataset(self, mri_types, cache_rate):
        test_transform = Compose(
            self.preaugment_transform +
            self.postaugment_transform
        )
        test_ids = [int(patient_id) for patient_id in os.listdir(self.args.test_dir)]
        test_ids = sorted(test_ids, key=lambda x: int(x))
        test_ds = BrainTumorDataset(root_dir=self.args.test_dir, 
                                    patient_ids=test_ids, 
                                    mri_types=mri_types, 
                                    annotations=None, 
                                    transform=test_transform,
                                    section=None,
                                    cache_rate=cache_rate,
                                    num_workers=self.args.n_workers)
        return test_ds
    
    def train_epoch(self, loader, loss_function, optimizer, verbose):
        self.model.train()
        summary_loss = AverageMeter()
        start = time.time()
        n = len(loader)
        for step, batch_data in enumerate(loader):
            inputs, labels = (
                batch_data["image"].to(self.device), # (None, 1, 32, 32)
                batch_data["label"].to(self.device), # (None, )
            )
            batch_size = inputs.size(0)
            # back propagation
            optimizer.zero_grad()
            outputs = self.model(inputs) # (None, 2)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            # update stats
            summary_loss.update(loss.item(), batch_size)
            if verbose:
                print('Train step {}/{}, loss: {:.5f}'.format(step + 1, n, 
                                                              summary_loss.avg), end='\r')
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Train loss: {:.5f} - time: {}'.format(summary_loss.avg, elapsed_time))
        return summary_loss.avg
    
    def evaluate_epoch(self, loader, loss_function, verbose):
        self.model.eval()
        summary_loss = AverageMeter()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        labels_all = []
        with torch.no_grad():
            for step, batch_data in enumerate(loader):
                inputs, labels, patient_ids = (
                    batch_data["image"].to(self.device), # (None, 1, 32, 32)
                    batch_data["label"].to(self.device), # (None, )
                    batch_data["patient_id"], # (None, )
                )
                batch_size = inputs.size(0)
                # back propagation
                outputs = self.model(inputs) # (None, 2)
                loss = loss_function(outputs, labels)
                # update stats
                probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
                probabilities_all.extend(probabilities)
                labels_all.extend(labels.tolist())
                patient_ids_all.extend(patient_ids)
                
                summary_loss.update(loss.item(), batch_size)
                if verbose:
                    print('Val step {}/{}, loss: {:.5f}'.format(step + 1, n, 
                                                                summary_loss.avg), end='\r')
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Val loss: {:.5f} - time: {}'.format(summary_loss.avg, elapsed_time))
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
            'label': labels_all
        }
        result = pd.DataFrame(result)
        slice_auc = roc_auc_score(result['label'], result['probability'])
        result = result.groupby("BraTS21ID", as_index=False).mean()
        patient_auc = roc_auc_score(result['label'], result['probability'])
        print('Patient AUC: {:.5f} - Slice AUC: {:.5f}'.format(patient_auc, slice_auc))
        
        return summary_loss.avg, patient_auc, result
    
    def infer_epoch(self, loader, verbose):
        self.model.eval()
        start = time.time()
        n = len(loader)
        patient_ids_all = []
        probabilities_all = []
        with torch.no_grad():
            for step, batch_data in enumerate(loader):
                inputs, patient_ids = (
                    batch_data["image"].to(self.device), # (None, 1, 32, 32)
                    batch_data["patient_id"], # (None, )
                )
                batch_size = inputs.size(0)
                # forward
                outputs = self.model(inputs) # (None, 2)
                # update stats
                probabilities = F.softmax(outputs, dim=1)[:, 1].tolist()
                probabilities_all.extend(probabilities)
                patient_ids_all.extend(patient_ids)
                if verbose:
                    print('Infer step {}/{}'.format(step + 1, n), end='\r')
        
        result = {
            'BraTS21ID': list(map(lambda x: x.item(), patient_ids_all)), 
            'probability': probabilities_all,
        }
        result = pd.DataFrame(result)
        result = result.groupby("BraTS21ID", as_index=False).mean()
        
        elapsed_time = str(datetime.timedelta(seconds=time.time() - start))
        print('Elapsed time: {}'.format(elapsed_time))
        
        return result
    
    def fit(self, train_ds, val_ds, val_holdout_ds, batch_size, epochs, lr, model_name, verbose):
        train_loader = DataLoader(train_ds, 
                                  batch_size=batch_size, 
                                  shuffle=True,
                                  num_workers=self.args.n_workers)
        val_loader = DataLoader(val_ds, 
                                batch_size=batch_size, 
                                shuffle=False,
                                num_workers=self.args.n_workers)
        val_holdout_loader = DataLoader(val_holdout_ds, 
                                        batch_size=batch_size, 
                                        shuffle=False,
                                        num_workers=self.args.n_workers)
        loss_function = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        current_metric = -np.inf
        current_loss = np.inf
        current_epoch = 1
        current_state_dict = None
        save_path = '{}_imgsize{}_valloss{:.3f}_valauc{:.3f}.pth'
        for epoch in range(1, epochs + 1):
            print('\nEpoch {}/{}:'.format(epoch, epochs))
            train_loss = self.train_epoch(train_loader, loss_function, optimizer, verbose)
            print(' Validation:')
            val_loss, val_metric, _ = self.evaluate_epoch(val_loader, loss_function, verbose)
            print(' Hold out:')
            val_holdout_loss, val_holdout_metric, _ = self.evaluate_epoch(val_holdout_loader, 
                                                                          loss_function, 
                                                                          verbose)
            
#             if val_loss < current_loss:
            if val_metric > current_metric:
                print('Val AUC improved from {:.5f} to {:.5f}'.format(current_metric, val_metric))
                current_metric = val_metric
                current_loss = val_loss
                current_epoch = epoch
                current_state_dict = deepcopy(self.model.state_dict())
                
            elif (epoch - current_epoch) > self.args.early_stopping_rounds:
                print('Early stopping. Best model is epoch {}'.format(current_epoch))
                print('Val loss: {:.5f}, Val auc: {:.5f}'.format(current_loss, current_metric))
                print('Saving model...')
                torch.save(current_state_dict, 
                           save_path.format(model_name,
                                            self.args.img_size, 
                                            current_loss, 
                                            current_metric))
                break
            if epoch == epochs:
                print('Finished training. Best model is epoch {}'.format(current_epoch))
                print('Val loss: {:.5f}, Val auc: {:.5f}'.format(current_loss, current_metric))
                print('Saving model...')
                torch.save(current_state_dict, 
                           save_path.format(model_name,
                                            self.args.img_size, 
                                            current_loss, 
                                            current_metric))
                
    def evaluate(self, val_holdout_ds, batch_size, verbose):
        val_holdout_loader = DataLoader(val_holdout_ds, 
                                        batch_size=batch_size, 
                                        shuffle=False,
                                        num_workers=self.args.n_workers)
        loss_function = nn.CrossEntropyLoss()
        print(' Hold out:')
        _, val_holdout_metric, val_holdout_result = self.evaluate_epoch(val_holdout_loader, 
                                                                        loss_function, 
                                                                        verbose)
        return val_holdout_metric, val_holdout_result
    
    def predict(self, test_ds, batch_size, verbose):
        test_loader = DataLoader(test_ds, 
                                 batch_size=batch_size, 
                                 shuffle=False,
                                 num_workers=self.args.n_workers)
        test_result = self.infer_epoch(test_loader, verbose)
        return test_result

In [7]:
mri_types = ['T1wCE']
img_size = 32
batch_size = 32
n_workers = 4
early_stopping_rounds = 3
n_folds = 5
epochs = 50
lr = 1e-3

In [8]:
args = Config(img_size=img_size, 
              n_workers=n_workers, 
              early_stopping_rounds=early_stopping_rounds,
              n_folds=n_folds)
pipeline = Pipeline(args)

## Train

In [9]:
pipeline.load_annotations()
for fold in range(n_folds):
    print(f'### Train {mri_types} on fold {fold}: ###')
    train_ds, val_ds, val_holdout_ds = pipeline.prepare_datasets(mri_types=mri_types, 
                                                                 fold=fold,
                                                                 cache_rate=1.0)
    pipeline.load_model()
    pipeline.fit(train_ds, val_ds, val_holdout_ds,
                 batch_size=batch_size, epochs=epochs, lr=lr, 
                 model_name=f'{"_".join(mri_types)}_fold{fold}',
                 verbose=True)

### Train ['T1wCE'] on fold 0: ###


100%|██████████| 465/465 [00:08<00:00, 51.78it/s]
Loading dataset: 100%|██████████| 10317/10317 [00:34<00:00, 299.08it/s]
100%|██████████| 465/465 [00:00<00:00, 538.11it/s]
Loading dataset: 100%|██████████| 2580/2580 [00:08<00:00, 297.88it/s]
100%|██████████| 117/117 [00:02<00:00, 48.44it/s]
Loading dataset: 100%|██████████| 3299/3299 [00:11<00:00, 299.81it/s]



Epoch 1/50:
Train loss: 0.68535 - time: 0:00:03.711379
 Validation:
Val loss: 0.67889 - time: 0:00:00.639530
Patient AUC: 0.58403 - Slice AUC: 0.59365
 Hold out:
Val loss: 0.68701 - time: 0:00:00.768204
Patient AUC: 0.60235 - Slice AUC: 0.56842
Val AUC improved from -inf to 0.58403

Epoch 2/50:
Train loss: 0.68121 - time: 0:00:03.581786
 Validation:
Val loss: 0.67336 - time: 0:00:00.644600
Patient AUC: 0.59802 - Slice AUC: 0.60609
 Hold out:
Val loss: 0.68926 - time: 0:00:01.182018
Patient AUC: 0.60176 - Slice AUC: 0.57000
Val AUC improved from 0.58403 to 0.59802

Epoch 3/50:
Train loss: 0.67890 - time: 0:00:04.054740
 Validation:
Val loss: 0.67148 - time: 0:00:00.931953
Patient AUC: 0.60027 - Slice AUC: 0.60966
 Hold out:
Val loss: 0.68384 - time: 0:00:00.785311
Patient AUC: 0.60587 - Slice AUC: 0.57171
Val AUC improved from 0.59802 to 0.60027

Epoch 4/50:
Train loss: 0.67315 - time: 0:00:03.715514
 Validation:
Val loss: 0.66991 - time: 0:00:00.735603
Patient AUC: 0.61896 - Slice AUC

100%|██████████| 465/465 [00:00<00:00, 482.72it/s]
Loading dataset: 100%|██████████| 10437/10437 [00:31<00:00, 336.24it/s]
100%|██████████| 465/465 [00:00<00:00, 583.66it/s]
Loading dataset: 100%|██████████| 2610/2610 [00:07<00:00, 349.55it/s]
100%|██████████| 117/117 [00:00<00:00, 453.08it/s]
Loading dataset: 100%|██████████| 3149/3149 [00:09<00:00, 346.27it/s]



Epoch 1/50:
Train loss: 0.67898 - time: 0:00:04.320052
 Validation:
Val loss: 0.68940 - time: 0:00:00.657396
Patient AUC: 0.61370 - Slice AUC: 0.58954
 Hold out:
Val loss: 0.69982 - time: 0:00:00.809674
Patient AUC: 0.52635 - Slice AUC: 0.52039
Val AUC improved from -inf to 0.61370

Epoch 2/50:
Train loss: 0.67141 - time: 0:00:04.111772
 Validation:
Val loss: 0.67260 - time: 0:00:00.716331
Patient AUC: 0.62930 - Slice AUC: 0.61079
 Hold out:
Val loss: 0.70947 - time: 0:00:00.805286
Patient AUC: 0.52283 - Slice AUC: 0.51242
Val AUC improved from 0.61370 to 0.62930

Epoch 3/50:
Train loss: 0.66017 - time: 0:00:03.904916
 Validation:
Val loss: 0.67450 - time: 0:00:00.706395
Patient AUC: 0.65397 - Slice AUC: 0.62409
 Hold out:
Val loss: 0.72969 - time: 0:00:00.857139
Patient AUC: 0.54947 - Slice AUC: 0.53265
Val AUC improved from 0.62930 to 0.65397

Epoch 4/50:
Train loss: 0.65311 - time: 0:00:03.717113
 Validation:
Val loss: 0.65404 - time: 0:00:00.672204
Patient AUC: 0.67276 - Slice AUC

100%|██████████| 466/466 [00:00<00:00, 475.35it/s]
Loading dataset: 100%|██████████| 10248/10248 [00:29<00:00, 342.13it/s]
100%|██████████| 466/466 [00:00<00:00, 583.74it/s]
Loading dataset: 100%|██████████| 2562/2562 [00:07<00:00, 335.36it/s]
100%|██████████| 116/116 [00:00<00:00, 459.37it/s]
Loading dataset: 100%|██████████| 3386/3386 [00:09<00:00, 348.99it/s]



Epoch 1/50:
Train loss: 0.68612 - time: 0:00:04.289785
 Validation:
Val loss: 0.68509 - time: 0:00:00.993450
Patient AUC: 0.55936 - Slice AUC: 0.57183
 Hold out:
Val loss: 0.67084 - time: 0:00:00.879693
Patient AUC: 0.58838 - Slice AUC: 0.59110
Val AUC improved from -inf to 0.55936

Epoch 2/50:
Train loss: 0.68087 - time: 0:00:03.811641
 Validation:
Val loss: 0.67837 - time: 0:00:00.661708
Patient AUC: 0.59370 - Slice AUC: 0.59042
 Hold out:
Val loss: 0.68100 - time: 0:00:00.877582
Patient AUC: 0.56215 - Slice AUC: 0.56687
Val AUC improved from 0.55936 to 0.59370

Epoch 3/50:
Train loss: 0.67523 - time: 0:00:03.958835
 Validation:
Val loss: 0.67550 - time: 0:00:00.670949
Patient AUC: 0.61087 - Slice AUC: 0.60430
 Hold out:
Val loss: 0.68508 - time: 0:00:00.870734
Patient AUC: 0.55440 - Slice AUC: 0.55646
Val AUC improved from 0.59370 to 0.61087

Epoch 4/50:
Train loss: 0.66786 - time: 0:00:03.789062
 Validation:
Val loss: 0.66642 - time: 0:00:00.667412
Patient AUC: 0.64505 - Slice AUC

100%|██████████| 466/466 [00:01<00:00, 367.28it/s]
Loading dataset: 100%|██████████| 10393/10393 [00:30<00:00, 336.34it/s]
100%|██████████| 466/466 [00:01<00:00, 430.76it/s]
Loading dataset: 100%|██████████| 2599/2599 [00:07<00:00, 340.66it/s]
100%|██████████| 116/116 [00:00<00:00, 477.18it/s]
Loading dataset: 100%|██████████| 3204/3204 [00:09<00:00, 334.55it/s]



Epoch 1/50:
Train loss: 0.68188 - time: 0:00:03.681989
 Validation:
Val loss: 0.67805 - time: 0:00:00.680517
Patient AUC: 0.60439 - Slice AUC: 0.61067
 Hold out:
Val loss: 0.70457 - time: 0:00:00.786390
Patient AUC: 0.50045 - Slice AUC: 0.53603
Val AUC improved from -inf to 0.60439

Epoch 2/50:
Train loss: 0.67106 - time: 0:00:03.651992
 Validation:
Val loss: 0.66368 - time: 0:00:00.764350
Patient AUC: 0.62988 - Slice AUC: 0.63710
 Hold out:
Val loss: 0.71838 - time: 0:00:00.893962
Patient AUC: 0.51684 - Slice AUC: 0.50864
Val AUC improved from 0.60439 to 0.62988

Epoch 3/50:
Train loss: 0.66576 - time: 0:00:03.743775
 Validation:
Val loss: 0.65781 - time: 0:00:00.877862
Patient AUC: 0.64524 - Slice AUC: 0.64701
 Hold out:
Val loss: 0.72217 - time: 0:00:01.116984
Patient AUC: 0.49598 - Slice AUC: 0.51070
Val AUC improved from 0.62988 to 0.64524

Epoch 4/50:
Train loss: 0.65841 - time: 0:00:03.865525
 Validation:
Val loss: 0.64926 - time: 0:00:00.895567
Patient AUC: 0.66984 - Slice AUC

100%|██████████| 466/466 [00:01<00:00, 460.95it/s]
Loading dataset: 100%|██████████| 10430/10430 [00:31<00:00, 331.21it/s]
100%|██████████| 466/466 [00:00<00:00, 536.39it/s]
Loading dataset: 100%|██████████| 2608/2608 [00:07<00:00, 337.93it/s]
100%|██████████| 116/116 [00:00<00:00, 470.45it/s]
Loading dataset: 100%|██████████| 3158/3158 [00:09<00:00, 338.27it/s]



Epoch 1/50:
Train loss: 0.68420 - time: 0:00:03.787243
 Validation:
Val loss: 0.67861 - time: 0:00:00.697848
Patient AUC: 0.61644 - Slice AUC: 0.61180
 Hold out:
Val loss: 0.68934 - time: 0:00:00.800263
Patient AUC: 0.54456 - Slice AUC: 0.55619
Val AUC improved from -inf to 0.61644

Epoch 2/50:
Train loss: 0.67509 - time: 0:00:05.119620
 Validation:
Val loss: 0.66888 - time: 0:00:00.669850
Patient AUC: 0.62942 - Slice AUC: 0.62235
 Hold out:
Val loss: 0.70112 - time: 0:00:00.801106
Patient AUC: 0.53443 - Slice AUC: 0.54192
Val AUC improved from 0.61644 to 0.62942

Epoch 3/50:
Train loss: 0.66896 - time: 0:00:03.788467
 Validation:
Val loss: 0.67112 - time: 0:00:00.672243
Patient AUC: 0.64997 - Slice AUC: 0.63016
 Hold out:
Val loss: 0.69995 - time: 0:00:00.766972
Patient AUC: 0.53472 - Slice AUC: 0.54224
Val AUC improved from 0.62942 to 0.64997

Epoch 4/50:
Train loss: 0.65556 - time: 0:00:03.648032
 Validation:
Val loss: 0.66622 - time: 0:00:00.667194
Patient AUC: 0.64118 - Slice AUC

## Evaluate

In [10]:
metrics = []
results = []
find_weight = lambda x: [w for w in os.listdir() if x in w][0]
weights_paths = [f'{"_".join(mri_types)}_fold{fold}' for fold in range(n_folds)]
weights_paths = [find_weight(x) for x in weights_paths]
for fold, weights_path in enumerate(weights_paths):
    print(f'### Evaluate {mri_types} on fold {fold}: ###')
    _, _, val_holdout_ds = pipeline.prepare_datasets(mri_types=mri_types, 
                                                     fold=fold,
                                                     cache_rate=0.0)
    pipeline.load_model(weights_path)
    val_metric, val_result = pipeline.evaluate(val_holdout_ds, batch_size=batch_size, verbose=True)
    metrics.append(val_metric)
    results.append(val_result)
results = pd.concat(results, ignore_index=True)
mean_auc = np.mean(metrics)
oof_auc = roc_auc_score(results['label'], results['probability'])
print('---')
print(f'{mri_types} holdout result:')
print(' Mean AUC: {:.5f}'.format(mean_auc))
print(' Out-of-fold AUC: {:.5f}'.format(oof_auc))
print('---')

### Evaluate ['T1wCE'] on fold 0: ###


100%|██████████| 465/465 [00:01<00:00, 449.10it/s]
100%|██████████| 465/465 [00:00<00:00, 589.32it/s]
100%|██████████| 117/117 [00:00<00:00, 500.66it/s]

 Hold out:





Val loss: 1.46764 - time: 0:00:06.037797
Patient AUC: 0.41026 - Slice AUC: 0.44033
### Evaluate ['T1wCE'] on fold 1: ###


100%|██████████| 465/465 [00:00<00:00, 469.45it/s]
100%|██████████| 465/465 [00:00<00:00, 573.59it/s]
100%|██████████| 117/117 [00:00<00:00, 575.17it/s]

 Hold out:





Val loss: 1.55515 - time: 0:00:05.746653
Patient AUC: 0.51874 - Slice AUC: 0.50661
### Evaluate ['T1wCE'] on fold 2: ###


100%|██████████| 466/466 [00:00<00:00, 596.87it/s]
100%|██████████| 466/466 [00:00<00:00, 594.94it/s]
100%|██████████| 116/116 [00:00<00:00, 590.71it/s]


 Hold out:
Val loss: 0.92470 - time: 0:00:06.588217
Patient AUC: 0.51297 - Slice AUC: 0.51377
### Evaluate ['T1wCE'] on fold 3: ###


100%|██████████| 466/466 [00:01<00:00, 407.88it/s]
100%|██████████| 466/466 [00:01<00:00, 386.76it/s]
100%|██████████| 116/116 [00:00<00:00, 421.89it/s]

 Hold out:





Val loss: 1.24312 - time: 0:00:07.674981
Patient AUC: 0.49091 - Slice AUC: 0.50873
### Evaluate ['T1wCE'] on fold 4: ###


100%|██████████| 466/466 [00:01<00:00, 407.63it/s]
100%|██████████| 466/466 [00:01<00:00, 430.19it/s]
100%|██████████| 116/116 [00:00<00:00, 449.26it/s]

 Hold out:





Val loss: 1.31189 - time: 0:00:06.226836
Patient AUC: 0.58003 - Slice AUC: 0.51548
---
['T1wCE'] holdout result:
 Mean AUC: 0.50258
 Out-of-fold AUC: 0.50393
---


As we can see, although the validation AUCs are very high (0.8x to 0.9x) for every fold, the holdout AUC is only around 0.5, which is not even better than random guess.

## Inference

Final prediction for the test set is the average of the predictions of all 5 models.

In [11]:
test_results = []
for fold, weights_path in enumerate(weights_paths):
    print(f'### Inference {mri_types} on fold {fold}: ###')
    test_ds = pipeline.prepare_test_dataset(mri_types=mri_types, cache_rate=0.0)
    pipeline.load_model(weights_path)
    test_result = pipeline.predict(test_ds, batch_size=batch_size, verbose=True)
    test_results.append(test_result)

### Inference ['T1wCE'] on fold 0: ###


100%|██████████| 87/87 [00:02<00:00, 41.24it/s]


Elapsed time: 0:00:06.860399
### Inference ['T1wCE'] on fold 1: ###


100%|██████████| 87/87 [00:00<00:00, 523.89it/s]


Elapsed time: 0:00:04.786158
### Inference ['T1wCE'] on fold 2: ###


100%|██████████| 87/87 [00:00<00:00, 820.86it/s]


Elapsed time: 0:00:04.611454
### Inference ['T1wCE'] on fold 3: ###


100%|██████████| 87/87 [00:00<00:00, 860.81it/s]


Elapsed time: 0:00:04.934411
### Inference ['T1wCE'] on fold 4: ###


100%|██████████| 87/87 [00:00<00:00, 871.78it/s]


Elapsed time: 0:00:04.624844


In [12]:
prediction = pd.concat([x.set_index('BraTS21ID') for x in test_results], axis=1).mean(axis=1)
prediction = pd.DataFrame(prediction, columns=['MGMT_value']).reset_index()
prediction.to_csv('submission.csv',index=False)