In [None]:
###
# is_on_kaggle = True
is_on_kaggle = False
###
is_debugging = True
# is_debugging = False

## Import ##

In [None]:
# import packages
import os
import multiprocessing
from pathlib import Path
import random
from collections import defaultdict
from glob import glob
import pickle
from joblib import Parallel, delayed
import gc
from tqdm.notebook import tqdm
from tabulate import tabulate
import yaml
import datetime
from logging import getLogger
import wandb

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold, GroupKFold, StratifiedGroupKFold
from sklearn.metrics import accuracy_score, roc_auc_score

import cv2
import pydicom
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import Adam, AdamW
from torchvision import models
# from torchvision.transforms import (Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip,)
import torchvision.transforms.v2 as t
from torchvision.transforms.v2 import (Resize, Compose, RandomHorizontalFlip, 
                                       ColorJitter, RandomAffine, RandomErasing, ToTensor)
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import timm
import albumentations as A


## Constants ##

In [None]:
cfg = {
    'general': {
        'project_name': '202309-rsna-atd',
        'inference_version': 'v2-8',
        'seed': 42,
    },
    'data': {
        'n_folds': 5, 
        'fold_i': 0, 
        'batch_size_inference': 16, 
        'kls_slice_start': 0.6, 
        'b_e_1_slice_start': 0.0, 
        'b_e_2_slice_start': 0.0, 
        'kls_stride': 4, 
        'b_e_1_stride': 16, 
        'b_e_2_stride': 48, 
        'calc_cv_score': False, 
        'apply_aug': True,
    }, 
    'model': {
        'kls_model_name': 'tf_efficientnetv2_s',
        'b_e_model_name_1': 'tf_efficientnetv2_s',
        'b_e_model_name_2': 'maxvit_tiny_tf_384.in1k',
        'pretrained': False, # for inference
        'in_chans': 1, 
        'num_classes': 0, # to use as backbone
        'global_pool': 'max',
        'drop_rate': 0.8, 
        'drop_path_rate': 0.2, 
        'kls_weights': [1.0, 8.0, 16.0],  # healty, low, high
        'b_weights': [1.0, 8.0],  # healthy, injury
        'e_weights': [1.0, 24.0],  # healthy, injury
        'hidden_dim': 128,
        'p_dropout': 0.3,
        'lr': 1.0e-4, 
    },
    'pl_params': {
        'accelerator': 'auto',
        'precision': 16,  # 16 or 32
        'enable_progress_bar': True, 
    }
}

if is_debugging:
    cfg['data']['kls_stride'] = 1000
    cfg['data']['b_e_1_stride'] = 1000
    cfg['data']['b_e_2_stride'] = 1000



## Paths ##

In [None]:
BASE_PATH = '.' if not is_on_kaggle else '/kaggle/input/rsna-2023-abdominal-trauma-detection'
KLS_TEST_DATA_DIR = '../kls_test_data_v2_8'
B_E_1_TEST_DATA_DIR = '../b_e_1_test_data_v2_8'
B_E_2_TEST_DATA_DIR = '../b_e_2_test_data_v2_8'
KLS_MODEL_PATH = f"../models/{cfg['model']['kls_model_name']}_exp005_kls_fold{cfg['data']['fold_i']}_aug.pt"
B_E_MODEL_1_PATH = f"../models/{cfg['model']['b_e_model_name_1']}_exp010_b_e_fold{cfg['data']['fold_i']}.pt"
B_E_MODEL_2_PATH = f"../models/{cfg['model']['b_e_model_name_2']}_exp011_b_e_fold{cfg['data']['fold_i']}.pt"

if is_on_kaggle:
    TEST_DATA_DIR = f'{BASE_PATH}/test_images'
    df_dicom_test = pd.read_parquet(f'{BASE_PATH}/test_dicom_tags.parquet')
    sample_submission = pd.read_csv(f'{BASE_PATH}/sample_submission.csv')
    KLS_TEST_DATA_DIR = './kls_test_data_v2_8'
    B_E_1_TEST_DATA_DIR = './b_e_1_test_data_v2_8'
    B_E_2_TEST_DATA_DIR = './b_e_2_test_data_v2_8'
    KLS_MODEL_PATH = f"/kaggle/input/models-{cfg['general']['inference_version']}/{cfg['model']['kls_model_name']}_exp005_kls_fold{cfg['data']['fold_i']}_aug.pt"
    B_E_MODEL_1_PATH = f"/kaggle/input/models-{cfg['general']['inference_version']}/{cfg['model']['b_e_model_name_1']}_exp010_b_e_fold{cfg['data']['fold_i']}.pt"
    B_E_MODEL_2_PATH = f"/kaggle/input/models-{cfg['general']['inference_version']}/{cfg['model']['b_e_model_name_2']}_exp011_b_e_fold{cfg['data']['fold_i']}.pt"

os.makedirs(KLS_TEST_DATA_DIR, exist_ok=True)
os.makedirs(B_E_1_TEST_DATA_DIR, exist_ok=True)
os.makedirs(B_E_2_TEST_DATA_DIR, exist_ok=True)



In [None]:
# misc functions
torch.cuda.empty_cache()
# multiprocessing.set_start_method('spawn', force=True)
seed_everything(cfg['general']['seed'], workers=True)
num_workers = os.cpu_count() if is_on_kaggle else 0
gpu_count = torch.cuda.device_count()
print('num_workers:', num_workers)
print('gpu_count:', gpu_count)

