In [2]:
import math, glob, random, os, time
import pydicom
import cv2
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data as torch_data
from sklearn import model_selection as sk_model_selection
from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

### Config

In [3]:
class Config():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    size = 256
    num_images = 64
    base_path = Path('/home/RSNA_MICCAI_Brain_Tumor/data')
    assert base_path.exists(), f'{base_path} does not exist'
    models_path = Path('/home/RSNA_MICCAI_Brain_Tumor/models')
    assert models_path.exists()
    seed = 42
    test_size = 0.1
    clahe=False
    mri_types = ['FLAIR','T1w','T1wCE','T2w']
    batch_size = 4
    num_workers = 7
    epochs = 18
    extra_check_epochs = [14,15]
    check_frequency = 4
    
cfg = Config()

In [4]:
cfg.device

device(type='cuda')

### Functions for loading images

In [5]:
def load_dicom_image(path, img_size=cfg.size, clahe=False):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    max_data = np.max(data)
    min_data = np.min(data)
    if min_data == max_data:
        data = np.zeros((img_size,img_size))
        return data
    data = data - min_data
    if max_data != 0:
        data = data / (max_data - min_data) 
    
    #data = (data * 255).astype(np.uint8)
    data = cv2.resize(data, (img_size, img_size))
    return data

In [6]:
def load_dicom_images_3d(scan_id, num_imgs=cfg.num_images, img_size=cfg.size, mri_type="FLAIR", split="train", clahe=False):

    files = sorted(glob.glob(f"{cfg.base_path}/{split}/{scan_id}/{mri_type}/*.dcm"))
    assert len(files) > 0
    
    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    img3d = np.stack([load_dicom_image(f, clahe=False) for f in files[p1:p2]]).T 
    if img3d.shape[-1] < num_imgs:
        n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
        img3d = np.concatenate((img3d,  n_zero), axis = -1)
    img3d = img3d[:,:,:]
    # Transforms
#     img3d = flip3D(brightness(img3d))
#     if convert_vertical_flag:
#         img3d_vert = convert_vertical(img3d, size=img3d.shape[0])
#         img3d = np.concatenate([img3d, img3d_vert], axis=-1)
    return np.expand_dims(img3d[:,:,:], 0)

In [7]:
sample_image_0 = load_dicom_images_3d("00000")
sample_image_0.shape

(1, 256, 256, 64)

### Seeding

In [8]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

set_seed(cfg.seed)

### Tabular Data

In [9]:
train_df = pd.read_csv(f"{cfg.base_path}/train_labels.csv")
train_df

Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1
...,...,...
580,1005,1
581,1007,1
582,1008,1
583,1009,0


In [10]:
train_df = train_df.drop(train_df[train_df['BraTS21ID'].isin([109, 123, 709])].index)
train_df

Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1
...,...,...
580,1005,1
581,1007,1
582,1008,1
583,1009,0


### Train / Test Split

In [11]:
df_train, df_valid = sk_model_selection.train_test_split(
    train_df, 
    test_size=cfg.test_size, 
    random_state=42, 
    stratify=train_df["MGMT_value"],
)

In [12]:
len(df_valid) / len(df_train)

0.11281070745697896

In [13]:
df_valid = df_valid.append(pd.DataFrame([{'BraTS21ID': 1, 'MGMT_value': 1}, {'BraTS21ID': 13, 'MGMT_value': 1}]))

In [14]:
len(df_valid) / len(df_train)

0.11663479923518165

### Dataset

In [15]:
class Dataset(torch_data.Dataset):
    def __init__(self, paths, targets=None, mri_type=None, label_smoothing=0.01, split="train"):
        self.paths = paths
        self.targets = targets
        self.mri_type = mri_type
        self.label_smoothing = label_smoothing
        self.split = split
          
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, index):
        scan_id = self.paths[index]
        if self.targets is None:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split, clahe=cfg.clahe)
        else:
            data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train", clahe=cfg.clahe)

        if self.targets is None:
            return {"X": torch.tensor(data).float(), "id": scan_id}
        else:
            y = torch.tensor(abs(self.targets[index]-self.label_smoothing), dtype=torch.float)
            return {"X": torch.tensor(data).float(), "y": y}


In [16]:
df_train.loc[:,"MRI_Type"] = 'FLAIR'
sample_ds = Dataset(df_train['BraTS21ID'], df_train['MGMT_value'], df_train['MRI_Type'])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[key] = value
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_column(loc, value, pi)


In [17]:
sample_ds[0]['X'].shape

torch.Size([1, 256, 256, 64])

### Model

In [18]:
def get_inplanes():
    return [64, 128, 256, 512]

def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)

