In [1]:
import numpy as np
import pandas as pd
import joblib

from tqdm import tqdm

import torch
from torch import nn

import os
import re
import numpy as np 
import glob
import gc
import pandas as pd 
import pydicom as dicom
import seaborn as sns
import joblib
import matplotlib.pyplot as plt

from tqdm import tqdm
from scipy.ndimage import zoom


# return bottleneck and his shortcut
class BigBottleneck(nn.Module):
    def __init__(self, in_channels, kernels, is_shortcut=False):
        super(BigBottleneck, self).__init__()
        k1, k2, k3  = kernels
        first_stride = 2 if is_shortcut else 1
    
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels, k1, 1, first_stride, bias=False),
            nn.BatchNorm2d(k1),
            nn.ReLU(inplace=True),
        
            nn.Conv2d(k1, k2, 3, bias=False, padding=1),
            nn.BatchNorm2d(k2),
            nn.ReLU(inplace=True),
        
            nn.Conv2d(k2, k3, 1, bias=False),
            nn.BatchNorm2d(k3)
        )
    
        self.shortcut = None
        if is_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, k3, 1, 2, bias=False),
                nn.BatchNorm2d(k3)
            )
        else:
            self.shortcut = nn.Identity()
            
        self.last_relu = nn.ReLU()
    
    def forward(self, inp):
        x = self.bottleneck(inp) + self.shortcut(inp)
        return self.last_relu(x)
        
        
# resnet50 model
class Resnet50(nn.Module):
    def __init__(self, in_channels=3):
        super(Resnet50, self).__init__()
        
        self.start = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        # 1, 2, 3, 4 layers
        layer1 = [BigBottleneck(64, [64, 64, 256], is_shortcut=True)] + [BigBottleneck(256, [64, 64, 256]) for _ in range(1, 3)] 
        layer2 = [BigBottleneck(256, [128, 128, 512], is_shortcut=True)] + [BigBottleneck(512, [128, 128, 512]) for _ in range(1, 4)] 
        layer3 = [BigBottleneck(512, [256, 256, 1024], is_shortcut=True)] + [BigBottleneck(1024, [256, 256, 1024]) for _ in range(1, 6)] 
        layer4 = [BigBottleneck(1024, [512, 512, 2048], is_shortcut=True)] + [BigBottleneck(2048, [512, 512, 2048]) for _ in range(1, 3)]
        layers = layer1 + layer2 + layer3 + layer4
        self.body = nn.Sequential(*layers)
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(2048, 1000),
            nn.Linear(1000, 1),
            nn.Sigmoid()
        )
        
    def forward(self, inp):
        x = self.start(inp)
        x = self.body(x)
        return self.classifier(x)

    
def load_source_dicom_line(path):
    t_paths = sorted(
        glob.glob(os.path.join(path, "*")), 
        key=lambda x: int(x[:-4].split("-")[-1]),
    )
    images = []
    for filename in t_paths:
        data = dicom.read_file(filename)
        #if data.pixel_array.max() == 0:
        #    continue
        images.append(data)
        
    return images


def load_dicom_xyz_v2(path):
    _3d = load_source_dicom_line(path)
    # get metadata orientation
    x1, y1, _, x2, y2, _ = [round(j) for j in _3d[0].ImageOrientationPatient]
    cords = [x1, y1, x2, y2]
    position_first, position_last = _3d[0].ImagePositionPatient, _3d[-1].ImagePositionPatient
    # get main metadata
    try:
        spacing_z = np.abs(float(_3d[-1].SliceLocation) - float(_3d[0].SliceLocation)) / len(_3d)
    except AttributeError:
        spacing_z = np.linalg.norm(np.array(_3d[-1].ImagePositionPatient) - np.array(_3d[0].ImagePositionPatient), ord=2) / len(_3d)
    spacing_x_y = _3d[0].PixelSpacing
    
    
    # form tensor
    _3d = [np.expand_dims(i.pixel_array, axis=0) for i in _3d]
    _3d = np.concatenate(_3d)
    # rescale
    _3d = zoom(_3d, (spacing_z/spacing_x_y[0], 1, 1))#first axis - rescaled
    
    # reorder planes if needed and rotate voxel
    if cords == [1, 0, 0, 0]:
        if position_first[1] < position_last[1]:
            _3d = _3d[::-1] 
        _3d = _3d.transpose((1, 0, 2))
        
    elif cords == [0, 1, 0, 0]:
        if position_first[0] < position_last[0]:
            _3d = _3d[::-1]
        _3d = _3d.transpose((1, 2, 0))
        _3d = np.rot90(_3d, 2, axes=(1, 2))
        
    elif cords == [1, 0, 0, 1]:
        if position_first[2] > position_last[2]:
            _3d = _3d[::-1]
        _3d = np.rot90(_3d, 2)

    return _3d
    