def random_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed(seed)    

random_seed(cfg['general']['seed'])

## Preparing test data ##

In [None]:
# ## check to solve scoring error
# if is_on_kaggle:
#     TRAIN_DATA_DIR = f'{BASE_PATH}/train_images'
#     TRAIN_LABEL = f'{BASE_PATH}/train.csv'
#     train_series_meta = pd.read_csv(f'{BASE_PATH}/train_series_meta.csv')
#     train_df = pd.read_csv(TRAIN_LABEL)

#     # use train data instead of test data
#     TEST_DATA_DIR = TRAIN_DATA_DIR
#     sample_submission = train_df
# #     unique_patients = train_series_meta.patient_id.unique()
# #     sample_submission = pd.DataFrame(unique_patients, columns=['patient_id'])
    

In [None]:
def standardize_pixel_array(dcm: pydicom.dataset.FileDataset) -> np.ndarray:
    """
    Source : https://www.kaggle.com/competitions/rsna-2023-abdominal-trauma-detection/discussion/427217
    """
    # Correct DICOM pixel_array if PixelRepresentation == 1.
    pixel_array = dcm.pixel_array
    if dcm.PixelRepresentation == 1:
        bit_shift = dcm.BitsAllocated - dcm.BitsStored
        dtype = pixel_array.dtype 
        pixel_array = (pixel_array << bit_shift).astype(dtype) >>  bit_shift
#         pixel_array = pydicom.pixel_data_handlers.util.apply_modality_lut(new_array, dcm)

    intercept = float(dcm.RescaleIntercept)
    slope = float(dcm.RescaleSlope)
    center = int(dcm.WindowCenter)
    width = int(dcm.WindowWidth)
    low = center - width / 2
    high = center + width / 2    
    
    pixel_array = (pixel_array * slope) + intercept
    pixel_array = np.clip(pixel_array, low, high)

    return pixel_array


In [None]:
def preprocess(patient, series, slice_start, slice_window, stride=10, size=256, flag='', test_data_dir='', save_folder=''):
    imgs = {}
    sorted_img_paths = sorted(glob(os.path.join(test_data_dir, patient, series, "*.dcm")), key=lambda x: int(x.split('/')[-1].split('.')[0]))
    start_index = int(len(sorted_img_paths) * slice_start)
    end_index = int(len(sorted_img_paths) * (slice_start + slice_window))
    if flag == 'b_e_2':
        # slide = int(min(stride / 2, len(sorted_img_paths) / 2))
        slide = int(min(8, len(sorted_img_paths) / 2))
        start_index += slide
        end_index += slide
    if start_index == end_index:
        end_index += 1
    roi = sorted_img_paths[start_index:end_index]
    for f in roi[::stride]:
        # skip this corrupted file, test_images/3124/5842/514.dcm
        if f.split('/')[-3] == '3124' and f.split('/')[-2] == '5842' and f.split('/')[-1] == '514.dcm':
            continue
        dicom = pydicom.dcmread(f)
        pos_z = dicom[(0x20, 0x32)].value[-1]
        img = standardize_pixel_array(dicom)
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)
        if dicom.PhotometricInterpretation == 'MONOCHROME1':
            img = 1 - img
        imgs[pos_z] = img

    for i, k in enumerate(sorted(imgs.keys())):
        img = imgs[k]
        if size is not None:
            img = cv2.resize(img, (size, size))
        if isinstance(save_folder, str):
            cv2.imwrite(os.path.join(save_folder, f'{patient}_{series}_{i}.png'), (img * 255).astype(np.uint8))
        else:
            im = cv2.imencode('.png', (img * 255).astype(np.uint8))[1]
            save_folder.writestr(f'{patient}_{series}_{i:04d}.png', im)



In [None]:
# kls
# preprocess all test images with multiprocessing
if is_on_kaggle:
    tasks = []
    for patient in os.listdir(TEST_DATA_DIR):
        # if len(os.listdir(os.path.join(TEST_DATA_DIR, patient))) == 0:
        #     continue
        for series in sorted(os.listdir(os.path.join(TEST_DATA_DIR, patient))):
            # if len(os.listdir(os.path.join(TEST_DATA_DIR, patient, series))) == 0:
            #     continue
            tasks.append((patient, series))

    slice_start = cfg['data']['kls_slice_start']
    slice_window = 0.2
    stride = cfg['data']['kls_stride']
    print('len(tasks):', len(tasks))
    print('kls_slice_start:', slice_start)
    print('kls_stride:', stride)
    _ = Parallel(n_jobs=-1, backend='threading')(
        delayed(preprocess)(patient, series, slice_start, slice_window, stride, test_data_dir=TEST_DATA_DIR, save_folder=KLS_TEST_DATA_DIR)
        for patient, series in tqdm(tasks)
    )
    
    del _
    gc.collect()

In [None]:
# b_e 1
# preprocess all test images with multiprocessing
if is_on_kaggle:
    slice_start = cfg['data']['b_e_1_slice_start']
    slice_window = 0.8
    stride = cfg['data']['b_e_1_stride']
    print('len(tasks):', len(tasks))
    print('b_e_slice_start:', slice_start)
    print('b_e_stride:', stride)
    _ = Parallel(n_jobs=-1, backend='threading')(
        delayed(preprocess)(patient, series, slice_start, slice_window, stride, 
                            test_data_dir=TEST_DATA_DIR, save_folder=B_E_1_TEST_DATA_DIR)
        for patient, series in tqdm(tasks)
    )
    
    del _
    gc.collect()

