# Import Libraries 

In [1]:
# https://www.kaggle.com/rluethy/efficientnet3d-with-one-mri-type
import os
import sys 
import json
import glob
import random
import collections
import time
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

import warnings
warnings.filterwarnings("ignore")

# System Constants

In [2]:
SCAN_TYPE = ('FLAIR', 'T1w', 'T1wCE', 'T2w')

# Hyperparameters

In [3]:
SEED = 42

IMAGE_SIZE = 256
BATCH_SIZE = 64
LEARNING_RATE = 1e-6
WEIGHT_DECAY = 1e-4
EPOCH_NUM = 1000

# Set System

In [4]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

In [5]:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
device

device(type='cuda')

In [6]:
data_path = Path('E:\\datasets\\RSNA MICCAI PNG')
label_path = Path('E:\\datasets\\rsna-miccai-brain-tumor-radiogenomic-classification')

# Load data
## Load label

In [7]:
df_label = pd.read_csv(str(label_path / "train_labels.csv"))

print(df_label.shape)
df_label.head()

(585, 2)


Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1


## Load submission sample

In [8]:
df_submission = pd.read_csv(str(label_path / "sample_submission.csv"))

print(df_submission.shape)
df_submission.head()

(87, 2)


Unnamed: 0,BraTS21ID,MGMT_value
0,1,0.5
1,13,0.5
2,15,0.5
3,27,0.5
4,37,0.5


## Train valid split

In [9]:
trainval_path_list = glob.glob(str(data_path / "train" / '*'))
train_path_list, valid_path_list = train_test_split(trainval_path_list, test_size = 0.1)

test_path_list = glob.glob(str(data_path / "test" / '*' / SCAN_TYPE[0] / '*.png'))

## Set train, valid path

In [10]:
path_list = []
for path in train_path_list:
    path_list += glob.glob(str(Path(path) / SCAN_TYPE[0] / '*.png'))
train_path_list = path_list

path_list = []
for path in valid_path_list:
    path_list += glob.glob(str(Path(path) / SCAN_TYPE[0] / '*.png'))
valid_path_list = path_list

## Create Train DataFrame

In [11]:
df_train = pd.DataFrame(train_path_list, columns = ['file_path'])

df_train['BraTS21ID'] = df_train.apply(lambda x: int(x['file_path'].split(os.path.sep)[-3]), axis = 1)

df_train['MGMT_value'] = df_train.apply(lambda x: df_label.loc[df_label['BraTS21ID'] == x['BraTS21ID'], 'MGMT_value'].values, axis = 1).astype(int)

df_train

Unnamed: 0,file_path,BraTS21ID,MGMT_value
0,E:\datasets\RSNA MICCAI PNG\train\00628\FLAIR\...,628,1
1,E:\datasets\RSNA MICCAI PNG\train\00628\FLAIR\...,628,1
2,E:\datasets\RSNA MICCAI PNG\train\00628\FLAIR\...,628,1
3,E:\datasets\RSNA MICCAI PNG\train\00628\FLAIR\...,628,1
4,E:\datasets\RSNA MICCAI PNG\train\00628\FLAIR\...,628,1
...,...,...,...
48549,E:\datasets\RSNA MICCAI PNG\train\00154\FLAIR\...,154,0
48550,E:\datasets\RSNA MICCAI PNG\train\00154\FLAIR\...,154,0
48551,E:\datasets\RSNA MICCAI PNG\train\00154\FLAIR\...,154,0
48552,E:\datasets\RSNA MICCAI PNG\train\00154\FLAIR\...,154,0


## Create Valid DataFrame

In [12]:
df_valid = pd.DataFrame(valid_path_list, columns = ['file_path'])

df_valid['BraTS21ID'] = df_valid.apply(lambda x: int(x['file_path'].split(os.path.sep)[-3]), axis = 1)

