# 4x EfficientNet Classifiers - Inference

In this notebook, we will train **four** 3D convolutional neural networks and combine their predictions in order to recognize a brain tumor.

For each case (i.e., person) we know whether the person suffered from cancer (1) or not (0). Each independent case has a dedicated folder identified by a five-digit number. Within each of these “case” folders, there are four sub-folders, each of them corresponding to an MRI scan, The MRI scans include:

* Fluid Attenuated Inversion Recovery (FLAIR)
* T1-weighted pre-contrast (T1w)
* T1-weighted post-contrast (T1Gd)
* T2-weighted (T2)

I am using the datas set created by [Jonathan Besomi](https://www.kaggle.com/c/rsna-miccai-brain-tumor-radiogenomic-classification/discussion/253000#1388021). Many thanks for creating the data set!

**3D images**
* each image has the following shape: Channel x Width x Height x Depth (i.e., 1 x Width x Height x Depth)
* *depth:* the depth represents the depth or the number of slices. (I tried various values >= 30. If for a given id, the depth <= the given value (e.g., 30), I replaced the missing depth slices with zero matrices (see [Zabir Al Nazi Nabil](https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train)))
* added some albumentation such as CLAHE, brightness, and CoarseDropout for the training images
* removed black pixels (see [Zabir Al Nazi Nabil](https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train))

**3D CNN** <br>
* model used: [EfficientNet-PyTorch-3D](https://github.com/shijianjian/EfficientNet-PyTorch-3D)

#### Imports

In [None]:
import os
import random
import glob
import re

import pandas as pd

import numpy as np

import math

from functools import partial

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.autograd import Variable
from torch.optim import lr_scheduler

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

import cv2

import albumentations as A

from tqdm import tqdm

import wandb

from efficientnet_pytorch_3d import EfficientNet3D

import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

#### Seed

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True


set_seed(42)

#### Config

In [None]:
PATH = '..'

config = dict(
    # Pre-processing
    SLICE_NUMBER = 64,
    REMOVE_BLACK_BOUNDARIES = True,
    DICOM=False,
    # Albumentation
    RRC_SIZE = 256,
    RRC_MIN_SCALE = 0.85,
    RRC_RATIO = (1., 1.),
    CLAHE_CLIP_LIMIT = 2.0,
    CLAHE_TILE_GRID_SIZE = (8, 8),
    CLAHE_PROB = 0.50,
    BRIGHTNESS_LIMIT = (-0.2,0.2),
    BRIGHTNESS_PROB = 0.40,
    HUE_SHIFT = (-15, 15),
    SAT_SHIFT = (-15, 15),
    VAL_SHIFT = (-15, 15),
    HUE_PROB = 0.64,
    COARSE_MAX_HOLES = 16,
    COARSE_PROB = 0.7,
    # Training
    N_EPOCHS = 20,
    BATCH_SIZE = 16,
    LEARNING_RATE = 0.001,
    WEIGHT_DECAY = 0.02,
    LABEL_SMOOTHING = 0.02,
    OPTIMIZER = "SGD",
    MOMENTUM = 0.9,
    SCHEDULER = "ReduceLROnPlateau",
    # Logging
    VERBOSE = False,
    MODELNAME = "04-3D-4-EfficientNet"
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### wandb

In [None]:
wandb.login()
run = wandb.init(project='rsna-miccai', config=config, mode="disabled")
config = wandb.config

### 1. Load Data

To create a 3D image, we will pick SLICE_NUMBER middle slices from each of the four MRI types (i.e., FLAIR, T1w, T1wCE, T2w). For example, if we set *SLICE_NUMBER=30*, each 3D image will have the shape: 1 x Width x Height x 30. Further, one image contains only the images from *one* MRI type as opposed to other notebooks where all MRI types are combined in a single image.

* If for a given MRI type, the number of images < SLICE_Number, than we will "fill up" the remaining *number_of_images - SLICE_NUMBER* slices with all black images
* We removed black pixels (for more information, see [here](https://www.kaggle.com/furcifer/torch-efficientnet3d-for-mri-no-train))

### 1.1 Utilities

#### 1.1.1 Augmentation

In [None]:
train_transform = A.Compose([
    A.RandomResizedCrop(
        config.RRC_SIZE, config.RRC_SIZE,            
        scale=(config.RRC_MIN_SCALE, 1.0),
        ratio=config.RRC_RATIO,
        p=1.0
    ),
    A.CLAHE(
        clip_limit=config.CLAHE_CLIP_LIMIT,
        tile_grid_size=config.CLAHE_TILE_GRID_SIZE,
        p=config.CLAHE_PROB
    ),
    A.RandomBrightnessContrast(
        brightness_limit=config.BRIGHTNESS_LIMIT,
        p=config.BRIGHTNESS_PROB
    ),
    A.HueSaturationValue(
        hue_shift_limit=config.HUE_SHIFT, 
        sat_shift_limit=config.SAT_SHIFT, 
        val_shift_limit=config.VAL_SHIFT, 
        p=config.HUE_PROB
    ),
    A.CoarseDropout(
        max_holes=config.COARSE_MAX_HOLES,
        p=config.COARSE_PROB
    ),
    A.Normalize(
        mean=(123.675),
        std=(58.39),
        max_pixel_value=255.0,
        always_apply=True
    )
])

valid_transform = A.Compose([
    A.RandomResizedCrop( 
        config.RRC_SIZE, config.RRC_SIZE,            
        scale=(config.RRC_MIN_SCALE, 1.0),
        ratio=config.RRC_RATIO,
        p=1.0
    ),
    A.Normalize(
        mean=(123.675),
        std=(58.39),
        max_pixel_value=255.0,
        always_apply=True
    )
])

#### 1.1.2 Loading Images

In [None]:
def dicom_2_image(file, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(file)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        img = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        img = dicom.pixel_array
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        img = np.amax(img) - img
    
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img * 255).astype(np.uint8)
    return img

def remove_black_boundaries(img):
    (x, y) = np.where(img > 0)
    if len(x) > 0 and len(y) > 0:
        x_mn = np.min(x)
        x_mx = np.max(x)
        y_mn = np.min(y)
        y_mx = np.max(y)
        if (x_mx - x_mn) > 10 and (y_mx - y_mn) > 10:
            img = img[:,np.min(y):np.max(y)]
    return img

def get_3d_image(mri_type, aug, dicom):
    if config.VERBOSE:
        print(f"Length of folder: {len(mri_type)}")
    # Take SLICE_NUMBER slices from the middle
    threshold = config.SLICE_NUMBER // 2
    minimum_idx = len(mri_type)//2 - threshold if (len(mri_type)//2 - threshold) > 0 else 0
    maximum_idx = len(mri_type)//2 + threshold  # maximum can exceed the index
    if config.VERBOSE:
        print(f"Minimum {minimum_idx}")
        print(f"Maximum {maximum_idx}")
    # Create array which contains the images
    mri_img = []
    for file in mri_type[minimum_idx:maximum_idx]:
        # Read image
        if dicom:
            img = dicom_2_image(file)
        else:
            img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
        # Remove black boundaries
        if config.REMOVE_BLACK_BOUNDARIES:
            img = remove_black_boundaries(img)
            (x, y) = np.where(img > 0)
        # Apply albumentation
        if aug:
            transformed = train_transform(image=img)
            img = transformed["image"]
        else:
            transformed = valid_transform(image=img)
            img = transformed["image"]
        mri_img.append(np.array(img))
    mri_img = np.array(mri_img).T
    # If less than SLICE_NUMBER slices, add SLICE_NUMBER - mri_img.shape[-1] images with only zero values
    if mri_img.shape[-1] < config.SLICE_NUMBER:
        if config.VERBOSE:
            print(f"Current slices: {mri_img.shape[-1]}")
        n_zero = config.SLICE_NUMBER - mri_img.shape[-1]
        mri_img = np.concatenate((mri_img, np.zeros((config.RRC_SIZE, config.RRC_SIZE, n_zero))), axis = -1)
    if config.VERBOSE:
        print(f"Shape of mri_img: {mri_img.shape}")
    return mri_img
    
def load_images(scan_id, mri_type, aug=True, split="train", dicom=False):
    file_ext = "png"
    if dicom:
        file_ext = "dcm"
    if config.VERBOSE:
        print(f"Scan id {scan_id}")
        
    # Ascending sort
    if mri_type == "FLAIR":
        flair = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/FLAIR/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(flair, aug, dicom)
    elif mri_type == "T1w":
        t1w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1w, aug, dicom)
    elif mri_type == "T1wCE":
        t1wce = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T1wCE/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t1wce, aug, dicom)
    else:
        t2w = sorted(glob.glob(f"{PATH}/{split}/{scan_id}/T2w/*.{file_ext}"), key=lambda f: int(re.sub('\D', '', f)))
        img = get_3d_image(t2w, aug, dicom)
    
    # Return 3D image: ChannelsxWidthxHeightxDepth
    img = img.reshape((1,config.RRC_SIZE,config.RRC_SIZE,config.SLICE_NUMBER))
    return img

### 1.2 Dataset and Dataloader

Create a PyTorch Dataset and DataLoader **for each** MRI type

In [None]:
# Load ids and labels and make a stratified 80:20 split
df_test = pd.read_csv(f"{PATH}/sample_submission.csv")

In [None]:
class RSNADataset(Dataset):
    def __init__(self, ids, labels, mri_type="FLAIR", split="train", dicom=False, label_smoothing=0.0):
        self.ids = ids
        self.labels = labels
        self.mri_type = mri_type
        self.split = split
        self.dicom = dicom
        self.label_smoothing = label_smoothing
        
        if split == "train":
            remove_ids = [709, 109, 123]
            self.ids = [id_ for id_ in self.ids if id_ not in remove_ids]  
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        patient_id = self.ids[idx]
        patient_id = str(patient_id).zfill(5)
        if self.split == "train":
            imgs = load_images(patient_id, self.mri_type, aug=True, dicom=self.dicom)
        elif self.split == "valid":
            imgs = load_images(patient_id, self.mri_type, aug=False, dicom=self.dicom)
        else:
            imgs = load_images(patient_id, self.mri_type, aug=False, split=self.split, dicom=self.dicom)
        # Normalize
        imgs = imgs - imgs.min()
        imgs = (imgs + 1e-5) / (imgs.max() - imgs.min() + 1e-5)

        if self.split != "test":
            label = abs(self.labels[idx] - self.label_smoothing)
            return torch.tensor(imgs, dtype = torch.float32), torch.tensor(label, dtype = torch.long)
        else:
            return torch.tensor(imgs, dtype = torch.float32), torch.tensor(self.ids[idx], dtype = torch.long)

In [None]:
def get_dataloader(df, mri_type, shuffle=True):
    ds = RSNADataset(df["BraTS21ID"].to_numpy(), df["MGMT_value"].to_numpy(), mri_type=mri_type, 
                     label_smoothing=config.LABEL_SMOOTHING, dicom=config.DICOM, split="test")
    dl = DataLoader(ds, batch_size=config.BATCH_SIZE, shuffle=shuffle)
    return dl

In [None]:
# FLAIR DataLoader
test_flair_dl = get_dataloader(df_test, "FLAIR", False)

# T1w DataLoader
test_t1w_dl = get_dataloader(df_test, "T1w", False)

# T1wCE DataLoader
test_t1wce_dl = get_dataloader(df_test, "T1wCE", False)

# T2w DataLoader
test_t2w_dl = get_dataloader(df_test, "T2w", False)

In [None]:
images, labels = next(iter(test_flair_dl))
print(f"Shape of the batch {images.shape}")
print(f"Batch size: {images.shape[0]}")
print(f"Number of channels each image has: {images.shape[1]}")
print(f"Size of each image is: {images.shape[2]}x{images.shape[3]}")
print(f"Depth of each channel/sequence: {images.shape[-1]}")

### 1.3 Model

In [None]:
class GliobCNN(nn.Module):
    def __init__(self):
        super(GliobCNN, self).__init__()
        
        self.efficient = EfficientNet3D.from_name("efficientnet-b0", override_params={'num_classes': 2}, in_channels=1)
        self.efficient._fc = nn.Linear(in_features=self.efficient._fc.in_features, out_features=2, bias=True)
        
    def forward(self, X):
        out = self.efficient(X)
        return out

In [None]:
flair_model = GliobCNN()
flair_model.load_state_dict(torch.load("../models/01-3D-4-ResNet18-pretrained-FLAIR-roc-0.68.pt"))
flair_model.to(device)
t1w_model = GliobCNN()
t1w_model.load_state_dict(torch.load("../models/01-3D-4-ResNet18-pretrained-T1w-roc-0.61.pt"))
t1w_model.to(device)
t1wce_model = GliobCNN()
t1wce_model.load_state_dict(torch.load("../models/01-3D-4-ResNet18-pretrained-T1wCE-roc-0.59.pt"))
t1wce_model.to(device)
t2w_model = GliobCNN()
t2w_model.load_state_dict(torch.load("../models/01-3D-4-ResNet18-pretrained-T2w-roc-0.71.pt"))
t2w_model.to(device)

### 1.4 Testing


In [None]:
def submission(net, test_dl):
    y_hats = None
    idx_list = []
    
    net.eval()
    for xb, idxb in tqdm(test_dl, desc="Testing"):
        xb = xb.to(device)
        y_hat = net(xb)
        y_hat = F.softmax(y_hat)[:,1].cpu().detach().numpy()
        if y_hats is None:
            y_hats = y_hat
            idx_list = idxb.numpy()
        else:
            y_hats = np.concatenate((y_hats, y_hat), axis=0)
            idx_list = np.concatenate((idx_list, idxb.numpy()), axis=0)
    
    return y_hats, idx_list

In [None]:
y_flair_hats, idx_list = submission(flair_model, test_flair_dl)
y_t1w_hats, idx_list = submission(t1w_model, test_t1w_dl)
y_t1wce_hats, idx_list = submission(t1wce_model, test_t1wce_dl)
y_t2w_hats, idx_list = submission(t2w_model, test_t2w_dl)

In [None]:
df = pd.DataFrame({'BraTS21ID': idx_list, 'FLAIR': y_flair_hats, 'T1w': y_t1w_hats, 'T1wCE': y_t1wce_hats,
                   'T2w': y_t2w_hats})
df["MGMT_value"] = (df["FLAIR"] + df["T1w"] + df["T1wCE"] + df["T2w"]) / 4

In [None]:
df_submission = df[["BraTS21ID", "MGMT_value"]]
df_submission.to_csv("submission.csv", index=False)
df_submission

Done