In [None]:
# b_e 2
# preprocess all test images with multiprocessing
if is_on_kaggle:
    flag = 'b_e_2'
    slice_start = cfg['data']['b_e_2_slice_start']
    slice_window = 0.8
    stride = cfg['data']['b_e_2_stride']
    print('len(tasks):', len(tasks))
    print('b_e_slice_start:', slice_start)
    print('b_e_stride:', stride)
    _ = Parallel(n_jobs=-1, backend='threading')(
        delayed(preprocess)(patient, series, slice_start, slice_window, stride, flag=flag, 
                            test_data_dir=TEST_DATA_DIR, save_folder=B_E_2_TEST_DATA_DIR)
        for patient, series in tqdm(tasks)
    )
    
    del _
    gc.collect()

In [None]:
# dataset
class AbdominalKLSDataTest(Dataset):
    
    def __init__(self, cfg, model_name, test_img_dir, apply_aug=True):
        super().__init__()
        self.cfg = cfg
        self.model_name = model_name
        self.augmentation = apply_aug

        self.test_img_paths = self._fetch_test_img_paths(test_img_dir)
                
        self.normalize = Compose([
            # Resize((256, 256), antialias=True),
            # RandomHorizontalFlip(),  # Randomly flip images left-right
            # ColorJitter(brightness=0.2),  # Randomly adjust brightness
            # ColorJitter(contrast=0.2),  # Randomly adjust contrast
            # RandomAffine(degrees=0, shear=10),  # Apply shear transformation
            # RandomAffine(degrees=0, scale=(0.8, 1.2)),  # Apply zoom transformation
            # RandomErasing(p=0.2, scale=(0.02, 0.2)),  # Coarse dropout
            ToTensor(),
            # ToImageTensor(), 
            # ConvertImageDtype(), 
        ])

        # augmentation
        # flip
        self.aug_h_flip = A.HorizontalFlip(p=0.5)
        self.aug_v_flip = A.VerticalFlip(p=0.5)
        # elastic and grid
        self.aug_distortion = A.GridDistortion(p=0.5)
        self.aug_elastic = A.ElasticTransform(p=0.5)
        # affine
        self.aug_affine = A.Affine(
            scale=(0.8, 1.2),
            translate_percent=(0.0, 0.2),
            rotate=(-45, 45),
            shear=(-15, 15),
            p=0.5)
        # self.aug_affine = A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.8)
        # clahe
        self.aug_clahe = A.CLAHE(p=0.5)
        # bright
        self.aug_bright = A.OneOf([
            A.RandomGamma(gamma_limit=(50, 150), p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=0.5)
        ], p=0.5)
        # cutout
        self.aug_cutout = A.CoarseDropout(max_height=8, max_width=8, p=0.5)
        # randomcrop
        self.aug_randomcrop = A.RandomResizedCrop(
            height=256,
            width=256,
            scale=(0.8, 1.0),
            ratio=(3/4, 4/3),
            p=0.5)
    
    def __len__(self):
        return len(self.test_img_paths)
    
    def __getitem__(self, idx):
        sample_img_path = self.test_img_paths[idx]
        patient_id = int(os.path.basename(sample_img_path).split('_')[0])

        # preprocess image
        img = self._process_img(sample_img_path)
        # img.shape: (256, 256)

        # augmentation
        if self.augmentation:
            img = self.aug_h_flip(image=img)["image"]
            img = self.aug_v_flip(image=img)["image"]
            img = self.aug_distortion(image=img)["image"]
            img = self.aug_clahe(image=img)["image"]
            img = self.aug_affine(image=img)["image"]
            img = self.aug_bright(image=img)["image"]
            img = self.aug_cutout(image=img)["image"]
            img = self.aug_randomcrop(image=img)["image"]

        img = img.astype('float32') / 255
        # img.shape: (256, 256)

        img = torch.tensor(img, dtype=torch.float).unsqueeze(dim=0)
        # img.shape: (1, 256, 256)
        if self.model_name == 'maxvit_tiny_tf_384.in1k':
            img = Compose([Resize((384, 384), antialias=True)])(img)
        img = self.normalize(img)
        # img.shape: (1, 256, 256)
        if is_on_kaggle and self.model_name == 'maxvit_rmlp_pico_rw_256.sw_in1k':
            # convert torch.FloatTensor into torch.cuda.FloatTensor
            img = img.cuda()

        return {
            'image': img, 
            'patient_id': patient_id,
        }
    
    def _fetch_test_img_paths(self, img_dir):
        paths = []
        patients_to_series_to_img_paths = defaultdict(lambda: defaultdict(list))
        for filename in os.listdir(img_dir):
            patient_id, series_id, _ = filename.split('_')
            patients_to_series_to_img_paths[patient_id][series_id].append(os.path.join(img_dir, filename))
        
        for patient_id, series_to_img_paths in patients_to_series_to_img_paths.items():
            for series_id, imgs in series_to_img_paths.items():
                # sort by instance number
                sorted_img_paths = sorted(imgs, key=lambda x: int(x.split('_')[-1].split('.')[0]))
                for img_path in sorted_img_paths:
                    paths.append(img_path)

        return paths

    def _process_img(self, img_path):
        image = cv2.imread(img_path)
        # image = image.astype('float32') / 255
        image = (image.astype('float32') * 255).astype('uint8')
        greyscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        greyscale = cv2.resize(greyscale, (256, 256))
        return greyscale


