In [1]:
import numpy as np
import pandas as pd
import pydicom 
import math
import cv2
import gc
import glob
import re
import torch
import pytorch_lightning as pl

import torchmetrics
from torch import nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pydicom.pixel_data_handlers.util import apply_voi_lut

IMAGE_SIZE = 256
NUM_IMAGES = 64

In [2]:
# Functions

def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
        
    data = cv2.resize(data, (img_size, img_size))
    return data


def load_dicom_images_3d(path, num_imgs=NUM_IMAGES, img_size=IMAGE_SIZE, rotate=0):
    files = sorted(glob.glob(f"{path}/*.dcm"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    
    img3d = np.stack([load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]).T 
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis = -1)
        
    if np.min(img3d) < np.max(img3d):
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
            
    return np.expand_dims(img3d,0)


class Dataset3D(torch.utils.data.Dataset):
    def __init__(self, csv_path, root_path, scan_type='FLAIR', idxs=None):
        # load df
        df = pd.read_csv(csv_path)
        if not idxs is None:
            df = df.iloc[idxs]
        # process data
        self.data = []
        for _, r in df.iterrows():
            bid, label = str(r['BraTS21ID']).zfill(5), int(r['MGMT_value'])
            self.data.append((f"{root_path}/{bid}/{scan_type}", label))
        
        del df
        gc.collect()
            
            
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        path, label = self.data[idx]
        
        img = load_dicom_images_3d(path)
        img = torch.from_numpy(np.moveaxis(img, -1, 1))
        img = img.type(torch.float32)

        return img, label

In [3]:
# Model, PTL wrapper
class Custom3DNet(nn.Module):
    def __init__(self):
        super(Custom3DNet, self).__init__()
        self.block1 = self.__gen_block(1, 64, 3, 2, 0.01)
        self.block2 = self.__gen_block(64, 128, 3, 2, 0.02)
        self.block3 = self.__gen_block(128, 256, 3, 2, 0.03)
        self.block4 = self.__gen_block(256, 512, 3, 2, 0.04)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=0.08),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
        
        
    def forward(self, inp):
        x = self.block1(inp)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        #print(nn.AdaptiveAvgPool3d(1)(x).shape)
        return self.classifier(x)
        
        
    def __gen_block(self, in_channels, out_channels, kernel_size, pool_size, dropout=None):
        layers = [
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=pool_size) ,
            nn.BatchNorm3d(num_features=out_channels)
        ]
        if not dropout is None:
            layers += [nn.Dropout3d(p=dropout)]
        return nn.Sequential(*layers)
        
        
class LITCustom3DNet(pl.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()
        self.lr = lr
        self.model = Custom3DNet()
        self.loss_f = nn.BCELoss()
        self.roc_auc_f = torchmetrics.AUROC(num_classes=1)
        self.f1_f = torchmetrics.F1()
        self.scores = {
            'val_loss': [],
            'val_f1': [],
            'val_rocauc': [],
            'train_loss': [],
            'train_rocauc': [],
            'train_f1': []
        }
        self.first_epoch = True
        
        
    def forward(self, inp):
        return self.model(inp)
    

    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=self.lr)

    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        
        y = y.type(torch.float32)
        pred_y = self.model(x)
        pred_y = torch.flatten(pred_y)
        pred_y = pred_y.type(torch.float32)
                                
        loss = self.loss_f(pred_y, y)
        self.log("train_loss", loss, on_epoch=True, logger=True)
        
        return {'loss': loss, 'pred': pred_y, 'y': y}
    
    
    def training_epoch_end(self, outputs):
        preds, targs = [], []
        for out in outputs:
            preds.append(out['pred'])
            targs.append(out['y'])
            
        preds = torch.cat(tensors=preds)
        targs = torch.cat(tensors=targs)
        
        rocauc = self.roc_auc_f(preds, targs.type(torch.int))
        loss = self.loss_f(preds, targs)
        f1 = self.f1_f(preds, targs.type(torch.int))
        
        self.log("train_roc_auc", rocauc, logger=True, on_epoch=True)
        
        self.scores['train_loss'].append(loss)
        self.scores['train_f1'].append(f1)
        self.scores['train_rocauc'].append(rocauc)

    
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        y = y.type(torch.float32)
        pred_y = self.model(x)
        pred_y = torch.flatten(pred_y)
        pred_y = pred_y.type(torch.float32)
           
        loss = self.loss_f(pred_y, y)
        self.log("val_loss", loss, on_epoch=True, logger=True)
        
        return {
            'y': y, 
            'pred': pred_y
        }
    
    
    def validation_epoch_end(self, outputs):
        preds, targs = [], []
        for out in outputs:
            preds.append(out['pred'])
            targs.append(out['y'])
            
        preds = torch.cat(tensors=preds)
        targs = torch.cat(tensors=targs)
        
        rocauc = self.roc_auc_f(preds, targs.type(torch.int))
        loss = self.loss_f(preds, targs)
        f1 = self.f1_f(preds, targs.type(torch.int))
        
        self.log("val_roc_auc", rocauc, logger=True, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, logger=True, on_epoch=True, prog_bar=True)
        if not self.first_epoch:# because method call in start training
            self.scores['val_loss'].append(loss)
            self.scores['val_f1'].append(f1)
            self.scores['val_rocauc'].append(rocauc)
        self.first_epoch = False

