### Libraries

In [1]:
!pip install -qU ../input/for-pydicom/python_gdcm-3.0.14-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ../input/for-pydicom/pylibjpeg-1.4.0-py3-none-any.whl --find-links frozen_packages --no-index

In [2]:
!pip install -q kaggle_vol3d_classify -f ../input/cervical-spine-fracture-detection-npz-3d-volumes/frozen_packages --no-index

In [23]:
# ---- efficientNet3D offline  ---
import torch.nn as nn
import sys
sys.path.append('../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D')
from efficientnet_pytorch_3d import EfficientNet3D

In [54]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.patches as patches
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.6)
import cv2
import os
from os import listdir
import re
import gc
import random
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from tqdm.auto import tqdm
from pprint import pprint
from time import time
import itertools
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import nibabel as nib
from glob import glob
import warnings
warnings.filterwarnings('ignore')
#warnings.filterwarnings("ignore", category=DeprecationWarning)
#warnings.filterwarnings("ignore", category=UserWarning)
#warnings.filterwarnings("ignore", category=FutureWarning)
import zipfile
from scipy import ndimage
from sklearn.model_selection import train_test_split, GroupKFold
from joblib import Parallel, delayed
from PIL import Image
from dipy.denoise.nlmeans import nlmeans
from dipy.denoise.noise_estimate import estimate_sigma
from kaggle_volclassif.utils import interpolate_volume
from skimage import exposure

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
import kornia
import kornia.augmentation as augmentation

## Setting & Data

In [55]:
# Set random seeds
def set_seed(seed=0):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
set_seed()

BATCH_SIZE = 4
LEARNING_RATE = 0.0001
N_EPOCHS = 20
PATIENCE = 3
EXPERIMENTAL = False
AUGMENTATIONS = True

# Config device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load metadata
train_df = pd.read_csv("../input/rsna-2022-cervical-spine-fracture-detection/train.csv")
train_bbox = pd.read_csv("../input/rsna-2022-cervical-spine-fracture-detection/train_bounding_boxes.csv")
test_df = pd.read_csv("../input/rsna-2022-cervical-spine-fracture-detection/test.csv")
ss = pd.read_csv("../input/rsna-2022-cervical-spine-fracture-detection/sample_submission.csv")


bad_scans = ['1.2.826.0.1.3680043.20574','1.2.826.0.1.3680043.29952']

for uid in bad_scans:
    train_df.drop(train_df[train_df['StudyInstanceUID']==uid].index, axis=0, inplace=True)

    
debug = False
if len(ss)==3:
    debug = True
    
    # Fix mismatch with test_images folder
    test_df = pd.DataFrame(columns = ['row_id','StudyInstanceUID','prediction_type'])
    for i in ['1.2.826.0.1.3680043.22327','1.2.826.0.1.3680043.25399','1.2.826.0.1.3680043.5876']:
        for j in ['C1','C2','C3','C4','C5','C6','C7','patient_overall']:
            test_df = test_df.append({'row_id':i+'_'+j,'StudyInstanceUID':i,'prediction_type':j},ignore_index=True)
    
    # Sample submission
    ss = pd.DataFrame(test_df['row_id'])
    ss['fractured'] = 0.5
    

if AUGMENTATIONS:
    augs = transforms.Compose([
        augmentation.RandomRotation3D((0,0,30), resample='bilinear', p=0.5, same_on_batch=False, keepdim=True),
        #augmentation.RandomHorizontalFlip3D(same_on_batch=False, p=0.5, keepdim=True),
        ])
else:
    augs=None

### CustomDataset

In [56]:
class RSNADataset(Dataset):
    # Initialise
    def __init__(self, subset='train', df_table = train_df, transform=None):
        super().__init__()
        
        self.subset = subset
        self.df_table = df_table.reset_index(drop=True)
        self.transform = transform
        self.targets = ['C1','C2','C3','C4','C5','C6','C7','patient_overall']
        
        # Identify files in each of the two datasets
        fh_paths = glob(os.path.join('../input/rsna-3d-train-tensors-first-half/train_volumes', "*.pt"))
        sh_paths = glob(os.path.join('../input/rsna-3d-train-tensors-second-half/train_volumes', "*.pt"))
        
        fh_list = []
        sh_list = []
        for i in fh_paths:
            fh_list.append(i.split('/')[-1][:-3])
        
        for i in sh_paths:
            sh_list.append(i.split('/')[-1][:-3])
        
        self.df_table_fh = self.df_table[self.df_table['StudyInstanceUID'].isin(fh_list)]
        self.df_table_sh = self.df_table[self.df_table['StudyInstanceUID'].isin(sh_list)]
        
        # Image paths
        self.volume_dir1 = '../input/rsna-3d-train-tensors-first-half/train_volumes'  # <=1000 patient
        self.volume_dir2 = '../input/rsna-3d-train-tensors-second-half/train_volumes' # >1000 patient

        # Populate labels
        self.labels = self.df_table[self.targets].values
        
    # Get item in position given by index
    def __getitem__(self, index):
        if index in self.df_table_fh.index:
            patient = self.df_table_fh[self.df_table_fh.index==index]['StudyInstanceUID'].iloc[0]
            path = os.path.join(self.volume_dir1, f"{patient}.pt")
            vol = torch.load(path).to(torch.float32)
        else:
            patient = self.df_table_sh[self.df_table_sh.index==index]['StudyInstanceUID'].iloc[0]
            path = os.path.join(self.volume_dir2, f"{patient}.pt")
            vol = torch.load(path).to(torch.float32)
        
        # Data augmentations
        if self.transform:
            vol = self.transform(vol)
        
        return vol.unsqueeze(0), self.labels[index]

    # Length of dataset
    def __len__(self):
        return len(self.df_table['StudyInstanceUID'])