class AbdominalBEDataTest(Dataset):
    
    def __init__(self, cfg, model_name, test_img_dir, apply_aug=True):
        super().__init__()
        self.cfg = cfg
        self.model_name = model_name
        self.augmentation = apply_aug

        self.test_img_paths = self._fetch_test_img_paths(test_img_dir)
                
        self.normalize = Compose([
            # Resize((256, 256), antialias=True),
            # RandomHorizontalFlip(),  # Randomly flip images left-right
            # ColorJitter(brightness=0.2),  # Randomly adjust brightness
            # ColorJitter(contrast=0.2),  # Randomly adjust contrast
            # RandomAffine(degrees=0, shear=10),  # Apply shear transformation
            # RandomAffine(degrees=0, scale=(0.8, 1.2)),  # Apply zoom transformation
            # RandomErasing(p=0.2, scale=(0.02, 0.2)),  # Coarse dropout
            ToTensor(),
        ])

        # augmentation
        # flip
        self.aug_h_flip = A.HorizontalFlip(p=0.5)
        self.aug_v_flip = A.VerticalFlip(p=0.5)
        # elastic and grid
        self.aug_distortion = A.GridDistortion(p=0.5)
        self.aug_elastic = A.ElasticTransform(p=0.5)
        # affine
        self.aug_affine = A.Affine(
            scale=(0.8, 1.2),
            translate_percent=(0.0, 0.1),
            rotate=(-35, 35),
            shear=(-15, 15),
            p=0.5)
        # self.aug_affine = A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.8)
        # clahe
        self.aug_clahe = A.CLAHE(p=0.5)
        # bright
        self.aug_bright = A.OneOf([
            A.RandomGamma(gamma_limit=(60, 140), p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=0.5)
        ], p=0.5)
        # cutout
        self.aug_cutout = A.CoarseDropout(max_height=8, max_width=8, p=0.5)
        # randomcrop
        self.aug_randomcrop = A.RandomResizedCrop(
            height=256,
            width=256,
            scale=(0.8, 1.0),
            ratio=(3/4, 4/3),
            p=0.5)
    
    def __len__(self):
        return len(self.test_img_paths)
    
    def __getitem__(self, idx):
        sample_img_path = self.test_img_paths[idx]
        patient_id = int(os.path.basename(sample_img_path).split('_')[0])

        # preprocess image
        img = self._process_img(sample_img_path)
        # img.shape: (256, 256)

        # augmentation
        if self.augmentation:
            img = self.aug_h_flip(image=img)["image"]
            img = self.aug_v_flip(image=img)["image"]
            img = self.aug_distortion(image=img)["image"]
            img = self.aug_clahe(image=img)["image"]
            img = self.aug_affine(image=img)["image"]
            img = self.aug_bright(image=img)["image"]
            img = self.aug_cutout(image=img)["image"]
            img = self.aug_randomcrop(image=img)["image"]

        img = img.astype('float32') / 255
        # img.shape: (256, 256)

        img = torch.tensor(img, dtype=torch.float).unsqueeze(dim=0)
        # img.shape: (1, 256, 256)
        if self.model_name == 'maxvit_tiny_tf_384.in1k':
            img = Compose([Resize((384, 384), antialias=True)])(img)
        img = self.normalize(img)
        # img.shape: (1, 256, 256)
        if is_on_kaggle and self.model_name == 'maxvit_rmlp_pico_rw_256.sw_in1k':
            # convert torch.FloatTensor into torch.cuda.FloatTensor
            img = img.cuda()

        return {
            'image': img, 
            'patient_id': patient_id,
        }
    
    def _fetch_test_img_paths(self, img_dir):
        paths = []
        patients_to_series_to_img_paths = defaultdict(lambda: defaultdict(list))
        for filename in os.listdir(img_dir):
            patient_id, series_id, _ = filename.split('_')
            patients_to_series_to_img_paths[patient_id][series_id].append(os.path.join(img_dir, filename))
        
        for patient_id, series_to_img_paths in patients_to_series_to_img_paths.items():
            for series_id, imgs in series_to_img_paths.items():
                # sort by instance number
                sorted_img_paths = sorted(imgs, key=lambda x: int(x.split('_')[-1].split('.')[0]))
                for img_path in sorted_img_paths:
                    paths.append(img_path)

        return paths

    def _process_img(self, img_path):
        image = cv2.imread(img_path)
        # image = image.astype('float32') / 255
        image = (image.astype('float32') * 255).astype('uint8')
        greyscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        greyscale = cv2.resize(greyscale, (256, 256))
        return greyscale

    

In [None]:
print('KLS_TEST_DATA_DIR:', KLS_TEST_DATA_DIR)
print('B_E_1_TEST_DATA_DIR:', B_E_1_TEST_DATA_DIR)
print('B_E_2_TEST_DATA_DIR:', B_E_2_TEST_DATA_DIR)
print('apply_aug:', cfg['data']['apply_aug'])
kls_test_data = AbdominalKLSDataTest(cfg, cfg['model']['kls_model_name'], KLS_TEST_DATA_DIR, apply_aug=cfg['data']['apply_aug'])
b_e_1_test_data = AbdominalBEDataTest(cfg, cfg['model']['b_e_model_name_1'], B_E_1_TEST_DATA_DIR, apply_aug=cfg['data']['apply_aug'])
b_e_2_test_data = AbdominalBEDataTest(cfg, cfg['model']['b_e_model_name_2'], B_E_2_TEST_DATA_DIR, apply_aug=cfg['data']['apply_aug'])


