In [1]:
import os
import sys

from pathlib import Path
import glob

import numpy as np
import pandas as pd

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset

import torchvision

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
%matplotlib inline

In [26]:
model_name = 'SimpleConvNet'
BATCH_SIZE = 16
IMAGE_SIZE = 256
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-3
NUM_EPOCH = 100

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
print(device)

cpu


In [4]:
label_root = Path('E:\\datasets\\rsna-miccai-brain-tumor-radiogenomic-classification')
data_root = Path('E:\\datasets\\RSNA MICCAI PNG')

In [5]:
SCAN_TYPE = ('FLAIR', 'T1w', 'T1wCE', 'T2w')

In [6]:
def one_hot(arr):
    return [[1, 0] if a_i == 0 else [0, 1] for a_i in arr]

In [7]:
trainval_label = pd.read_csv(str(label_root / 'train_labels.csv'))
print(trainval_label.shape)
trainval_label.head()

(585, 2)


Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1


In [8]:
trainval_path_list = glob.glob(str(data_root / 'train' / '*'))

train_path_list, valid_path_list = train_test_split(trainval_path_list, test_size = 0.1)

In [9]:
test_path_list = glob.glob(str(data_root / 'test' / '*'))

In [10]:
class ImageDataset(Dataset):
    def __init__(self, path_list, labels = None, transform = None):
        self.path_list = path_list
        self.labels = labels
        self.len = len(path_list)
        
        self.transform = transform
        
        data_list = []
        for path in tqdm(path_list):
            data = torch.zeros((len(SCAN_TYPE), IMAGE_SIZE, IMAGE_SIZE))
            for c, scan in enumerate(SCAN_TYPE):
                scan_path_list = glob.glob(path + f'/{scan}/*.png')
                if len(scan_path_list) > 0:
                    temp = torch.cat([torchvision.io.read_image(scan_path)/255 for scan_path in scan_path_list], axis = 0)
                else:
                    temp = torch.zeros((1, IMAGE_SIZE, IMAGE_SIZE))

                if self.transform:
                    temp = self.transform(temp)
                data[c,:,:] = torch.Tensor(temp.mean(axis = 0))
            data_list.append(data)
        self.data_list = data_list
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        path = self.path_list[idx]
        patient = path.split(os.path.sep)[-1]
        
        if self.labels is not None:
            label = self.labels.loc[self.labels['BraTS21ID'] == int(patient), 'MGMT_value'].values
            return self.data_list[idx], label
        else:
            return self.data_list[idx]

transform = torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE))

train_dataset = ImageDataset(train_path_list, labels = trainval_label, transform = transform)
valid_dataset = ImageDataset(valid_path_list, labels = trainval_label, transform = transform)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
valid_loader = DataLoader(valid_dataset, batch_size = BATCH_SIZE)


  0%|          | 0/526 [00:00<?, ?it/s]

  0%|          | 0/59 [00:00<?, ?it/s]

In [34]:
class SimpleModel(Module):
    def __init__(self):
        super().__init__()
        
        self.conv11 = nn.Conv2d(4, 4, kernel_size = (3,3), padding = (1,1))
        self.bn11 = nn.BatchNorm2d(4)
        self.conv12 = nn.Conv2d(4, 4, kernel_size = (3,3), padding = (1,1))
        self.bn12 = nn.BatchNorm2d(4)
        self.conv13 = nn.Conv2d(4, 4, kernel_size = (3,3), padding = (1,1))
        self.bn13 = nn.BatchNorm2d(4)
        self.conv1 = nn.Conv2d(4, 16, kernel_size = (3,3), padding = (1,1))
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(kernel_size = (2,2))
        
        self.conv21 = nn.Conv2d(16, 16, kernel_size = (3,3), padding = (1,1))
        self.bn21 = nn.BatchNorm2d(16)
        self.conv22 = nn.Conv2d(16, 16, kernel_size = (3,3), padding = (1,1))
        self.bn22 = nn.BatchNorm2d(16)
        self.conv23 = nn.Conv2d(16, 16, kernel_size = (3,3), padding = (1,1))
        self.bn23 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size = (3,3), padding = (1,1))
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(kernel_size = (2,2))
        
        self.conv31 = nn.Conv2d(32, 32, kernel_size = (3,3), padding = (1,1))
        self.bn31 = nn.BatchNorm2d(32)
        self.conv32 = nn.Conv2d(32, 32, kernel_size = (3,3), padding = (1,1))
        self.bn32 = nn.BatchNorm2d(32)
        self.conv33 = nn.Conv2d(32, 32, kernel_size = (3,3), padding = (1,1))
        self.bn33 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size = (3,3), padding = (1,1))
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(kernel_size = (2,2))
        
        self.fc1 = nn.Linear(64 * 32 * 32, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 2)
        
        self.dropout06 = nn.Dropout(0.6)
        self.dropout03 = nn.Dropout(0.3)
        
    def forward(self, x):
        skip = x
        x = self.bn11(F.relu(self.conv11(x)))
        x += skip
        x = self.bn12(F.relu(self.conv12(x)))
        x += skip
        x = self.bn13(F.relu(self.conv13(x)))
        x += skip
        x = self.pool1(self.bn1(F.relu(self.conv1(x))))
        
        skip = x
        x = self.bn21(F.relu(self.conv21(x)))
        x += skip
        x = self.bn22(F.relu(self.conv22(x)))
        x += skip
        x = self.bn23(F.relu(self.conv23(x)))
        x += skip
        x = self.pool2(self.bn2(F.relu(self.conv2(x))))
        
        skip = x
        x = self.bn31(F.relu(self.conv31(x)))
        x += skip
        x = self.bn32(F.relu(self.conv32(x)))
        x += skip
        x = self.bn33(F.relu(self.conv33(x)))
        x += skip
        x = self.pool3(self.bn3(F.relu(self.conv3(x))))
        
        x = x.reshape(-1, 64*32*32)
        x = F.relu(self.fc1(x))
        x = self.dropout06(x)
        x = F.relu(self.fc2(x))
        x = self.dropout06(x)
        x = F.relu(self.fc3(x))
        x = self.dropout03(x)
        x = F.softmax(self.fc4(x))
        
        return x