In [57]:
train_df = train_df.reset_index()
train_df.drop(['index'], axis = 1, inplace = True)
train_df.head(2)

In [58]:
N_FOLDS = 5

split = GroupKFold(N_FOLDS)
for k, (_, test_idx) in enumerate(split.split(train_df, groups=train_df.StudyInstanceUID)):
    train_df.loc[test_idx, 'split'] = k

train_df.head(2)

### Setting for Train

In [59]:
def gc_collect():
    gc.collect()
    torch.cuda.empty_cache()
    
def return_model():
    model = EfficientNet3D.from_name("efficientnet-b2", override_params={'num_classes': 8}, in_channels=1)
    return model

In [60]:
loss_fn = nn.BCEWithLogitsLoss(reduction='none')

competition_weights = {
    '-' : torch.tensor([1, 1, 1, 1, 1, 1, 1, 7], dtype=torch.float, device=device),
    '+' : torch.tensor([2, 2, 2, 2, 2, 2, 2, 14], dtype=torch.float, device=device),
}

def competiton_loss_row_norm(y_hat, y):
    loss = loss_fn(y_hat, y.to(y_hat.dtype))
    weights = y * competition_weights['+'] + (1 - y) * competition_weights['-']
    loss = (loss * weights).sum(axis=1)
    w_sum = weights.sum(axis=1)
    loss = torch.div(loss, w_sum)
    return loss.mean()

In [61]:
def evaluate_model(dl_valid, model):
    val_loss_acc = 0
    valid_count = 0
    model.eval()
    with torch.no_grad():
        for val_imgs, val_labels in dl_valid:
            val_imgs = val_imgs.to(device)
            val_labels = val_labels.to(device)
            
            val_preds = model(val_imgs)
            val_L = competition_weights(val_preds, val_labels)
            
            val_loss_acc += val_L.item()
            valid_count += 1
        
        
        return val_loss_acc/valid_count

In [62]:
def save_model(model, fold):
    torch.save(model.state_dict(), f'{fold}/5_EffV2_model.pt')

In [67]:
def train_model(dl_train, dl_valid, fold):
    loss_hist = []
    val_loss_hist = []
    best_valid = 1.0
    
    model = return_model()
    model.to(device)
    optimizer = optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=N_EPOCHS)
    
    for epoch in tqdm(range(N_EPOCHS)):
        loss_acc = 0
        train_count = 0
        
        for imgs, labels in dl_train:
            imgs = imgs.to(device)
            labels = labels.to(device)
            
            preds = model(imgs)
            L = competiton_loss_row_norm(preds, labels)
            L.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            loss_acc += L.detach().item()
            train_count += 1
        
        scheduler.step()
        loss_acc /= train_count
        val_loss = evaluate_model(dl_valid, model)

        # append History
        loss_hist.append(loss_acc)
        val_loss_hist.append(val_loss)
                
        print(f'{fold}/5 | {epoch}/{N_EPOCHS} | train_loss {loss_acc} | valid_loss : {val_loss}')
        
        if best_valid > val_loss:
            best_valid = val_loss
            save_model(model, fold)
            

In [None]:
def start_model_train_eval():
    for fold in range(N_FOLDS):
        train_dataset = RSNADataset(subset='train', df_table = train_df.query('split != @fold'), transform = augs)
        valid_dataset = RSNADataset(subset='valid', df_table = train_df.query('split == @fold'), transform = None)

        dl_train = DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, shuffle = True)
        dl_valid = DataLoader(dataset = valid_dataset, batch_size = BATCH_SIZE, shuffle = False)

        gc_collect()        

        train_model(dl_train, dl_valid, fold)
        
start_model_train_eval()