In [None]:
len(kls_test_data), len(b_e_1_test_data), len(b_e_2_test_data)

## Model Architecture ##

In [None]:
# Model Architecure
class KLSNet(pl.LightningModule):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.backbone = timm.create_model(
            model_name=cfg['model']['model_name'],
            pretrained=cfg['model']['pretrained'],
            in_chans=cfg['model']['in_chans'],
            num_classes=cfg['model']['num_classes'],
            global_pool=cfg['model']['global_pool'],
            drop_rate=cfg["model"]["drop_rate"],
            drop_path_rate=cfg["model"]["drop_path_rate"],
        )
        # for param in self.backbone.parameters():
        #     param.requires_grad = False

        self.in_features = self.backbone.num_features  # 1280
        hidden_dim = cfg['model']['hidden_dim']
        self.neck = nn.Sequential(
            nn.Linear(self.in_features, hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg['model']['p_dropout']),
        )

        self.kidney = nn.Linear(hidden_dim, 3)
        self.liver = nn.Linear(hidden_dim, 3)
        self.spleen = nn.Linear(hidden_dim, 3)

        self.cce = nn.CrossEntropyLoss(label_smoothing=0.05, weight=torch.tensor(cfg['model']['kls_weights']))

        self.train_epoch_loss = []
        self.val_epoch_loss = []
        self.probs = defaultdict(list)
        self.targets = defaultdict(list)
        self.auc_scores = dict()

    def forward(self, x):
        # extract features
        x = self.backbone(x)
        x = self.neck(x)

        # output logits
        kidney = self.kidney(x)
        liver = self.liver(x)
        spleen = self.spleen(x)

        return kidney, liver, spleen

    def training_step(self, batch, batch_idx):
        inputs = batch['image']
        kidney = batch['kidney']
        liver = batch['liver']
        spleen = batch['spleen']

        k, l, s = self.forward(inputs)
        k_loss = self.cce(k, kidney)
        l_loss = self.cce(l, liver)
        s_loss = self.cce(s, spleen)
        loss = k_loss + l_loss + s_loss
        self.train_epoch_loss.append(loss.item())

        self.log('train_loss', loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True)
        return loss

    # def on_train_epoch_end(self):
    #     avg_loss = np.mean(self.train_epoch_loss)
    #     self.log('avg_train_loss', avg_loss, prog_bar=True)
    #     self.train_epoch_loss.clear()

    def validation_step(self, batch, batch_idx):
        inputs = batch['image']
        kidney = batch['kidney']
        liver = batch['liver']
        spleen = batch['spleen']

        k, l, s = self.forward(inputs)
        k_loss = self.cce(k, kidney)
        l_loss = self.cce(l, liver)
        s_loss = self.cce(s, spleen)
        loss = k_loss + l_loss + s_loss
        self.val_epoch_loss.append(loss.item())

        self.probs['k'].extend(F.softmax(k, dim=1).detach().cpu().numpy())
        self.probs['l'].extend(F.softmax(l, dim=1).detach().cpu().numpy())
        self.probs['s'].extend(F.softmax(s, dim=1).detach().cpu().numpy())
        self.targets['k'].extend(kidney.detach().cpu().numpy())
        self.targets['l'].extend(liver.detach().cpu().numpy())
        self.targets['s'].extend(spleen.detach().cpu().numpy())

        self.log('val_loss', loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True)
        return loss

    def on_validation_epoch_end(self):
        avg_loss = np.mean(self.val_epoch_loss)

        for t in ['k', 'l', 's']:
            self.auc_scores[t] = roc_auc_score(
                self.targets.get(t),
                self.probs.get(t),
                multi_class='ovo', labels=[0, 1, 2])

        # self.log('avg_val_loss', avg_loss, prog_bar=True)
        self.log('val_auc_score_k', self.auc_scores.get('k'), prog_bar=True, sync_dist=True)
        self.log('val_auc_score_l', self.auc_scores.get('l'), prog_bar=True, sync_dist=True)
        self.log('val_auc_score_s', self.auc_scores.get('s'), prog_bar=True, sync_dist=True)
        self.val_epoch_loss.clear()
        self.probs.clear()
        self.targets.clear()
        self.auc_scores.clear()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=float(self.cfg['model']['lr']))
        # optimizer = AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=float(self.cfg['model']['lr']))
        return optimizer

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        pass


