In [1]:
import numpy as np
import pandas as pd
import pydicom 
import math
import cv2
import gc
import glob
import re
import torch
import pytorch_lightning as pl

import torchmetrics
from torch import nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pydicom.pixel_data_handlers.util import apply_voi_lut

IMAGE_SIZE = 256
NUM_IMAGES = 64

In [2]:
# Functions

def load_dicom_image(path, img_size=IMAGE_SIZE, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
        
    if rotate > 0:
        rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        data = cv2.rotate(data, rot_choices[rotate])
        
    data = cv2.resize(data, (img_size, img_size))
    return data


def load_dicom_images_3d(path, num_imgs=NUM_IMAGES, img_size=IMAGE_SIZE, rotate=0):
    files = sorted(glob.glob(f"{path}/*.dcm"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    
    img3d = np.stack([load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]).T 
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis = -1)
        
    if np.min(img3d) < np.max(img3d):
        img3d = img3d - np.min(img3d)
        img3d = img3d / np.max(img3d)
            
    return np.expand_dims(img3d,0)


class Dataset3D(torch.utils.data.Dataset):
    def __init__(self, csv_path, root_path, scan_type='FLAIR', idxs=None):
        # load df
        df = pd.read_csv(csv_path)
        if not idxs is None:
            df = df.iloc[idxs]
        # process data
        self.data = []
        for _, r in df.iterrows():
            bid, label = str(r['BraTS21ID']).zfill(5), int(r['MGMT_value'])
            self.data.append((f"{root_path}/{bid}/{scan_type}", label))
        
        del df
        gc.collect()
            
            
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        path, label = self.data[idx]
        
        img = load_dicom_images_3d(path)
        img = torch.from_numpy(np.moveaxis(img, -1, 1))
        img = img.type(torch.float32)

        return img, label
    
# Model, PTL wrapper
class Custom3DNet(nn.Module):
    def __init__(self):
        super(Custom3DNet, self).__init__()
        self.block1 = self.__gen_block(1, 64, 3, 2, 0.01)
        self.block2 = self.__gen_block(64, 128, 3, 2, 0.02)
        self.block3 = self.__gen_block(128, 256, 3, 2, 0.03)
        self.block4 = self.__gen_block(256, 512, 3, 2, 0.04)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=0.08),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
        
        
    def forward(self, inp):
        x = self.block1(inp)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        #print(nn.AdaptiveAvgPool3d(1)(x).shape)
        return self.classifier(x)
        
        
    def __gen_block(self, in_channels, out_channels, kernel_size, pool_size, dropout=None):
        layers = [
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=pool_size) ,
            nn.BatchNorm3d(num_features=out_channels)
        ]
        if not dropout is None:
            layers += [nn.Dropout3d(p=dropout)]
        return nn.Sequential(*layers)

In [3]:
if __name__ == "__main__":
    test_img_path = "../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/"
    submit_sub_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv'
    model_path = "../input/train-3d-custom-v1/ROC-epoch=15-val_roc_auc=0.63-val_loss=0.67.ckpt"
    
    # loadmodel
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
    state_dict = {k.lstrip('model.'):v for k, v in state_dict.items()}
    model = Custom3DNet()
    model.load_state_dict(state_dict)
    model.eval()
    
    # load data
    submit = pd.read_csv(submit_sub_path)
    
    # preproc. data and predict
    preds = []
    for i in submit['BraTS21ID'].tolist():
        path = test_img_path + str(i).zfill(5) + '/' + 'FLAIR/'
    
        img = load_dicom_images_3d(path)
        img = torch.from_numpy(np.moveaxis(img, -1, 1))
        img = torch.unsqueeze(img, 0)
        img = img.type(torch.float32)
        
        out = model(img).item()
        preds.append(out)
        #print(out)
        gc.collect()
        
    submit['MGMT_value'] = preds
    submit.to_csv('submission.csv',index=False)