In [1]:
import os
import re
import numpy as np 
import glob
import gc
import pandas as pd 
import pydicom as dicom
import seaborn as sns
import joblib
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy.ndimage import zoom

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold

import torch 
import pytorch_lightning as pl
import torchmetrics
from torch import nn
from torch.utils.data import DataLoader
torch.manual_seed(42)


def load_source_dicom_line(path):
    t_paths = sorted(
        glob.glob(os.path.join(path, "*")), 
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    images = []
    for filename in t_paths:
        data = dicom.read_file(filename)
        #if data.pixel_array.max() == 0:
        #    continue
        images.append(data)
        
    return images


def load_dicom_xyz_v2(path):
    _3d = load_source_dicom_line(path)
    # get metadata orientation
    x1, y1, _, x2, y2, _ = [round(j) for j in _3d[0].ImageOrientationPatient]
    cords = [x1, y1, x2, y2]
    position_first, position_last = _3d[0].ImagePositionPatient, _3d[-1].ImagePositionPatient
    # get main metadata
    try:
        spacing_z = np.abs(float(_3d[-1].SliceLocation) - float(_3d[0].SliceLocation)) / len(_3d)
    except AttributeError:
        spacing_z = np.linalg.norm(np.array(_3d[-1].ImagePositionPatient) - np.array(_3d[0].ImagePositionPatient), ord=2) / len(_3d)
    spacing_x_y = _3d[0].PixelSpacing
    #print(spacing_x_y)
    #assert spacing_x_y[0] == spacing_x_y[1]
    # form tensor
    _3d = [np.expand_dims(i.pixel_array, axis=0) for i in _3d]
    _3d = np.concatenate(_3d)
    # rescale
    _3d = zoom(_3d, (spacing_z/spacing_x_y[0], 1, 1))#first axis - rescaled
    
    # reorder planes if needed and rotate voxel
    if cords == [1, 0, 0, 0]:
        if position_first[1] < position_last[1]:
            _3d = _3d[::-1] 
        _3d = _3d.transpose((1, 0, 2))
        
    elif cords == [0, 1, 0, 0]:
        if position_first[0] < position_last[0]:
            _3d = _3d[::-1]
        _3d = _3d.transpose((1, 2, 0))
        _3d = np.rot90(_3d, 2, axes=(1, 2))
        
    elif cords == [1, 0, 0, 1]:
        if position_first[2] > position_last[2]:
            _3d = _3d[::-1]
        _3d = np.rot90(_3d, 2)

    return _3d
    

def get_pseudo_rgb(img):
    #img = np.clip(img, min_a, max_a)
    img_nan = np.where(img == img.min(), np.nan, img)

    pimages = []
    for i in range(3):
        # get pseudorgb shape
        shape_2_part = list(np.nanmax(img_nan, axis=i).shape)
        shape = [3] + shape_2_part
        
        prgb = np.ones(shape)
        prgb[0, :, :] = np.nanmean(img_nan, axis=i)
        prgb[0, :, :] /= np.nanmax(prgb[0, :, :])
        prgb[1, :, :] = np.nanmax(img_nan, axis=i)
        prgb[1, :, :] /= np.nanmax(prgb[1, :, :])
        prgb[2, :, :] = np.nanstd(img_nan, axis=i)
        prgb[2, :, :] /= np.nanmax(prgb[2, :, :])
        
        prgb = np.swapaxes(prgb, 0, 2)
        prgb = np.swapaxes(prgb, 0, 1)
        prgb = np.where(np.isnan(prgb), 0, prgb)
        prgb = np.clip(prgb, -1, 1)
        pimages.append(prgb)
    
    return pimages


def create_pseudorgb_set(img):
    pimgs = get_pseudo_rgb(img)
    
    new_pimgs = []
    for p in pimgs:
        # find brain bbox
        y_shape, x_shape = p.shape[0], p.shape[1]
        x_min, x_max, y_min, y_max = 1e3, -1, 1e3, -1 
        for yi in range(y_shape):
            for xi in range(x_shape):
                if not np.sum(np.abs(p[yi, xi, :])) == 0:
                    x_min = xi if xi < x_min else x_min
                    y_min = yi if yi < y_min else y_min
                    x_max = xi if xi > x_max else x_max
                    y_max = yi if yi > y_max else y_max
                    
        # place in canvas
        canvas = np.zeros((512, 512, 3))
        x_size, y_size = x_max - x_min, y_max - y_min
        start_x, start_y = (512-x_size)//2, (512-y_size)//2
        canvas[start_y:start_y+y_size, start_x:start_x+x_size, :] = p[y_min:y_max, x_min:x_max, :]

        new_pimgs.append(canvas.astype('float32'))
        
    return new_pimgs

# return bottleneck and his shortcut
class BigBottleneck(nn.Module):
    def __init__(self, in_channels, kernels, is_shortcut=False):
        super(BigBottleneck, self).__init__()
        k1, k2, k3  = kernels
        first_stride = 2 if is_shortcut else 1
    
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels, k1, 1, first_stride, bias=False),
            nn.BatchNorm2d(k1),
            nn.ReLU(inplace=True),
        
            nn.Conv2d(k1, k2, 3, bias=False, padding=1),
            nn.BatchNorm2d(k2),
            nn.ReLU(inplace=True),
        
            nn.Conv2d(k2, k3, 1, bias=False),
            nn.BatchNorm2d(k3)
        )
    
        self.shortcut = None
        if is_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, k3, 1, 2, bias=False),
                nn.BatchNorm2d(k3)
            )
        else:
            self.shortcut = nn.Identity()
            
        self.last_relu = nn.ReLU()
    
    def forward(self, inp):
        x = self.bottleneck(inp) + self.shortcut(inp)
        return self.last_relu(x)
        
        