# Model Architecure
class BENet(pl.LightningModule):

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.backbone = timm.create_model(
            model_name=cfg['model']['model_name'],
            pretrained=cfg['model']['pretrained'],
            in_chans=cfg['model']['in_chans'],
            num_classes=cfg['model']['num_classes'],
            global_pool=cfg['model']['global_pool'],
            drop_rate=cfg["model"]["drop_rate"],
            drop_path_rate=cfg["model"]["drop_path_rate"],
        )
        # for param in self.backbone.parameters():
        #     param.requires_grad = False

        self.in_features = self.backbone.num_features  # 1280
        hidden_dim = cfg['model']['hidden_dim']
        self.neck = nn.Sequential(
            nn.Linear(self.in_features, hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg['model']['p_dropout']),
        )

        self.bowel = nn.Linear(hidden_dim, 2)
        self.extravasation = nn.Linear(hidden_dim, 2)

        self.cce_b = nn.CrossEntropyLoss(label_smoothing=0.05, weight=torch.tensor(cfg['model']['b_weights']))
        self.cce_e = nn.CrossEntropyLoss(label_smoothing=0.05, weight=torch.tensor(cfg['model']['e_weights']))

        self.train_epoch_loss = []
        self.val_epoch_loss = []
        self.probs = defaultdict(list)
        self.targets = defaultdict(list)
        self.auc_scores = dict()

    def forward(self, x):
        # extract features
        x = self.backbone(x)
        x = self.neck(x)

        # output logits
        bowel = self.bowel(x)
        extravsation = self.extravasation(x)

        return bowel, extravsation

    def training_step(self, batch, batch_idx):
        inputs = batch['image']
        bowel = batch['bowel']
        extravasation = batch['extravasation']

        b, e = self.forward(inputs)
        b_loss = self.cce_b(b, bowel)
        e_loss = self.cce_e(e, extravasation)
        loss = b_loss + e_loss
        self.train_epoch_loss.append(loss.item())

        self.log('train_loss', loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True)
        return loss

    # def on_train_epoch_end(self):
    #     avg_loss = np.mean(self.train_epoch_loss)
    #     self.log('avg_train_loss', avg_loss, prog_bar=True)
    #     self.train_epoch_loss.clear()

    def validation_step(self, batch, batch_idx):
        inputs = batch['image']
        bowel = batch['bowel']
        extravasation = batch['extravasation']

        b, e = self.forward(inputs)
        b_loss = self.cce_b(b, bowel)
        e_loss = self.cce_e(e, extravasation)
        loss = b_loss + e_loss
        self.val_epoch_loss.append(loss.item())

        self.probs['b'].extend(F.softmax(b, dim=1).detach().cpu().numpy())
        self.probs['e'].extend(F.softmax(e, dim=1).detach().cpu().numpy())
        self.targets['b'].extend(bowel.detach().cpu().numpy())
        self.targets['e'].extend(extravasation.detach().cpu().numpy())

        self.log('val_loss', loss, prog_bar=True, logger=True, on_epoch=True, on_step=True, sync_dist=True)
        return loss

    def on_validation_epoch_end(self):
        avg_loss = np.mean(self.val_epoch_loss)

        for t in ['b', 'e']:
            y_true = np.ravel(self.targets.get(t))
            prob_array = np.array(self.probs.get(t))
            if len(np.unique(y_true)) != 2:
                return -1
            self.auc_scores[t] = roc_auc_score(y_true, prob_array[:, 1])

        # self.log('avg_val_loss', avg_loss, prog_bar=True)
        self.log('val_auc_score_b', self.auc_scores.get('b'), prog_bar=True, sync_dist=True)
        self.log('val_auc_score_e', self.auc_scores.get('e'), prog_bar=True, sync_dist=True)
        self.val_epoch_loss.clear()
        self.probs.clear()
        self.targets.clear()
        self.auc_scores.clear()

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=float(self.cfg['model']['lr']))
        # optimizer = AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=float(self.cfg['model']['lr']))
        return optimizer

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        pass



In [None]:
if not is_on_kaggle:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    kls_model = torch.load(KLS_MODEL_PATH, map_location=device)
    b_e_model_1 = torch.load(B_E_MODEL_1_PATH, map_location=device)
    b_e_model_2 = torch.load(B_E_MODEL_2_PATH, map_location=device)
else:
    kls_model = torch.load(KLS_MODEL_PATH)
    b_e_model_1 = torch.load(B_E_MODEL_1_PATH)
    b_e_model_2 = torch.load(B_E_MODEL_2_PATH)


In [None]:
kls_test_dataloader = DataLoader(
                    kls_test_data, 
                    batch_size=cfg['data']['batch_size_inference'], 
                    shuffle=False, 
                    num_workers=num_workers, 
                    pin_memory=True
                )

b_e_1_test_dataloader = DataLoader(
                    b_e_1_test_data, 
                    batch_size=cfg['data']['batch_size_inference'], 
                    shuffle=False, 
                    num_workers=num_workers, 
                    pin_memory=True
                )

b_e_2_test_dataloader = DataLoader(
                    b_e_2_test_data, 
                    batch_size=cfg['data']['batch_size_inference'], 
                    shuffle=False, 
                    num_workers=num_workers, 
                    pin_memory=True
                )



## Predicting ##

In [None]:
# predict for kls
print('batch_size_inference:', cfg['data']['batch_size_inference'])
kls_model.eval()
kls_test_predictions = []
with torch.no_grad():
    for batch_idx, batch_data in tqdm(enumerate(kls_test_dataloader)):
        image = batch_data['image']
        patient_id = batch_data['patient_id']
        k, l, s = kls_model(image)
        
        # Apply activations to get probabilities
        k_probs, l_probs, s_probs = map(
                lambda x: F.softmax(x, dim=1),
                [k, l, s]
            )
        
        # Transfer probabilities back to CPU
        k_probs, l_probs, s_probs = map(
                lambda x: x.cpu().numpy().astype(np.float64),
                [k_probs, l_probs, s_probs]
            )

        # Get one prediction per series in the batch
        for i in range(k_probs.shape[0]):  # Assuming all arrays have the same size
            kls_test_predictions.append([
                int(patient_id[i]), 
                *k_probs[i], 
                *l_probs[i], 
                *s_probs[i]
            ])

column_names = ['patient_id', 
                'kidney_healthy', 'kidney_low', 'kidney_high',
                'liver_healthy', 'liver_low', 'liver_high',
                'spleen_healthy', 'spleen_low', 'spleen_high']