In [19]:
def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv1x1x1(in_planes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 block_inplanes,
                 n_input_channels=3,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 shortcut_type='B',
                 widen_factor=1.0,
                 n_classes=400):
        super().__init__()

        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        self.conv1 = nn.Conv3d(n_input_channels,
                               self.in_planes,
                               kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2),
                               padding=(conv1_t_size // 2, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       shortcut_type,
                                       stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
                                out.size(3), out.size(4))
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

        out = torch.cat([out.data, zero_pads], dim=1)

        return out

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(self._downsample_basic_block,
                                     planes=planes * block.expansion,
                                     stride=stride)
            else:
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def generate_model(model_depth, **kwargs):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
    elif model_depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
    elif model_depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
    elif model_depth == 152:
        model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
    elif model_depth == 200:
        model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)

    return model

In [20]:
class AttentionHead(nn.Module):
    
    def __init__(self, in_features, hidden_dim, num_targets):
        super().__init__()
        self.in_features = in_features
        
        self.hidden_layer = nn.Linear(in_features, hidden_dim)
        self.final_layer = nn.Linear(hidden_dim, num_targets)
        self.out_features = hidden_dim
        
    def forward(self, features):
        att = torch.tanh(self.hidden_layer(features))
        score = self.final_layer(att)
        attention_weights = torch.softmax(score, dim=1)
        context_vector = torch.mean(attention_weights * features, dim=1) 
        return context_vector.unsqueeze(1)

In [21]:
def create_model():
    model = generate_model(34, n_input_channels=1, n_classes=1)
    return model.to(cfg.device)

In [22]:
resnet = create_model()

In [23]:
# sample_data = torch.randn([8, 1, 64, 256, 256]).to(cfg.device)
resnet(sample_data)

In [25]:
del sample_data
import gc
gc.collect()
torch.cuda.empty_cache()

### Trainer

In [28]:
class Trainer:
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion

        self.best_valid_score = np.inf
        self.n_patience = 0
        self.lastmodel = None
        
    def fit(self, epochs, train_loader, valid_loader, save_path, patience):  
        tbar = tqdm(range(1, epochs + 1), total=epochs)
#         self.lr_sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=int(len(train_loader) / 2), T_mult=1, eta_min=0.0001, last_epoch=-1)
        self.lr_sched = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=1e-3, steps_per_epoch=len(train_loader), epochs=epochs)
        self.lrs = []
        
        for n_epoch in tbar:
            
            self.info_message("EPOCH: {}", n_epoch)
            tbar.set_description(f'EPOCH: {n_epoch}')
            
            train_loss, train_time = self.train_epoch(train_loader, valid_loader, n_epoch, save_path)
            valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
            
            self.info_message(
                "[Epoch Train: {}] loss: {:.4f}, time: {:.2f} s            ",
                n_epoch, train_loss, train_time
            )
            
            self.print_valid_message(valid_loss, valid_auc, valid_time, n_epoch)
            
            self.check_and_save(valid_loss, valid_auc, valid_time, n_epoch, save_path)
            
            if self.n_patience >= patience:
                self.info_message("\nValid auc didn't improve last {} epochs.", patience)
                break
        plt.plot(self.lrs)
        
    def check_and_save(self, valid_loss, valid_auc, valid_time, n_epoch, save_path):
        # if True:
        # if self.best_valid_score < valid_auc: 
        if self.best_valid_score > valid_loss and valid_auc > 0.5: 
            self.save_model(n_epoch, save_path, valid_loss, valid_auc)
            self.info_message(
                 "auc improved from {:.4f} to {:.4f}. Saved model to '{}'", 
                self.best_valid_score, valid_loss, self.lastmodel
            )
            self.best_valid_score = valid_loss
            self.n_patience = 0
        else:
            self.n_patience += 1
            
    def print_valid_message(self, valid_loss, valid_auc, valid_time, n_epoch):
        self.info_message(
            "[Epoch Valid: {}] loss: {:.4f}, auc: {:.4f}, time: {:.2f} s",
            n_epoch, valid_loss, valid_auc, valid_time
        )
            
    def train_epoch(self, train_loader, valid_loader, n_epoch, save_path):
        self.model.train()
        t = time.time()
        sum_loss = 0

        tbar = tqdm(enumerate(train_loader, 1), total=len(train_loader))
        
        scaler = torch.cuda.amp.GradScaler() # fp16
        for step, batch in tbar:
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            self.optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets)
                
#             loss.backward()
            scaler.scale(loss).backward()
            scaler.step(self.optimizer)
            scaler.update()

            sum_loss += loss.detach().item()
