In [None]:
!pip install torchio
import torchio as tio

In [None]:
import os
import sys 
import json
import glob
import random
import collections
import time

import numpy as np
import pandas as pd
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from torch.nn import functional as torch_functional
import torch.nn.functional as F
from torchvision import transforms, utils

from sklearn import model_selection
from sklearn import metrics
from skimage import exposure

from albumentations import Resize, Normalize, Compose
from albumentations.pytorch import ToTensorV2
import albumentations as album

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

from tqdm import tqdm 

import warnings
warnings.filterwarnings("ignore")
plt.style.use("dark_background")

# **Setting up Configurations**

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

In [None]:
# CONFIG

# -- Common -- 
SEED = 42
data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    
# -- Data --
mri_types = ['T1w']
SIZE = 256
PAD_SIZE = 512
NUM_IMAGES = 64\

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(SEED)

# **Useful Functions**

In [None]:
def pad_images(images, pad_size=PAD_SIZE):
    h, w = images.shape[:2]
    diff_vert = pad_size - h
    pad_top = diff_vert // 2
    pad_bottom = diff_vert - pad_top
    diff_hori = pad_size - w
    pad_left = diff_hori // 2
    pad_right = diff_hori - pad_left
    
    images = cv2.copyMakeBorder(images, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
    assert images.shape[:2] == (pad_size, pad_size)
    
    return images

In [None]:
def load_dicom_image(path, img_size=SIZE):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    
    if np.min(data) == np.max(data):
        data = np.zeros((img_size,img_size))
        return data
    
    data = exposure.equalize_adapthist(data, clip_limit=0.04)
    data = apply_voi_lut(dicom.pixel_array, dicom)
    
    if dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
        
    data = cv2.resize(data, (img_size, img_size))
    
    data = data - np.min(data)
    data = data / np.max(data)
    data = (data * 255).astype(np.uint8)

    return data.astype(np.uint8)

In [None]:
def load_dicom_images_3d(scan_id, num_imgs=NUM_IMAGES, img_size=SIZE, mri_type="FLAIR", split="train"):
    files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"), 
                   key=lambda x: int(x.split('/')[-1].split('-')[-1].split('.')[0]))
    
    middle = len(files) // 2
    num_imgs2 = num_imgs // 2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f) for f in files[p1:p2]]).T 
    
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis=-1)
        
        
    if np.min(img3d) < np.max(img3d):
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
            
    return img3d

In [None]:
def cropped_images(images, img_size=SIZE):
    try:
        min=np.array(np.nonzero(images)).min(axis=1)
        max=np.array(np.nonzero(images)).max(axis=1)
        images = images[min[0]:max[0], min[1]:max[1], :]
    except ValueError:
        pass
    
    images = cv2.resize(images, (img_size, img_size))
    
    return images

In [None]:
def visualize(**images):
    n = len(images)
    plt.figure(figsize=(16, 12))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# **3D Augmenatation and Transformation**

In [None]:
# 3D Augmentation

flip = tio.RandomFlip(axes=['inferior-superior'])

swap = tio.RandomSwap(patch_size=[5, 5, 5], p=.4)
add_noise = tio.RandomNoise(std=0.5, p=.1)
bias_field = tio.RandomBiasField(coefficients=0.4, p=.6)
add_motion = tio.RandomMotion(num_transforms=1, image_interpolation='nearest', p=.2)

canonical = tio.ToCanonical()
standardize = tio.ZNormalization(masking_method=tio.ZNormalization.mean)
intensity = tio.RescaleIntensity((-1, 1))

def validation_augmentation_3d():
    transform = tio.Compose([
        canonical,
#         standardize, 
#         intensity
    ])
    
    return transform

# **Loading the Dataset**

In [None]:
train_df = pd.read_csv(f"{data_directory}/train_labels.csv")
display(train_df)

df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=0.3, 
    random_state=SEED, 
    stratify=train_df["MGMT_value"],
)