kls_preds = pd.DataFrame(kls_test_predictions, columns=column_names)


In [None]:
kls_preds.head()

In [None]:
# aggregate predictions to patient level
kls_preds = kls_preds.groupby('patient_id').mean().reset_index()


In [None]:
kls_preds.info()

In [None]:
# predict for b_e 1
print('batch_size_inference:', cfg['data']['batch_size_inference'])
b_e_model_1.eval()
b_e_1_test_predictions = []
with torch.no_grad():
    for batch_idx, batch_data in tqdm(enumerate(b_e_1_test_dataloader)):
        image = batch_data['image']
        patient_id = batch_data['patient_id']
        b, e = b_e_model_1(image)
        
        # Apply activations to get probabilities
        b_probs, e_probs = map(
                lambda x: F.softmax(x, dim=1),
                [b, e]
            )
        
        # Transfer probabilities back to CPU
        b_probs, e_probs = map(
                lambda x: x.cpu().numpy().astype(np.float64),
                [b_probs, e_probs]
            )

        # Get one prediction per series in the batch
        for i in range(b_probs.shape[0]):  # Assuming all arrays have the same size
            b_e_1_test_predictions.append([
                int(patient_id[i]), 
                *b_probs[i], 
                *e_probs[i], 
            ])

column_names = ['patient_id', 
                'bowel_healthy', 'bowel_injury', 
                'extravasation_healthy', 'extravasation_injury']

b_e_1_preds = pd.DataFrame(b_e_1_test_predictions, columns=column_names)


In [None]:
b_e_1_preds.head()

In [None]:
# aggregate predictions to patient level
b_e_1_preds = b_e_1_preds.groupby('patient_id').mean().reset_index()


In [None]:
b_e_1_preds.info()

In [None]:
# predict for b_e 2
print('batch_size_inference:', cfg['data']['batch_size_inference'])
b_e_model_2.eval()
b_e_2_test_predictions = []
with torch.no_grad():
    for batch_idx, batch_data in tqdm(enumerate(b_e_2_test_dataloader)):
        image = batch_data['image']
        patient_id = batch_data['patient_id']
        b, e = b_e_model_2(image)
        
        # Apply activations to get probabilities
        b_probs, e_probs = map(
                lambda x: F.softmax(x, dim=1),
                [b, e]
            )
        
        # Transfer probabilities back to CPU
        b_probs, e_probs = map(
                lambda x: x.cpu().numpy().astype(np.float64),
                [b_probs, e_probs]
            )

        # Get one prediction per series in the batch
        for i in range(b_probs.shape[0]):  # Assuming all arrays have the same size
            b_e_2_test_predictions.append([
                int(patient_id[i]), 
                *b_probs[i], 
                *e_probs[i], 
            ])

column_names = ['patient_id', 
                'bowel_healthy', 'bowel_injury', 
                'extravasation_healthy', 'extravasation_injury']

b_e_2_preds = pd.DataFrame(b_e_2_test_predictions, columns=column_names)


In [None]:
b_e_2_preds.head()

In [None]:
# aggregate predictions to patient level
b_e_2_preds = b_e_2_preds.groupby('patient_id').mean().reset_index()


In [None]:
b_e_2_preds.info()

In [None]:
# take mean
b_e_preds = pd.merge(b_e_1_preds, b_e_2_preds, on='patient_id', how='outer', suffixes=('_df1', '_df2'))
for column in ['bowel_healthy', 'bowel_injury', 'extravasation_healthy', 'extravasation_injury']:
    b_e_preds[column] = b_e_preds[[f"{column}_df1", f"{column}_df2"]].mean(axis=1)

b_e_preds.drop([col for col in b_e_preds.columns if '_df1' in col or '_df2' in col], axis=1, inplace=True)


In [None]:
# merge
df_preds = pd.merge(b_e_preds, kls_preds, on='patient_id', how='outer')
# fill na with median of each column
columns_to_fill = df_preds.columns.difference(['patient_id'])
for column in columns_to_fill:
    df_preds[column].fillna(df_preds[column].median(), inplace=True)
    
# # fill na with mean of each column multiplied by 4 for '*_low' and 'bowel_injury', 6 for '*_high', 28 for extravsation_injury
# columns_to_fill = df_preds.columns.difference(['patient_id'])
# for column in columns_to_fill:
#     if 'low' in column or 'bowel_injury' in column:
#         df_preds[column].fillna(df_preds[column].mean() * 4, inplace=True)
#     elif 'high' in column:
#         df_preds[column].fillna(df_preds[column].mean() * 6, inplace=True)
#     elif 'extravasation_injury' in column:
#         df_preds[column].fillna(df_preds[column].mean() * 28, inplace=True)
#     else:
#         df_preds[column].fillna(df_preds[column].mean(), inplace=True)

df_preds.head()


In [None]:
# # probability calibration
# for column in df_preds.columns.difference(['patient_id']):
#     df_preds[column] = df_preds[column].apply(lambda x: 0.0 if x < 0.1 else x)
#     df_preds[column] = df_preds[column].apply(lambda x: 1.0 if x > 0.9 else x)



## Submission ##