df_valid['MGMT_value'] = df_valid.apply(lambda x: df_label.loc[df_label['BraTS21ID'] == x['BraTS21ID'], 'MGMT_value'].values, axis = 1).astype(int)

df_valid

Unnamed: 0,file_path,BraTS21ID,MGMT_value
0,E:\datasets\RSNA MICCAI PNG\train\00561\FLAIR\...,561,1
1,E:\datasets\RSNA MICCAI PNG\train\00561\FLAIR\...,561,1
2,E:\datasets\RSNA MICCAI PNG\train\00561\FLAIR\...,561,1
3,E:\datasets\RSNA MICCAI PNG\train\00561\FLAIR\...,561,1
4,E:\datasets\RSNA MICCAI PNG\train\00561\FLAIR\...,561,1
...,...,...,...
5123,E:\datasets\RSNA MICCAI PNG\train\00652\FLAIR\...,652,1
5124,E:\datasets\RSNA MICCAI PNG\train\00652\FLAIR\...,652,1
5125,E:\datasets\RSNA MICCAI PNG\train\00652\FLAIR\...,652,1
5126,E:\datasets\RSNA MICCAI PNG\train\00652\FLAIR\...,652,1


## Create Test DataFrame

In [13]:
df_test = pd.DataFrame(test_path_list, columns = ['file_path'])

df_test['BraTS21ID'] = df_test.apply(lambda x: int(x['file_path'].split(os.path.sep)[-3]), axis = 1)

df_test

Unnamed: 0,file_path,BraTS21ID
0,E:\datasets\RSNA MICCAI PNG\test\00001\FLAIR\I...,1
1,E:\datasets\RSNA MICCAI PNG\test\00001\FLAIR\I...,1
2,E:\datasets\RSNA MICCAI PNG\test\00001\FLAIR\I...,1
3,E:\datasets\RSNA MICCAI PNG\test\00001\FLAIR\I...,1
4,E:\datasets\RSNA MICCAI PNG\test\00001\FLAIR\I...,1
...,...,...
7921,E:\datasets\RSNA MICCAI PNG\test\01006\FLAIR\I...,1006
7922,E:\datasets\RSNA MICCAI PNG\test\01006\FLAIR\I...,1006
7923,E:\datasets\RSNA MICCAI PNG\test\01006\FLAIR\I...,1006
7924,E:\datasets\RSNA MICCAI PNG\test\01006\FLAIR\I...,1006


# Dataset

In [14]:
class CustomDataset(Dataset):
    def __init__(self, df, is_train = False):
        self.df = df
        self.is_train = is_train
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        path = self.df.iloc[idx]['file_path']
        X = torch.Tensor(cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (IMAGE_SIZE, IMAGE_SIZE)))
        bts_id = self.df.iloc[idx]['BraTS21ID']
        
        if self.is_train:
            y = self.df.iloc[idx]['MGMT_value']
            return X, y, bts_id
        else:
            return X, bts_id
        

In [15]:
train_dataset = CustomDataset(df_train, is_train = True)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)

In [16]:
each_valid_dataset = CustomDataset(df_valid, is_train = True)
each_valid_loader = DataLoader(each_valid_dataset, batch_size = BATCH_SIZE)

id_valid_dataset = CustomDataset(df_valid, is_train = False)
id_valid_loader = DataLoader(id_valid_dataset, batch_size = BATCH_SIZE)

In [17]:
test_dataset = CustomDataset(df_test, is_train = False)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE)

# Model