In [35]:
model = SimpleModel()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = WEIGHT_DECAY)

In [36]:
def train(prog):
    model.train()
    
    sum_loss = 0
    sum_roc = 0
    sum_count = 0
    for X, y in prog:
        X = X.to(device)
        y = y.to(device).squeeze()
        pred = model(X)
        
        optimizer.zero_grad()
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        
        sum_loss += loss * X.shape[0]
        sum_roc += roc_auc_score(one_hot(y.tolist()), pred.tolist()) * X.shape[0]
        sum_count += X.shape[0]
        
        prog.set_description(f"TRAIN: loss {sum_loss / sum_count :.4}, roc {sum_roc / sum_count:.4}")
        
    
    return sum_loss / sum_count, sum_roc / sum_count

In [37]:
def valid(prog):
    model.eval()
    
    sum_loss = 0
    sum_roc = 0
    sum_count = 0
    for X, y in prog:
        X = X.to(device)
        y = y.to(device).squeeze()
        
        pred = model(X)
        
        loss = criterion(pred, y)
        
        sum_loss += loss * X.shape[0]
        sum_roc += roc_auc_score(one_hot(y.tolist()), pred.tolist()) * X.shape[0]
        sum_count += X.shape[0]
        
        prog.set_description(f"TRAIN: loss {sum_loss / sum_count :.4}, roc {sum_roc / sum_count:.4}")
        
    return sum_loss / sum_count, sum_roc / sum_count

In [38]:
try:
    model = torch.load(f'./models/{model_name}.pkl')
except:
    pass

In [None]:
max_valid_roc = 0

history = {
    "train_loss":[],
    "train_roc":[],
    "valid_loss":[],
    "valid_roc":[]
}

for epoch in tqdm(range(NUM_EPOCH)):
    print( "-------------------------------------------------------")
    print(f"|EPOCH: {epoch+1}/{NUM_EPOCH}")
    
    train_loss, train_roc = train(tqdm(train_loader))
    valid_loss, valid_roc = valid(tqdm(valid_loader))
    
    history['train_loss'].append(train_loss)
    history['train_roc'].append(train_roc)
    history['valid_loss'].append(valid_loss)
    history['valid_roc'].append(valid_roc)
    
    if valid_roc > max_valid_roc:
        print(f"|{epoch+1}-th model is checked!, *{model_name}-{epoch}-{valid_roc}.pkl*")
        max_valid_roc = valid_roc
        torch.save(model, f'./models/{model_name}-{epoch}-{valid_roc}.pkl')
        
    print(f"|TRAIN: loss={train_loss:.4f} roc={train_roc:.4f}|")
    print(f"|VALID: loss={valid_loss:.4f} roc={valid_roc:.4f}|")

  0%|          | 0/100 [00:00<?, ?it/s]

-------------------------------------------------------
|EPOCH: 1/100


  0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

|1-th model is checked!, *SimpleConvNet-0-0.4025221953188054.pkl*
|TRAIN: loss=0.8027 roc=0.4806|
|VALID: loss=0.6692 roc=0.4025|
-------------------------------------------------------
|EPOCH: 2/100


  0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

|2-th model is checked!, *SimpleConvNet-1-0.49771320957761633.pkl*
|TRAIN: loss=0.8019 roc=0.4978|
|VALID: loss=0.6692 roc=0.4977|
-------------------------------------------------------
|EPOCH: 3/100


  0%|          | 0/33 [00:00<?, ?it/s]

In [None]:
fig, ax = plt.subplots(2,1,figsize = (16,10))

ax[0].plot(history['train_loss'])
ax[0].plot(history['valid_loss'])

ax[1].plot(history['train_roc'])
ax[1].plot(history['valid_roc'])