In [None]:
# merge with sample submission
if is_on_kaggle:
    sample_submission = sample_submission[['patient_id']]
    df_preds = pd.merge(sample_submission, df_preds, on='patient_id', how='left')
    # fill NaN with median of that column except patient_id 
    # in case there is no prediction for a patient in df_preds but there is in sample_submission
    columns_to_fill = df_preds.columns.difference(['patient_id'])
    df_preds[columns_to_fill] = df_preds[columns_to_fill].apply(lambda col: col.fillna(col.median()), axis=0)
    # for column in columns_to_fill:
    #     if 'low' in column or 'bowel_injury' in column:
    #         df_preds[column].fillna(df_preds[column].mean() * 4, inplace=True)
    #     elif 'high' in column:
    #         df_preds[column].fillna(df_preds[column].mean() * 6, inplace=True)
    #     elif 'extravasation_injury' in column:
    #         df_preds[column].fillna(df_preds[column].mean() * 28, inplace=True)
    #     else:
    #         df_preds[column].fillna(df_preds[column].mean(), inplace=True)

    # drop duplicates
    # df_preds = df_preds.drop_duplicates()
    
    !rm -rf {KLS_TEST_DATA_DIR}
    !rm -rf {B_E_1_TEST_DATA_DIR}
    !rm -rf {B_E_2_TEST_DATA_DIR}
    print("processed for kaggle submission!")
    


In [None]:
df_preds.info()

In [None]:
# df_preds.round(3).head()
df_preds.head()

In [None]:
# df_preds.to_csv('submission.csv', index=False, float_format='%.2f')
# df_preds.round(3).to_csv("submission.csv", index=False)
df_preds.to_csv("submission.csv", index=False)


In [None]:
# sample_submission = pd.read_csv(f'{BASE_PATH}/sample_submission.csv')
# sample_submission.to_csv("submission.csv", index=False)
# print(sample_submission.info())
# sample_submission.head()

In [None]:
## check to solve scoring error

In [None]:
# # provided by competition organizer
# import numpy as np
# import pandas as pd
# import pandas.api.types
# import sklearn.metrics


# class ParticipantVisibleError(Exception):
#     pass


# def normalize_probabilities_to_one(df: pd.DataFrame, group_columns: list) -> pd.DataFrame:
#     # Normalize the sum of each row's probabilities to 100%.
#     # 0.75, 0.75 => 0.5, 0.5
#     # 0.1, 0.1 => 0.5, 0.5
#     row_totals = df[group_columns].sum(axis=1)
#     if row_totals.min() == 0:
#         raise ParticipantVisibleError('All rows must contain at least one non-zero prediction')
#     for col in group_columns:
#         df[col] /= row_totals
#     return df


# def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
#     '''
#     Pseudocode:
#     1. For every label group (liver, bowel, etc):
#         - Normalize the sum of each row's probabilities to 100%.
#         - Calculate the sample weighted log loss.
#     2. Derive a new any_injury label by taking the max of 1 - p(healthy) for each label group
#     3. Calculate the sample weighted log loss for the new label group
#     4. Return the average of all of the label group log losses as the final score.
#     '''
# #     del solution[row_id_column_name]
# #     del submission[row_id_column_name]

#     # Run basic QC checks on the inputs
#     if not pandas.api.types.is_numeric_dtype(submission.values):
#         raise ParticipantVisibleError('All submission values must be numeric')

#     if not np.isfinite(submission.values).all():
#         raise ParticipantVisibleError('All submission values must be finite')

#     if solution.min().min() < 0:
#         raise ParticipantVisibleError('All labels must be at least zero')
#     if submission.min().min() < 0:
#         raise ParticipantVisibleError('All predictions must be at least zero')

#     # Calculate the label group log losses
#     binary_targets = ['bowel', 'extravasation']
#     triple_level_targets = ['kidney', 'liver', 'spleen']
#     all_target_categories = binary_targets + triple_level_targets

#     label_group_losses = []
#     for category in all_target_categories:
#         if category in binary_targets:
#             col_group = [f'{category}_healthy', f'{category}_injury']
#         else:
#             col_group = [f'{category}_healthy', f'{category}_low', f'{category}_high']

#         solution = normalize_probabilities_to_one(solution, col_group)

#         for col in col_group:
#             if col not in submission.columns:
#                 raise ParticipantVisibleError(f'Missing submission column {col}')
#         submission = normalize_probabilities_to_one(submission, col_group)
#         label_group_losses.append(
#             sklearn.metrics.log_loss(
#                 y_true=solution[col_group].values,
#                 y_pred=submission[col_group].values,
#                 sample_weight=solution[f'{category}_weight'].values
#             )
#         )

#     # Derive a new any_injury label by taking the max of 1 - p(healthy) for each label group
#     healthy_cols = [x + '_healthy' for x in all_target_categories]
#     any_injury_labels = (1 - solution[healthy_cols]).max(axis=1)
#     any_injury_predictions = (1 - submission[healthy_cols]).max(axis=1)
#     any_injury_loss = sklearn.metrics.log_loss(
#         y_true=any_injury_labels.values,
#         y_pred=any_injury_predictions.values,
#         sample_weight=solution['any_injury_weight'].values
#     )

#     label_group_losses.append(any_injury_loss)
#     return np.mean(label_group_losses)


In [None]:
# df_preds.head()

In [None]:
# train_df.head()

In [None]:
# # create solution df from train_df with weight columns
# train_df['bowel_weight'] = 1
# train_df['extravasation_weight'] = 1
# train_df['kidney_weight'] = 1
# train_df['liver_weight'] = 1
# train_df['spleen_weight'] = 1
# train_df['any_injury_weight'] = 1


In [None]:
# score(train_df, df_preds, '')