In [18]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        kernel_size = 3
        padding_size = kernel_size//2
        
        self.conv1 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv2 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv3 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv4 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv5 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv6 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv7 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv8 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv9 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv10 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv11 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.conv12 = nn.Sequential(nn.Conv2d(1, 1, kernel_size = kernel_size, padding = padding_size, bias = False),
                                 nn.BatchNorm2d(1),
                                 nn.ReLU(inplace = False))
        
        self.fc1 = nn.Linear(IMAGE_SIZE * IMAGE_SIZE, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc_out = nn.Linear(128, 1)
        self.act = nn.Sigmoid()
        
        self.dropout10 = nn.Dropout(0.1)
        self.dropout20 = nn.Dropout(0.2)
        self.dropout30 = nn.Dropout(0.3)
        
        
    def forward(self, x):
        x = x.unsqueeze(1)
        
        skip = x
        x = self.conv1(x)
        x += skip
        
        skip = x
        x = self.conv2(x)
        x += skip
        
        skip = x
        x = self.conv3(x)
        x += skip
        
        skip = x
        x = self.conv4(x)
        x += skip
        
        skip = x
        x = self.conv5(x)
        x += skip
        
        skip = x
        x = self.conv6(x)
        x += skip
        
        skip = x
        x = self.conv7(x)
        x += skip
        
        skip = x
        x = self.conv8(x)
        x += skip
        
        skip = x
        x = self.conv9(x)
        x += skip
        
        skip = x
        x = self.conv10(x)
        x += skip
        
        skip = x
        x = self.conv11(x)
        x += skip
        
        skip = x
        x = self.conv12(x)
        x += skip
        
        x = x.reshape(-1, IMAGE_SIZE * IMAGE_SIZE)
        x = self.dropout30(x)
        x = F.relu(self.fc1(x))
        x = self.dropout30(x)
        x = F.relu(self.fc2(x))
        x = self.dropout20(x)
        x = F.relu(self.fc3(x))
        x = self.dropout20(x)
        x = F.relu(self.fc4(x))
        x = self.dropout10(x)
        x = self.act(self.fc_out(x))
        return x        

# Trainer

In [19]:
class Trainer():
    def __init__(self, model, criterion, optimizer, device):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        
        self.history = dict()
    
    def train(self, loader):
        self.model.train()
        device = self.device
        
        y_all = []
        y_hat_all = []
        prog_bar = tqdm(loader)
        for i, (X, y, _) in enumerate(prog_bar):
            X = X.to(device)
            y = y.to(device).to(torch.float)
            
            self.optimizer.zero_grad()
            
            y_hat = self.model(X)
            y_hat = y_hat.reshape(-1)
            
            loss = self.criterion(y_hat, y)
            loss.backward()
            self.optimizer.step()
            
            y_all.append(y.cpu())
            y_hat_all.append(y_hat.cpu())
            
            temp_loss = self.criterion(torch.cat(y_hat_all), torch.cat(y_all))
            temp_roc = roc_auc_score(torch.cat(y_all).detach().numpy(), torch.cat(y_hat_all).detach().numpy())
            prog_bar.set_description(f"TRAIN loss {temp_loss:.4}, roc {temp_roc:.4}")
        
        return temp_loss, temp_roc
            
    
    def valid(self, loader):
        self.model.eval()
        device = self.device
        
        y_all = []
        y_hat_all = []
        prog_bar = tqdm(loader)
        
        with torch.no_grad():
            for X, y, _ in prog_bar:
                X = X.to(device)
                y = y.to(device).to(torch.float)

                y_hat = self.model(X)
                y_hat = y_hat.reshape(-1)

                y_all.append(y.cpu())
                y_hat_all.append(y_hat.cpu())

                temp_loss = self.criterion(torch.cat(y_hat_all), torch.cat(y_all))
                try:
                    temp_roc = roc_auc_score(torch.cat(y_all).detach().numpy(), torch.cat(y_hat_all).detach().numpy())
                except:
                    temp_roc = 0.
                prog_bar.set_description(f"VALID loss {temp_loss:.4}, roc {temp_roc:.4}")
        
        return temp_loss, temp_roc
    
    def test(self, loader):
        self.model.eval()
        device = self.device
        
        y_hat_all = []
        prog_bar = tqdm(loader)
        
        pred = dict()
        
        with torch.no_grad():
            for X, bts_id in prog_bar:
                X = X.to(device)

                y_hat = self.model(X)
                y_hat = y_hat.reshape(-1)
                
                for i, id_ in enumerate(bts_id):
                    id_ = int(id_.to(torch.int))
                    y_hat
                    if id_ in pred.keys():
                        pred[id_].append(float(y_hat[i].to(torch.float)))
                    else:
                        pred.update({id_:[float(y_hat[i].to(torch.float))]})
        
        return pred
    
    def result_to_roc(self, result, df_labels):
        
        result = {k:np.mean(v) for k, v in result.items()}
        true = []
        pred = []
        for k, v in result.items():
            true.append(df_labels.loc[df_labels['BraTS21ID'] == k, 'MGMT_value'].values)
            pred.append(v)
        return roc_auc_score(true, pred)
        
    
    def fit(self, train_loader, each_valid_loader, id_valid_loader, df_label, nepoch = 10):
        
        self.history["train_loss"] = []
        self.history["train_roc"] = []
        self.history["valid_loss"] = []
        self.history["valid_roc"] = []
        
        max_valid_roc = 0
        
        for epoch in tqdm(range(nepoch)):
            print(f"START {epoch}/{nepoch}")
            train_loss, train_roc = self.train(train_loader)
            valid_loss, valid_roc = self.valid(each_valid_loader)
            result = self.test(id_valid_loader)
            
            valid_roc_true = self.result_to_roc(result, df_label)
            
            if max_valid_roc < valid_roc_true:
                max_valid_roc = valid_roc_true
                self.save_model(f"./model-check-{valid_roc_true:.5}.pkl")
                print(f"./model-check-{valid_roc_true:.5}.pkl")
            
            print(f"TRAIN: loss {train_loss:.4}, roc {train_roc:.4}")
            print(f"VALID: loss {valid_loss:.4}, roc {valid_roc:.4}, total roc {valid_roc_true:.4}")
            
            self.history["train_loss"].append(train_loss)
            self.history["train_roc"].append(train_roc)
            self.history["valid_loss"].append(valid_loss)
            self.history["valid_roc"].append(valid_roc)
            
        return self.history
    
    def save_model(self, name):
        print(f"Save Model {name}")
        torch.save(self.model, name)
        
    def load_model(self, name):
        self.model = torch.load(name)

# Set model parameters and create trainer object

In [20]:
model = SimpleModel()
model = model.to(device)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)