def get_pseudo_rgb(img):
    #img = np.clip(img, min_a, max_a)
    img_nan = np.where(img == img.min(), np.nan, img)

    pimages = []
    for i in range(3):
        # get pseudorgb shape
        shape_2_part = list(np.nanmax(img_nan, axis=i).shape)
        shape = [3] + shape_2_part
        
        prgb = np.ones(shape)
        prgb[0, :, :] = np.nanmean(img_nan, axis=i)
        prgb[0, :, :] /= np.nanmax(prgb[0, :, :])
        prgb[1, :, :] = np.nanmax(img_nan, axis=i)
        prgb[1, :, :] /= np.nanmax(prgb[1, :, :])
        prgb[2, :, :] = np.nanstd(img_nan, axis=i)
        prgb[2, :, :] /= np.nanmax(prgb[2, :, :])
        
        prgb = np.swapaxes(prgb, 0, 2)
        prgb = np.swapaxes(prgb, 0, 1)
        prgb = np.where(np.isnan(prgb), 0, prgb)
        prgb = np.clip(prgb, -1, 1)
        pimages.append(prgb)
    
    return pimages


def create_pseudorgb_set(img):
    pimgs = get_pseudo_rgb(img)
    
    new_pimgs = []
    for p in pimgs:
        # find brain bbox
        y_shape, x_shape = p.shape[0], p.shape[1]
        x_min, x_max, y_min, y_max = 1e3, -1, 1e3, -1 
        for yi in range(y_shape):
            for xi in range(x_shape):
                if not np.sum(np.abs(p[yi, xi, :])) == 0:
                    x_min = xi if xi < x_min else x_min
                    y_min = yi if yi < y_min else y_min
                    x_max = xi if xi > x_max else x_max
                    y_max = yi if yi > y_max else y_max
                    
        # place in canvas
        canvas = np.zeros((512, 512, 3))
        x_size, y_size = x_max - x_min, y_max - y_min
        start_x, start_y = (512-x_size)//2, (512-y_size)//2
        canvas[start_y:start_y+y_size, start_x:start_x+x_size, :] = p[y_min:y_max, x_min:x_max, :]

        new_pimgs.append(canvas.astype('float32'))
        
    return new_pimgs


if __name__ == "__main__":
    test_img_path = "../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/"
    submit_sub_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv'
    model_path = "../input/train-resnet50-pseudorgb-v1/TEST-epoch=50-val_roc_auc=0.64-val_loss=0.66.ckpt"
    
    # loadmodel
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']
    state_dict = {k.lstrip('model.'):v for k, v in state_dict.items()}
    model = Resnet50()
    model.load_state_dict(state_dict)
    model.eval()
    
    # load data
    submit = pd.read_csv(submit_sub_path)
    
    # preproc. data and predict
    preds = []
    for i in submit['BraTS21ID'].tolist():
        path = test_img_path + str(i).zfill(5) + '/' + 'FLAIR/'
        
        img = load_dicom_xyz_v2(path)
        inp = create_pseudorgb_set(img)
        
        inp = inp[0]
        inp = torch.from_numpy(np.moveaxis(inp, -1, 0))
        inp = torch.unsqueeze(inp, 0)
        out = model(inp).item()
        preds.append(out)
        
        gc.collect()
        
    submit['MGMT_value'] = preds
    submit.to_csv('submission.csv',index=False)

  keepdims=keepdims)