#             torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
#             self.optimizer.step()
            self.lr_sched.step()
            current_lr = self.optimizer.param_groups[0]["lr"]
            self.lrs.append(
                current_lr
            )
            
            message = 'Train Step {}/{}, train_loss: {:.4f}'
            self.info_message(message, step, len(train_loader), sum_loss/step, end="\r")
            
            if n_epoch in cfg.extra_check_epochs:
                if step % cfg.check_frequency == 0:
                    valid_loss, valid_auc, valid_time = self.valid_epoch(valid_loader)
                    self.check_and_save(valid_loss, valid_auc, valid_time, n_epoch, save_path)
                    self.print_valid_message(valid_loss, valid_auc, valid_time, n_epoch)
            
        
        return sum_loss/len(train_loader), int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        sum_loss = 0
        y_all = []
        outputs_all = []

        tbar = tqdm(enumerate(valid_loader, 1), total=len(valid_loader))
        for step, batch in tbar:
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)

                outputs = torch.sigmoid(self.model(X).squeeze(1))
                loss = self.criterion(outputs, targets)
                print('outputs', outputs.mean().item(), outputs.std().item())
                print('targets', targets.mean().item(), targets.std().item())
                
                sum_loss += loss.detach().item()
                y_all.extend(batch["y"].tolist())
                outputs_all.extend(outputs.tolist())

            message = 'Valid Step {}/{}, valid_loss: {:.4f}'
            self.info_message(message, step, len(valid_loader), sum_loss/step, end="\r")
            
        y_all = [1 if x > 0.5 else 0 for x in y_all]
        auc = roc_auc_score(y_all, outputs_all)
        
        return sum_loss/len(valid_loader), auc, int(time.time() - t)
    
    def save_model(self, n_epoch, save_path, loss, auc):
        self.lastmodel = str(cfg.models_path/f"{save_path}-e{n_epoch}-loss{loss:.3f}-auc{auc:.3f}.pth")
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            self.lastmodel,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)

In [29]:
def loss_func(input, target):
#     return F.binary_cross_entropy_with_logits(input, target) * 0.5 + F.mse_loss(input, target) * 0.5
#     return F.binary_cross_entropy_with_logits(input, target) * 0.2 + dice_loss(input, target) * 0.8
    return F.binary_cross_entropy_with_logits(input, target)

In [30]:
def train_mri_type(df_train, df_valid, mri_type):
    if mri_type=="all":
        train_list = []
        valid_list = []
        for mri_type in mri_types:
            df_train.loc[:,"MRI_Type"] = mri_type
            train_list.append(df_train.copy())
            df_valid.loc[:,"MRI_Type"] = mri_type
            valid_list.append(df_valid.copy())

        df_train = pd.concat(train_list)
        df_valid = pd.concat(valid_list)
    else:
        df_train.loc[:,"MRI_Type"] = mri_type
        df_valid.loc[:,"MRI_Type"] = mri_type

    print(df_train.shape, df_valid.shape)
    display(df_train.head())
    
    train_data_retriever = Dataset(
        df_train["BraTS21ID"].values, 
        df_train["MGMT_value"].values, 
        df_train["MRI_Type"].values
    )

    valid_data_retriever = Dataset(
        df_valid["BraTS21ID"].values, 
        df_valid["MGMT_value"].values,
        df_valid["MRI_Type"].values
    )

    train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
    )

    valid_loader = torch_data.DataLoader(
        valid_data_retriever, 
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
    )

    model = create_model()

    #checkpoint = torch.load("best-model-all-auc0.555.pth")
    #model.load_state_dict(checkpoint["model_state_dict"])

    #print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    #optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    criterion = loss_func

    trainer = Trainer(
        model, 
        cfg.device, 
        optimizer, 
        criterion
    )

    history = trainer.fit(
        cfg.epochs, 
        train_loader, 
        valid_loader, 
        f"{mri_type}", 
        cfg.epochs + 100,
    )
    
    return trainer.lastmodel