# resnet50 model
class Resnet50(nn.Module):
    def __init__(self, in_channels=3):
        super(Resnet50, self).__init__()
        
        self.start = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        # 1, 2, 3, 4 layers
        layer1 = [BigBottleneck(64, [64, 64, 256], is_shortcut=True)] + [BigBottleneck(256, [64, 64, 256]) for _ in range(1, 3)] 
        layer2 = [BigBottleneck(256, [128, 128, 512], is_shortcut=True)] + [BigBottleneck(512, [128, 128, 512]) for _ in range(1, 4)] 
        layer3 = [BigBottleneck(512, [256, 256, 1024], is_shortcut=True)] + [BigBottleneck(1024, [256, 256, 1024]) for _ in range(1, 6)] 
        layer4 = [BigBottleneck(1024, [512, 512, 2048], is_shortcut=True)] + [BigBottleneck(2048, [512, 512, 2048]) for _ in range(1, 3)]
        layers = layer1 + layer2 + layer3 + layer4
        self.body = nn.Sequential(*layers)
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(2048, 1000),
            nn.Linear(1000, 1),
            nn.Sigmoid()
        )
        
        
    def forward(self, inp):
        x = self.start(inp)
        x = self.body(x)
        return self.classifier(x)
        
        
class LITResnet50(pl.LightningModule):
    def __init__(self, lr=1e-5):
        super().__init__()
        self.lr = lr
        self.model = Resnet50()
        self.loss_f = nn.BCELoss()
        self.roc_auc_f = torchmetrics.AUROC(num_classes=1)
        self.f1_f = torchmetrics.F1()
        self.scores = {
            'val_loss': [],
            'val_f1': [],
            'val_rocauc': [],
            'train_loss': [],
            'train_rocauc': [],
            'train_f1': []
        }
        self.first_epoch = True
        
        
    def forward(self, inp):
        return self.model(inp)
    

    def configure_optimizers(self):
        return torch.optim.Adam(model.parameters(), lr=self.lr)

                                
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
                                
        pred_y = self.model(x)
        pred_y = torch.flatten(pred_y)
        pred_y = pred_y.type(torch.float64)              
        return self.loss_f(pred_y, y)
    
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        
        pred_y = self.model(x)
        pred_y = torch.flatten(pred_y)
        pred_y = pred_y.type(torch.float64)
                                
        loss = self.loss_f(pred_y, y)
        self.log("train_loss", loss, on_epoch=True, logger=True)
        
        return {'loss': loss, 'pred': pred_y, 'y': y}
    
    
    def training_epoch_end(self, outputs):
        preds, targs = [], []
        for out in outputs:
            preds.append(out['pred'])
            targs.append(out['y'])
            
        preds = torch.cat(tensors=preds)
        targs = torch.cat(tensors=targs)
        
        rocauc = self.roc_auc_f(preds, targs.type(torch.int))
        loss = self.loss_f(preds, targs)
        f1 = self.f1_f(preds, targs.type(torch.int))
        
        self.log("train_roc_auc", rocauc, logger=True, on_epoch=True)
        
        self.scores['train_loss'].append(loss)
        self.scores['train_f1'].append(f1)
        self.scores['train_rocauc'].append(rocauc)

    
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
               
        pred_y = self.model(x)
        pred_y = torch.flatten(pred_y)
        pred_y = pred_y.type(torch.float64)
           
        loss = self.loss_f(pred_y, y)
        self.log("val_loss", loss, on_epoch=True, logger=True)
        
        return {
            'y': y, 
            'pred': pred_y
        }
    
    
    def validation_epoch_end(self, outputs):
        preds, targs = [], []
        for out in outputs:
            preds.append(out['pred'])
            targs.append(out['y'])
            
        preds = torch.cat(tensors=preds)
        targs = torch.cat(tensors=targs)
        
        rocauc = self.roc_auc_f(preds, targs.type(torch.int))
        loss = self.loss_f(preds, targs)
        f1 = self.f1_f(preds, targs.type(torch.int))
        