In [4]:
RANDOM_SEED = 42
TEST_SIZE = 0.3
BATCH_SIZE = 2
csv_path = "../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv"
root_path = "../input/rsna-miccai-brain-tumor-radiogenomic-classification/train"

df = pd.read_csv(csv_path)
X, y = list(range(len(df))), df['MGMT_value'].tolist()
assert len(X) == len(y)
train_idx, test_idx, _, _ = train_test_split(X, y, test_size=TEST_SIZE, random_state=RANDOM_SEED, stratify=y)
print(f"Test idx: {test_idx} \nLen: {len(test_idx)}")

train_data, test_data = Dataset3D(csv_path, root_path, idxs=train_idx), Dataset3D(csv_path, root_path, idxs=test_idx)
train_data = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
test_data = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    
del df, X, y
gc.collect()

Test idx: [342, 260, 284, 41, 49, 556, 38, 262, 80, 376, 384, 322, 527, 254, 343, 173, 422, 31, 310, 90, 112, 328, 238, 430, 380, 47, 88, 234, 202, 160, 321, 547, 55, 404, 582, 54, 102, 518, 324, 568, 576, 40, 494, 36, 386, 114, 433, 355, 515, 125, 172, 570, 228, 309, 145, 7, 201, 198, 106, 436, 21, 146, 16, 209, 435, 58, 406, 261, 320, 480, 210, 349, 306, 425, 464, 387, 109, 474, 179, 11, 345, 194, 29, 275, 107, 305, 287, 523, 131, 33, 330, 428, 20, 333, 331, 256, 548, 50, 336, 447, 192, 354, 516, 57, 485, 545, 230, 506, 481, 95, 373, 521, 540, 531, 353, 30, 200, 519, 424, 113, 483, 427, 487, 409, 416, 211, 164, 4, 575, 174, 561, 60, 577, 315, 492, 361, 282, 381, 419, 566, 185, 456, 216, 501, 205, 318, 388, 257, 62, 441, 459, 43, 105, 258, 86, 176, 446, 78, 554, 449, 584, 539, 389, 130, 69, 563, 526, 126, 91, 377, 98, 6, 288, 273, 121, 497] 
Len: 176


0

In [5]:
# axial
GPUS = 1
EPOCHS = 30
LR = 3e-5
    
# train
model = LITCustom3DNet(lr=LR)
save_callback_val_loss = pl.callbacks.ModelCheckpoint(
    dirpath="./",
    filename='LOSS-{epoch}-{val_roc_auc:.2f}-{val_loss:.2f}',
    monitor='val_loss',
    mode='min',
    save_on_train_epoch_end=False,
    save_top_k=2
)
save_callback_roc_auc = pl.callbacks.ModelCheckpoint(
    dirpath="./",
    filename='ROC-{epoch}-{val_roc_auc:.2f}-{val_loss:.2f}',
    monitor='val_roc_auc',
    mode='max',
    save_on_train_epoch_end=False,
    save_top_k=2
)
trainer = pl.Trainer(max_epochs=EPOCHS, gpus=GPUS, callbacks=[save_callback_val_loss, save_callback_roc_auc])
trainer.fit(model, train_data, test_data)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
2021-09-18 13:29:14.339036: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


Validation sanity check: 0it [00:00, ?it/s]

  f"Your {mode}_dataloader has `shuffle=True`, it is best practice to turn"


Training: -1it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]