trainer = Trainer(model, criterion, optimizer, device)

# Fit

In [None]:
history = trainer.fit(train_loader = train_loader, 
                      each_valid_loader = each_valid_loader, 
                      id_valid_loader = id_valid_loader, 
                      df_label = df_label, 
                      nepoch=EPOCH_NUM)

  0%|          | 0/1000 [00:00<?, ?it/s]

START 0/1000


  0%|          | 0/759 [00:00<?, ?it/s]

# Visualize Train Result

In [None]:
history = trainer.history
fig, axes = plt.subplots(2,1,figsize = (16,10))

axes[0].plot(history['train_loss'], label = 'train')
axes[0].plot(history['valid_loss'], label = 'valid')
axes[0].set_ylabel('Loss')

axes[1].plot(history['train_roc'], label = 'train')
axes[1].plot(history['valid_roc'], label = 'valid')
axes[1].set_ylabel('Roc')

# Load Model

In [None]:
trainer.load_model("./model-check-0.94879.pkl")

# Validate Model

In [None]:
result = trainer.test(id_valid_loader)

valid_roc_true = trainer.result_to_roc(result, df_label)
print(valid_roc_true)

In [None]:
df_label

# Create Submission

In [None]:
result = trainer.test(test_loader)

In [None]:
result = {k:np.mean(v) for k, v in result.items()}

for k, v in result.items():
    df_submission.loc[df_submission['BraTS21ID'] == k, 'MGMT_value'] = v
df_submission

In [None]:
df_submission.to_csv("submission.csv", index = False)