#         print("Preds shape: ", preds.shape)
#         print("Targs shape: ", targs.shape)
        
#         print("rocauc ptl: ", rocauc)
#         print("rocauc sklearn: ", roc_auc_score(list(map(int, torch.flatten(targs).tolist())), torch.flatten(preds).tolist()))
        
        self.log("val_roc_auc", rocauc, logger=True, on_epoch=True, prog_bar=True)
        self.log("val_loss", loss, logger=True, on_epoch=True, prog_bar=True)
        if not self.first_epoch:# because method call in start training
            self.scores['val_loss'].append(loss)
            self.scores['val_f1'].append(f1)
            self.scores['val_rocauc'].append(rocauc)
        self.first_epoch = False
   
    
class PseudoRGBDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, pimgs, idxs=None):
        # load df
        df = pd.read_csv(csv_path)
        if not idxs is None:
            df = df.iloc[idxs]
        # process data
        self.data = []
        for _, r in df.iterrows():
            bid, target = r['BraTS21ID'], r['MGMT_value']
            if pimgs[bid] == -1:
                continue
            self.data += [(torch.from_numpy(np.moveaxis(pimgs[bid][i], -1, 0)), float(target)) for i in range(1)]
            gc.collect()
                
        del pimgs, df
        gc.collect()
            
            
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, idx):
        return self.data[idx]

In [2]:
if __name__ == "__main__":
    test_img_path = "../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/"
    submit_sub_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv'
    
    models_path = [ 
        "../input/train-resnet50-pseudorgb-v1-3f/0LOSS-epoch=4-val_roc_auc=0.54-val_loss=0.69.ckpt",
        "../input/train-resnet50-pseudorgb-v1-3f/1LOSS-epoch=0-val_roc_auc=0.42-val_loss=0.69.ckpt",
        "../input/train-resnet50-pseudorgb-v1-3f/2LOSS-epoch=0-val_roc_auc=0.47-val_loss=0.69.ckpt"
    ]
    
    # load models
    models = []
    for m in models_path:
        state_dict = torch.load(m, map_location=torch.device('cpu'))['state_dict']
        state_dict = {k.lstrip('model.'):v for k, v in state_dict.items()}
        model = Resnet50()
        model.load_state_dict(state_dict)
        model.eval()
        models.append(model)
    
    # load data
    submit = pd.read_csv(submit_sub_path)
    
    # preproc. data and predict
    preds = []
    for i in submit['BraTS21ID'].tolist():
        path = test_img_path + str(i).zfill(5) + '/' + 'FLAIR/'
        
        img = load_dicom_xyz_v2(path)
        inp = create_pseudorgb_set(img)
        
        inp = inp[0]
        inp = torch.from_numpy(np.moveaxis(inp, -1, 0))
        inp = torch.unsqueeze(inp, 0)
        
        out = 0.
        for m in models:
            out += m(inp).item()
        print(out/3)
        preds.append(out)
        
        gc.collect()
        
    submit['MGMT_value'] = preds
    submit.to_csv('submission.csv',index=False)

0.5461922685305277
0.5422345399856567
0.5589451591173807
0.5501331686973572
0.5557267069816589
0.5437542994817098
0.5546454985936483
0.5539363423983256
0.5392972032229105
0.547537644704183


  keepdims=keepdims)


0.5153825879096985
0.5246059695879618
0.5438394149144491
0.515596608320872
0.5350794593493143
0.5222155054410299
0.49738632639249164
0.49118103583653766
0.5162718097368876
0.5125235716501871
0.4949304560820262
0.5119198560714722
0.509886751572291
0.48435888687769574
0.4912083049615224
0.49034929275512695
0.49961047371228534
0.4983927309513092
0.48967164754867554
0.5077093144257864
0.5142026742299398
0.5115422407786051
0.5045379896958669
0.5134802261988322
0.535989006360372
0.5115026632944742
0.4996361533800761
0.4861648579438527
0.5087830026944479
0.4871824085712433
0.48706498742103577
0.500799278418223
0.5054487387339274
0.5012368559837341
0.48556403319040936
0.541254997253418
0.5175691246986389
0.5441183646519979
0.5358957250912985
0.5457772413889567
0.5569382508595785
0.5433076620101929
0.5552479028701782
0.5461958249409994
0.5500989754994711
0.5494295557339987
0.537970503171285
0.5475536982218424
0.5579876899719238
0.5390849113464355
0.5045947730541229
0.5123843550682068
0.51242371