# **3D ResNet Model**

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, out_planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        
        
        self.conv1 = nn.Conv3d(in_planes, 
                               out_planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False
                              )
        
        self.conv2 = nn.Conv3d(out_planes, 
                               out_planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False
                              )

        self.bn1 = nn.BatchNorm3d(out_planes)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm3d(out_planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        # Residual Connection Block
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

    
class BasicStem(nn.Sequential):
    """
    conv-batchnorm-relu stem
    """
    def __init__(self, in_planes=64, in_channels=1):
        super(BasicStem, self).__init__(
            nn.Conv3d(in_channels, in_planes, 
                      kernel_size=(7, 7, 7), 
                      stride=(1, 2, 2),
                      padding=(1, 3, 3), 
                      bias=False
                     ),
            nn.BatchNorm3d(in_planes),
            nn.ReLU(inplace=True)
        )


class ResNet3D(nn.Module):

    def __init__(self, block, stem,
                 model_name='resnet-18',
                 in_channels=1,
                 n_classes=2
                ):
        super(ResNet3D, self).__init__()
        
        __depths__ = {
            'resnet-10': [1, 1, 1, 1],
            'resnet-18': [2, 2, 2, 2],
        }
        
        assert model_name in __depths__, f'Specified model name {model_name} cant be loaded\nAvailable models: {[model for model in __depths__]}'
        layers = __depths__[model_name]
        self.inplanes = 64
        
        # Stem
        self.stem = stem(self.inplanes, in_channels)
        
        # Layers
        self.layer1 = self._layer(block, 64, layers[0],)
        self.layer2 = self._layer(block, 128, layers[1], stride=2)
        self.layer3 = self._layer(block, 256, layers[2], stride=2)
        self.layer4 = self._layer(block, 512, layers[3], stride=2)
        
        # Fetching
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        # Flatten the layer to fc
        x = x.flatten(1)
        x = self.fc(x)

        return x

    def _layer(self, block, planes, blocks, stride=1):
        downsample = None
        
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))

        self.inplanes = planes * block.expansion

        for i in range(blocks - 1):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

# **Testing the model**

In [None]:
df_train = df_train.set_index("BraTS21ID")
df_train["MGMT_pred"] = 0

In [None]:
df_valid = df_valid.set_index("BraTS21ID")
df_valid["MGMT_pred"] = 0

In [None]:
modelfile = "[your-model-path]"

model = ResNet3D(
    block=BasicBlock, 
    stem=BasicStem, 
    model_name='resnet-10',
    in_channels=1,
    n_classes=1
)
model.to(device)

if torch.cuda.is_available():
    checkpoint = torch.load(modelfile)
else:
    checkpoint = torch.load(modelfile, map_location=torch.device('cpu'))    
model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
class TestDataset(torch_data.Dataset):
    def __init__(self, paths, labels=None, mri_type=None, label_smoothing=0.01, augmentation=None, transformation=None, split="train"):
        self.paths = paths
        self.labels = labels
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.augmentation = augmentation
        self.transformation = transformation
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        label = self.labels[index]
        images = []
        
        for i in mri_types:
            image_3d = load_dicom_images_3d(scan_id=str(scan_id).zfill(5), mri_type=i)
            image_3d = cropped_images(image_3d)

            if self.augmentation:
                for i in range(image_3d.shape[-1]):
                    temp_img = image_3d[:, :, i].astype(np.uint8)
                    temp_img = cv2.cvtColor(temp_img, cv2.COLOR_BGR2RGB)
                    temp_img = self.augmentation(image=temp_img)['image'][:, :, 0]
                    image_3d[:, :, i] = temp_img
            images.append(image_3d)
        four_channel_pack = np.stack(images)
        four_channel_pack = np.transpose(four_channel_pack, (0, 3, 2, 1))
        
        # transformation
        if self.transformation:
            four_channel_pack = self.transformation(four_channel_pack)
            
        y = self.labels[index]
        
        return torch.tensor(four_channel_pack).float(), y

# **AUC Score**

In [None]:
data_retriever = Dataset(
    df_valid.index.values, 
    df_valid["MGMT_value"].values,
    transformation=validation_augmentation_3d(),
    split="test",
)

data_loader = torch_data.DataLoader(
    data_retriever,
    batch_size=1,
    shuffle=False,
    num_workers=8,
)

y_preds = []
y = []

for e, batch in enumerate(data_loader):
    print(f"{e + 1}/{len(data_loader)}", end="\r")
    with torch.no_grad():
        model.eval()
        image, label = batch["X"].to(device), batch["y"]
    
        output_ = model(image)
        _, pred = torch.max(output_, dim=1)

        percentage = output_.sigmoid().detach().cpu().numpy().squeeze()
        prediction = percentage
        
        label = label.detach().cpu().numpy()[0]
        
        y.append(label)
        y_preds.append(prediction)

In [None]:
# score
y = np.array(y)
y_preds = np.array(y_preds)

fpr, tpr, thresholds = metrics.roc_curve(y, y_preds, pos_label=1)
roc_auc = metrics.auc(fpr, tpr)

print(f"AUC score is: {roc_auc}")

In [None]:
plt.style.use("seaborn-white")

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve

plt.figure(figsize=[8, 8])
plt.plot(fpr, tpr, label='3D ResNet10 (area = %0.2f)' % roc_auc, color='blue')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic [ROC AUC]')
plt.legend(loc="lower right")
plt.savefig('resnet10-rocauc.jpg')
plt.show()