In [31]:
!rm -rf {str(cfg.models_path)}/*.pth

In [32]:
modelfiles = None

if not modelfiles:
    modelfiles = [train_mri_type(df_train, df_valid, m) for m in cfg.mri_types]
    print(modelfiles)

(523, 3) (61, 3)


Unnamed: 0,BraTS21ID,MGMT_value,MRI_Type
549,803,0,FLAIR
353,520,1,FLAIR
272,399,0,FLAIR
137,206,0,FLAIR
292,423,0,FLAIR


HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

EPOCH: 1


HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.7111


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5619516968727112 0.00421748124063015
targets 0.5 0.5658032894134521
outputs 0.5593413710594177 0.0005276636802591383
targets 0.5 0.5658032894134521
outputs 0.6018283367156982 0.05345012992620468
targets 0.7450000047683716 0.49000000953674316
outputs 0.5599985718727112 0.0003556534938979894
targets 0.5 0.5658032894134521
outputs 0.559741199016571 0.0009763562120497227
targets 0.7450000047683716 0.49000000953674316
outputs 0.5599250793457031 0.0001746349298628047
targets 0.2549999952316284 0.49000000953674316
outputs 0.560142993927002 0.0007714470848441124
targets 0.5 0.5658032894134521
outputs 0.559507429599762 0.0007505397079512477
targets 0.7450000047683716 0.49000000953674316
outputs 0.5598040819168091 0.0010167293949052691
targets 0.5 0.5658032894134521
outputs 0.5596222877502441 0.0001441489439457655
targets 0.2549999952316284 0.49000000953674316
outputs 0.5982738733291626 0.0776875913143158
targets 0.2549999952316284 0.49000000953674316
outputs 0.5596616268157959 0.00070

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.7003


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.9810956716537476 0.03382173553109169
targets 0.5 0.5658032894134521
outputs 0.9611643552780151 0.033620547503232956
targets 0.5 0.5658032894134521
outputs 0.9726222157478333 0.0518275648355484
targets 0.7450000047683716 0.49000000953674316
outputs 0.9617067575454712 0.04799306020140648
targets 0.5 0.5658032894134521
outputs 0.9314550161361694 0.08098644018173218
targets 0.7450000047683716 0.49000000953674316
outputs 0.9579074382781982 0.03122856467962265
targets 0.2549999952316284 0.49000000953674316
outputs 0.8811644315719604 0.19997639954090118
targets 0.5 0.5658032894134521
outputs 0.7942914366722107 0.23394252359867096
targets 0.7450000047683716 0.49000000953674316
outputs 0.8916712999343872 0.2085057497024536
targets 0.5 0.5658032894134521
outputs 0.8505990505218506 0.16245783865451813
targets 0.2549999952316284 0.49000000953674316
outputs 0.8866868615150452 0.18741418421268463
targets 0.2549999952316284 0.49000000953674316
outputs 0.9709435701370239 0.021835803985595703

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6989


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.6984723806381226 0.09447672963142395
targets 0.5 0.5658032894134521
outputs 0.6009507179260254 0.14203420281410217
targets 0.5 0.5658032894134521
outputs 0.6978024244308472 0.17459382116794586
targets 0.7450000047683716 0.49000000953674316
outputs 0.6642757058143616 0.03914386034011841
targets 0.5 0.5658032894134521
outputs 0.6214154958724976 0.13848333060741425
targets 0.7450000047683716 0.49000000953674316
outputs 0.6423453688621521 0.06057171896100044
targets 0.2549999952316284 0.49000000953674316
outputs 0.6996880769729614 0.13289178907871246
targets 0.5 0.5658032894134521
outputs 0.5820797681808472 0.0943070501089096
targets 0.7450000047683716 0.49000000953674316
outputs 0.6369189620018005 0.1406155228614807
targets 0.5 0.5658032894134521
outputs 0.6340330839157104 0.12808674573898315
targets 0.2549999952316284 0.49000000953674316
outputs 0.6811197996139526 0.19504819810390472
targets 0.2549999952316284 0.49000000953674316
outputs 0.6279851794242859 0.11497769504785538
t

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6960


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5673908591270447 0.06521127372980118
targets 0.5 0.5658032894134521
outputs 0.563346803188324 0.08309292793273926
targets 0.5 0.5658032894134521
outputs 0.610433042049408 0.1522703468799591
targets 0.7450000047683716 0.49000000953674316
outputs 0.5332552790641785 0.012994904071092606
targets 0.5 0.5658032894134521
outputs 0.5642657279968262 0.082792267203331
targets 0.7450000047683716 0.49000000953674316
outputs 0.5284124612808228 0.005423514172434807
targets 0.2549999952316284 0.49000000953674316
outputs 0.5886675715446472 0.0743265226483345
targets 0.5 0.5658032894134521
outputs 0.534615159034729 0.02536025457084179
targets 0.7450000047683716 0.49000000953674316
outputs 0.5697848796844482 0.08784198760986328
targets 0.5 0.5658032894134521
outputs 0.572883129119873 0.09925337135791779
targets 0.2549999952316284 0.49000000953674316
outputs 0.6119500994682312 0.1582871377468109
targets 0.2549999952316284 0.49000000953674316
outputs 0.5412476062774658 0.03244483098387718
target

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6945


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5468940734863281 0.009833499789237976
targets 0.5 0.5658032894134521
outputs 0.5500358939170837 0.012979181483387947
targets 0.5 0.5658032894134521
outputs 0.5641876459121704 0.0401482991874218
targets 0.7450000047683716 0.49000000953674316
outputs 0.5422877073287964 0.0008438414661213756
targets 0.5 0.5658032894134521
outputs 0.5503264665603638 0.012121804058551788
targets 0.7450000047683716 0.49000000953674316
outputs 0.5444986820220947 0.002406473970040679
targets 0.2549999952316284 0.49000000953674316
outputs 0.5526191592216492 0.01009565033018589
targets 0.5 0.5658032894134521
outputs 0.5440539717674255 0.0006733246264047921
targets 0.7450000047683716 0.49000000953674316
outputs 0.5501317977905273 0.014068924821913242
targets 0.5 0.5658032894134521
outputs 0.5511176586151123 0.01431281864643097
targets 0.2549999952316284 0.49000000953674316
outputs 0.5628861784934998 0.03901660442352295
targets 0.2549999952316284 0.49000000953674316
outputs 0.5462760925292969 0.005878181

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6955


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5453540086746216 0.022168705239892006
targets 0.5 0.5658032894134521
outputs 0.553899884223938 0.026222404092550278
targets 0.5 0.5658032894134521
outputs 0.5673173666000366 0.060350220650434494
targets 0.7450000047683716 0.49000000953674316
outputs 0.5363419055938721 0.0030073223169893026
targets 0.5 0.5658032894134521
outputs 0.5532737374305725 0.026950068771839142
targets 0.7450000047683716 0.49000000953674316
outputs 0.5377075672149658 0.0032907300628721714
targets 0.2549999952316284 0.49000000953674316
outputs 0.5525811910629272 0.016880212351679802
targets 0.5 0.5658032894134521
outputs 0.5402600765228271 0.0010188436135649681
targets 0.7450000047683716 0.49000000953674316
outputs 0.5536841154098511 0.031131666153669357
targets 0.5 0.5658032894134521
outputs 0.5594863891601562 0.03914623335003853
targets 0.2549999952316284 0.49000000953674316
outputs 0.572860836982727 0.06745150685310364
targets 0.2549999952316284 0.49000000953674316
outputs 0.5390573740005493 0.0019786

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6927


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5782653093338013 0.0833451896905899
targets 0.5 0.5658032894134521
outputs 0.5702702403068542 0.0840858668088913
targets 0.5 0.5658032894134521
outputs 0.6087316870689392 0.1330646127462387
targets 0.7450000047683716 0.49000000953674316
outputs 0.5357388257980347 0.0104991989210248
targets 0.5 0.5658032894134521
outputs 0.5715285539627075 0.0872674360871315
targets 0.7450000047683716 0.49000000953674316
outputs 0.5289150476455688 0.0008937720558606088
targets 0.2549999952316284 0.49000000953674316
outputs 0.585040807723999 0.06752722710371017
targets 0.5 0.5658032894134521
outputs 0.5377939939498901 0.019214056432247162
targets 0.7450000047683716 0.49000000953674316
outputs 0.5809186697006226 0.10075946897268295
targets 0.5 0.5658032894134521
outputs 0.5831549167633057 0.11086244136095047
targets 0.2549999952316284 0.49000000953674316
outputs 0.6129534840583801 0.15498292446136475
targets 0.2549999952316284 0.49000000953674316
outputs 0.5311896204948425 0.006401554681360722
t

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6920


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5448494553565979 0.03520948812365532
targets 0.5 0.5658032894134521
outputs 0.549673855304718 0.04566211253404617
targets 0.5 0.5658032894134521
outputs 0.5710789561271667 0.0783141478896141
targets 0.7450000047683716 0.49000000953674316
outputs 0.5275033712387085 0.0018464099848642945
targets 0.5 0.5658032894134521
outputs 0.549824059009552 0.04642918333411217
targets 0.7450000047683716 0.49000000953674316
outputs 0.5273122191429138 0.0007129275472834706
targets 0.2549999952316284 0.49000000953674316
outputs 0.5470120906829834 0.02596450224518776
targets 0.5 0.5658032894134521
outputs 0.5282972455024719 0.0033393169287592173
targets 0.7450000047683716 0.49000000953674316
outputs 0.5508614778518677 0.048297103494405746
targets 0.5 0.5658032894134521
outputs 0.5607430934906006 0.06815623492002487
targets 0.2549999952316284 0.49000000953674316
outputs 0.5760659575462341 0.09721086174249649
targets 0.2549999952316284 0.49000000953674316
outputs 0.5270541906356812 0.0005034782807

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6909


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.6999915838241577 0.16750402748584747
targets 0.5 0.5658032894134521
outputstep 1/16, valid_loss: 0.8045 0.6916244029998779 0.19092561304569244
targets 0.5 0.5658032894134521
outputs 0.7348294258117676 0.20367401838302612
targets 0.7450000047683716 0.49000000953674316
outputs 0.5963379144668579 0.08279437571763992
targets 0.5 0.5658032894134521
outputs 0.6958823800086975 0.2122759073972702
targets 0.7450000047683716 0.49000000953674316
outputs 0.5513530373573303 0.04362880066037178
targets 0.2549999952316284 0.49000000953674316
outputs 0.6970205307006836 0.20434296131134033
targets 0.5 0.5658032894134521
outputs 0.6021296381950378 0.09830129891633987
targets 0.7450000047683716 0.49000000953674316
outputs 0.7159563302993774 0.19859352707862854
targets 0.5 0.5658032894134521
outputs 0.654334545135498 0.22737571597099304
targets 0.2549999952316284 0.49000000953674316
outputs 0.6884310841560364 0.21543128788471222
targets 0.2549999952316284 0.49000000953674316
outputs 0.6165517568

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6931


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5370900630950928 0.03518941253423691
targets 0.5 0.5658032894134521
outputs 0.5474875569343567 0.05953552573919296
targets 0.5 0.5658032894134521
outputs 0.5653918981552124 0.08796420693397522
targets 0.7450000047683716 0.49000000953674316
outputs 0.5178359746932983 0.001887400052510202
targets 0.5 0.5658032894134521
outputs 0.5463934540748596 0.05778935179114342
targets 0.7450000047683716 0.49000000953674316
outputs 0.5173048973083496 0.0008361215004697442
targets 0.2549999952316284 0.49000000953674316
outputs 0.5517804026603699 0.04109799489378929
targets 0.5 0.5658032894134521
outputs 0.520357608795166 0.005324396770447493
targets 0.7450000047683716 0.49000000953674316
outputs 0.5508995056152344 0.06683829426765442
targets 0.5 0.5658032894134521
outputs 0.5577860474586487 0.08037297427654266
targets 0.2549999952316284 0.49000000953674316
outputs 0.5739660263061523 0.10904719680547714
targets 0.2549999952316284 0.49000000953674316
outputs 0.5199651718139648 0.00530187785625

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6908


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.573052167892456 0.09805992990732193
targets 0.5 0.5658032894134521
outputs 0.5791487693786621 0.11783052980899811
targets 0.5 0.5658032894134521
outputs 0.6169054508209229 0.17702311277389526
targets 0.7450000047683716 0.49000000953674316
outputs 0.5210277438163757 0.00272765732370317
targets 0.5 0.5658032894134521
outputs 0.5806353092193604 0.12194384634494781
targets 0.7450000047683716 0.49000000953674316
outputs 0.5199776291847229 0.0008268571691587567
targets 0.2549999952316284 0.49000000953674316
outputs 0.5886539220809937 0.0858873650431633
targets 0.5 0.5658032894134521
outputs 0.5236791372299194 0.007929030805826187
targets 0.7450000047683716 0.49000000953674316
outputs 0.5843257904052734 0.12201215326786041
targets 0.5 0.5658032894134521
outputs 0.6087843179702759 0.1785248965024948
targets 0.2549999952316284 0.49000000953674316
outputs 0.6216048002243042 0.1992555558681488
targets 0.2549999952316284 0.49000000953674316
outputs 0.5219696164131165 0.003103223163634538

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6904


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5455284118652344 0.06384609639644623
targets 0.5 0.5658032894134521
outputs 0.5543395280838013 0.053852301090955734
targets 0.5 0.5658032894134521
outputs 0.5563238859176636 0.08312436938285828
targets 0.7450000047683716 0.49000000953674316
outputs 0.5131838321685791 0.00019712618086487055
targets 0.5 0.5658032894134521
outputs 0.547220766544342 0.04682491347193718
targets 0.7450000047683716 0.49000000953674316
outputs 0.5133170485496521 0.0002440283860778436
targets 0.2549999952316284 0.49000000953674316
outputs 0.5696406960487366 0.11274918913841248
targets 0.5 0.5658032894134521
outputs 0.5230398774147034 0.018984802067279816
targets 0.7450000047683716 0.49000000953674316
outputs 0.557722806930542 0.050998374819755554
targets 0.5 0.5658032894134521
outputs 0.5924950242042542 0.1577538251876831
targets 0.2549999952316284 0.49000000953674316
outputs 0.5753945112228394 0.11844511330127716
targets 0.2549999952316284 0.49000000953674316
outputs 0.5318000316619873 0.029330112040

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 131/131, train_loss: 0.6909


HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5346370935440063 0.044946227222681046
targets 0.5 0.5658032894134521
outputs 0.5454232096672058 0.057488128542900085
targets 0.5 0.5658032894134521
outputs 0.5691969394683838 0.10264257341623306
targets 0.7450000047683716 0.49000000953674316
outputs 0.5115153193473816 0.0005167420022189617
targets 0.5 0.5658032894134521
outputs 0.5446833968162537 0.054588593542575836
targets 0.7450000047683716 0.49000000953674316
outputs 0.5117863416671753 0.00023105350555852056
targets 0.2549999952316284 0.49000000953674316
outputs 0.5455410480499268 0.06510120630264282
targets 0.5 0.5658032894134521
outputs 0.515966534614563 0.006406803149729967
targets 0.7450000047683716 0.49000000953674316
outputs 0.5460360050201416 0.046217530965805054
targets 0.5 0.5658032894134521
outputs 0.5783478021621704 0.1326047033071518
targets 0.2549999952316284 0.49000000953674316
outputs 0.5740985870361328 0.12375087291002274
targets 0.2549999952316284 0.49000000953674316
outputs 0.5204766392707825 0.014508069

HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))

Train Step 4/131, train_loss: 0.6730

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5365814566612244 0.04946257174015045
targets 0.5 0.5658032894134521
outputs 0.5470962524414062 0.05903320759534836
targets 0.5 0.5658032894134521
outputs 0.5710278153419495 0.10642357915639877
targets 0.7450000047683716 0.49000000953674316
outputs 0.5107925534248352 0.0006290273158811033
targets 0.5 0.5658032894134521
outputs 0.5464175939559937 0.05672493204474449
targets 0.7450000047683716 0.49000000953674316
outputs 0.5109689235687256 0.00035920782829634845
targets 0.2549999952316284 0.49000000953674316
outputs 0.5469521880149841 0.06887102872133255
targets 0.5 0.5658032894134521
outputs 0.5163810849189758 0.00887678749859333
targets 0.7450000047683716 0.49000000953674316
outputs 0.5478881597518921 0.04806486889719963
targets 0.5 0.5658032894134521
outputs 0.5797995924949646 0.13752619922161102
targets 0.2549999952316284 0.49000000953674316
outputs 0.5753265619277954 0.1276775598526001
targets 0.2549999952316284 0.49000000953674316
outputs 0.5219423770904541 0.0177046004682

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5352742075920105 0.04626751318573952
targets 0.5 0.5658032894134521
outputs 0.5484423041343689 0.05457363650202751
targets 0.5 0.5658032894134521
outputs 0.5684041976928711 0.09939759224653244
targets 0.7450000047683716 0.49000000953674316
outputs 0.5111114382743835 0.0006313332123681903
targets 0.5 0.5658032894134521
outputs 0.5469105839729309 0.05248023197054863
targets 0.7450000047683716 0.49000000953674316
outputs 0.5111649036407471 0.0004915338358841836
targets 0.2549999952316284 0.49000000953674316
outputs 0.5477614402770996 0.07131069153547287
targets 0.5 0.5658032894134521
outputs 0.5188533067703247 0.013147802092134953
targets 0.7450000047683716 0.49000000953674316
outputs 0.5481848120689392 0.044978998601436615
targets 0.5 0.5658032894134521
outputs 0.5786195993423462 0.13488684594631195
targets 0.2549999952316284 0.49000000953674316
outputs 0.572895884513855 0.12012758105993271
targets 0.2549999952316284 0.49000000953674316
outputs 0.5255680084228516 0.022231630980

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.5773868560791016 0.11384860426187515
targets 0.5 0.5658032894134521
outputs 0.6201825737953186 0.12059280276298523
targets 0.5 0.5658032894134521
outputs 0.6348371505737305 0.2010057270526886
targets 0.7450000047683716 0.49000000953674316
outputs 0.5152238607406616 0.004959757439792156
targets 0.5 0.5658032894134521
outputs 0.615198016166687 0.1306431144475937
targets 0.7450000047683716 0.49000000953674316
outputs 0.5139803290367126 0.004308292176574469
targets 0.2549999952316284 0.49000000953674316
outputs 0.601093590259552 0.16460341215133667
targets 0.5 0.5658032894134521
outputs 0.5446099042892456 0.05855034664273262
targets 0.7450000047683716 0.49000000953674316
outputs 0.6125423908233643 0.11616629362106323
targets 0.5 0.5658032894134521
outputs 0.6235411167144775 0.22485627233982086
targets 0.2549999952316284 0.49000000953674316
outputs 0.6318617463111877 0.20987990498542786
targets 0.2549999952316284 0.49000000953674316
outputs 0.5652605891227722 0.0741700828075409
ta

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.9743784666061401 0.0500052385032177
targets 0.5 0.5658032894134521
outputs 0.9982879161834717 0.0032708137296140194
targets 0.5 0.5658032894134521
outputs 0.9904892444610596 0.016649959608912468
targets 0.7450000047683716 0.49000000953674316
outputs 0.8159905076026917 0.21714235842227936
targets 0.5 0.5658032894134521
outputs 0.9777633547782898 0.039760638028383255
targets 0.7450000047683716 0.49000000953674316
outputs 0.7718452215194702 0.18658216297626495
targets 0.2549999952316284 0.49000000953674316
outputs 0.7939730882644653 0.24064551293849945
targets 0.5 0.5658032894134521
outputs 0.7990971207618713 0.2399810403585434
targets 0.7450000047683716 0.49000000953674316
outputs 0.8760709762573242 0.24290238320827484
targets 0.5 0.5658032894134521
outputs 0.7595059871673584 0.2747548818588257
targets 0.2549999952316284 0.49000000953674316
outputs 0.8885599374771118 0.22221386432647705
targets 0.2549999952316284 0.49000000953674316
outputs 0.9023914337158203 0.1912758052349090

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.9998783469200134 0.00024330615997314453
targets 0.5 0.5658032894134521
outputs 0.9999997615814209 4.172325134277344e-07
targets 0.5 0.5658032894134521
outputs 0.9999642372131348 7.146596908569336e-05
targets 0.7450000047683716 0.49000000953674316
outputs 0.9161070585250854 0.13583067059516907
targets 0.5 0.5658032894134521
outputs 0.999893069267273 0.00021006491442676634
targets 0.7450000047683716 0.49000000953674316
outputs 0.9100407958030701 0.15735051035881042
targets 0.2549999952316284 0.49000000953674316
outputs 0.8832482099533081 0.17295342683792114
targets 0.5 0.5658032894134521
outputs 0.8689742684364319 0.19629482924938202
targets 0.7450000047683716 0.49000000953674316
outputs 0.8779575228691101 0.24408498406410217
targets 0.5 0.5658032894134521
outputs 0.8025091886520386 0.22822020947933197
targets 0.2549999952316284 0.49000000953674316
outputs 0.9153324365615845 0.16933518648147583
targets 0.2549999952316284 0.49000000953674316
outputs 0.9598536491394043 0.08029270

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

outputs 0.9998783469200134 0.00024330615997314453
targets 0.5 0.5658032894134521
outputs 0.9999997615814209 4.172325134277344e-07
targets 0.5 0.5658032894134521
outputs 0.9999642372131348 7.146596908569336e-05
targets 0.7450000047683716 0.49000000953674316
outputs 0.9161070585250854 0.13583067059516907
targets 0.5 0.5658032894134521
outputs 0.999893069267273 0.00021006491442676634
targets 0.7450000047683716 0.49000000953674316
outputs 0.9100407958030701 0.15735051035881042
targets 0.2549999952316284 0.49000000953674316
outputs 0.8832482099533081 0.17295342683792114
targets 0.5 0.5658032894134521
outputs 0.8689742684364319 0.19629482924938202
targets 0.7450000047683716 0.49000000953674316
outputs 0.8779575228691101 0.24408498406410217
targets 0.5 0.5658032894134521
outputs 0.8025091886520386 0.22822020947933197
targets 0.2549999952316284 0.49000000953674316
outputs 0.9153324365615845 0.16933518648147583
targets 0.2549999952316284 0.49000000953674316
outputs 0.9598536491394043 0.08029270

KeyboardInterrupt: 

## Predict Function

In [None]:
def predict(modelfile, df, mri_type, split):
    print("Predict:", modelfile, mri_type, df.shape)
    df.loc[:,"MRI_Type"] = mri_type
    data_retriever = Dataset(
        df.index.values, 
        mri_type=df["MRI_Type"].values,
        split=split
    )

    data_loader = torch_data.DataLoader(
        data_retriever,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=8,
    )
   
    model = create_model()
    
    checkpoint = torch.load(modelfile)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    y_pred = []
    ids = []

    for e, batch in enumerate(data_loader,1):
        print(f"{e}/{len(data_loader)}", end="\r")
        with torch.no_grad():
            tmp_pred = torch.sigmoid(model(batch["X"].to(device)).cpu().numpy().squeeze())
            print('tmp_pred', tmp_pred)
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred)
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(batch["id"].numpy().tolist())
            
    preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred}) 
    preddf = preddf.set_index("BraTS21ID")
    return preddf

### Ensemble for Validation

In [None]:
df_valid = df_valid.set_index("BraTS21ID")

In [None]:
def normalize_results(preds, train):
    return (preds - (preds.mean() - train.mean()))

In [None]:
df_valid["MGMT_pred"] = 0
for m, mtype in zip(modelfiles,  mri_types):
    pred = predict(m, df_valid, mtype, "train")
    df_valid["MGMT_pred"] += pred["MGMT_value"]
df_valid["MGMT_pred"] /= len(modelfiles)
auc = roc_auc_score(df_valid["MGMT_value"], df_valid["MGMT_pred"])
print(f"Validation ensemble AUC: {auc:.4f}")
sns.displot(df_valid["MGMT_pred"])