### Libraries

In [22]:
!pip install -qU ../input/for-pydicom/python_gdcm-3.0.14-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ../input/for-pydicom/pylibjpeg-1.4.0-py3-none-any.whl --find-links frozen_packages --no-index

In [23]:
!pip install -q kaggle_vol3d_classify -f ../input/cervical-spine-fracture-detection-npz-3d-volumes/frozen_packages --no-index

In [24]:
# ---- efficientNet3D offline  ---
import torch.nn as nn
import sys
sys.path.append('../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D')
from efficientnet_pytorch_3d import EfficientNet3D

In [51]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.patches as patches
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.6)
import cv2
import os
from os import listdir
import re
import gc
import random
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from tqdm.auto import tqdm
from pprint import pprint
from time import time
import itertools
from skimage import measure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import nibabel as nib
from glob import glob
import warnings
#warnings.filterwarnings("ignore", category=DeprecationWarning)
#warnings.filterwarnings("ignore", category=UserWarning)
#warnings.filterwarnings("ignore", category=FutureWarning)
import zipfile
from scipy import ndimage
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed
from PIL import Image
from dipy.denoise.nlmeans import nlmeans
from dipy.denoise.noise_estimate import estimate_sigma
from kaggle_volclassif.utils import interpolate_volume
from skimage import exposure
from sklearn.model_selection import GroupKFold

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
import kornia
import kornia.augmentation as augmentation

from sklearn.model_selection import GroupKFold
from torch.cuda.amp import GradScaler, autocast
from tqdm.notebook import tqdm

In [106]:
# Set random seeds
def set_seed(seed=0):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
set_seed()

BATCH_SIZE = 16
LEARNING_RATE = 0.0001
N_EPOCHS = 20
PATIENCE = 3
EXPERIMENTAL = False
AUGMENTATIONS = True

# Config device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Data

In [28]:
# ---- Train ----

train_df = pd.read_csv("../input/rsna-2022-cervical-spine-fracture-detection/train.csv")
df_train_slices = pd.read_csv('../input/vertebrae-detection-checkpoints/train_segmented.csv')
c1c7 = [f'C{i}' for i in range(1, 8)]
df_train_slices[c1c7] = (df_train_slices[c1c7] > 0.5).astype(int)
train_df = df_train_slices.set_index('StudyInstanceUID').join(train_df.set_index('StudyInstanceUID'), rsuffix='_fracture').reset_index().copy()
train_df = train_df.query('StudyInstanceUID != "1.2.826.0.1.3680043.20574"').reset_index(drop=True)

split = GroupKFold(5)
for k, (_, test_idx) in enumerate(split.split(train_df, groups=train_df.StudyInstanceUID)):
    train_df.loc[test_idx, 'split'] = k

In [29]:
# ----- TEST ----

test_df = pd.read_csv(f'../input/rsna-2022-cervical-spine-fracture-detection/test.csv')

if test_df.iloc[0].row_id == '1.2.826.0.1.3680043.10197_C1':
    # test_images and test.csv are inconsistent in the dev dataset, fixing labels for the dev run.
    test_df = pd.DataFrame({
        "row_id": ['1.2.826.0.1.3680043.22327_C1', '1.2.826.0.1.3680043.25399_C1', '1.2.826.0.1.3680043.5876_C1'],
        "StudyInstanceUID": ['1.2.826.0.1.3680043.22327', '1.2.826.0.1.3680043.25399', '1.2.826.0.1.3680043.5876'],
        "prediction_type": ["C1", "C1", "patient_overall"]}
    )


test_slices = glob('../input/rsna-2022-cervical-spine-fracture-detection/test_images/*/*')
test_slices = [re.findall('../input/rsna-2022-cervical-spine-fracture-detection/test_images/(.*)/(.*).dcm', s)[0] for s in test_slices]
df_test_slices = pd.DataFrame(data=test_slices, columns=['StudyInstanceUID', 'Slice'])
    
test_df = test_df.set_index('StudyInstanceUID').join(df_test_slices.set_index('StudyInstanceUID')).reset_index()
test_df.sample(2)

In [30]:
if AUGMENTATIONS:
    augs = transforms.Compose([
        augmentation.RandomRotation3D((0,0,30), resample='bilinear', p=0.5, same_on_batch=False, keepdim=True),
        #augmentation.RandomHorizontalFlip3D(same_on_batch=False, p=0.5, keepdim=True),
        ])
else:
    augs=None

### Torch Dataset

In [112]:
class RSNADataset(torch.utils.data.Dataset):
    def __init__(self, subset = 'train', df_table = train_df, transform = None):
        super().__init__()
        
        self.subset = subset
        self.df_table = df_table.reset_index(drop = True)
        self.transform = transform
#         self.targets = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'patient_overall']/
        
        
        # Classification of two Dataset
        fh_paths = glob(os.path.join('../input/rsna-3d-train-tensors-first-half/train_volumes', "*.pt"))
        sh_paths = glob(os.path.join('../input/rsna-3d-train-tensors-first-half/train_volumes', "*.pt"))
        
        fh_list = []
        sh_list = []
        
        for i in fh_paths:
            fh_list.append(i.split('/')[-1][:-3])
        
        for i in sh_paths:
            sh_list.append(i.split('/')[-1][:-3])
        
        # StudyInstanceUID 의 SLice 부분의 데이터를 전부 가져오는 것이다.
        self.df_table_fh = self.df_table[self.df_table['StudyInstanceUID'].isin(fh_list)]
        self.df_table_sh = self.df_table[self.df_table['StudyInstanceUID'].isin(sh_list)]
        
        # Image path
        self.volume_dir1 = '../input/rsna-3d-train-tensors-first-half/train_volumes' 
        self.volume_dir2 = '../input/rsna-3d-train-tensors-second-half/train_volumes'
    
    def __getitem__(self, idx):
        
        if idx in self.df_table_fh.index:
            patient = self.df_table_fh[self.df_table_fh.index == idx]['StudyInstanceUID'].iloc[0]
            path = os.path.join(self.volume_dir1, f"{patient}.pt")
            vol = torch.load(path).to(torch.float32)
        
        else:
            patient = self.df_table_sh[self.df_table_sh.index == idx]['StudyInstanceUID'].iloc[0]
            path = os.path.join(self.volume_dir2, f"{patient}.pt")
            vol = torch.load(path).to(torch.float32)        
        
        if self.transform:
            vol = self.transform(vol)    
            
        if 'C1_fracture' in self.df_table:
            frac_targets = torch.as_tensor(self.df_table.iloc[idx][['C1_fracture', 'C2_fracture', 'C3_fracture', 'C4_fracture',
                                                            'C5_fracture', 'C6_fracture', 'C7_fracture']].astype('float32').values)
            
            vert_targets = torch.as_tensor(self.df_table.iloc[idx][['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']].astype('float32').values)
            frac_targets = frac_targets * vert_targets
            
            # train data
            return vol.unsqueeze(0), frac_targets, vert_targets
        
        # Validataion
        return vol.unsqueeze(0)
        
    def __len__(self):
        return len(self.df_table)
        

### Loss & Weights

In [32]:
def weighted_loss(y_pred_logit, y, reduction='mean', verbose=False):

    neg_weights = (torch.tensor([7., 1, 1, 1, 1, 1, 1, 1]) if y_pred_logit.shape[-1] == 8 else torch.ones(y_pred_logit.shape[-1])).to(device)
    pos_weights = (torch.tensor([14., 2, 2, 2, 2, 2, 2, 2]) if y_pred_logit.shape[-1] == 8 else torch.ones(y_pred_logit.shape[-1]) * 2.).to(device)

    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_pred_logit,
        y,
        reduction='none',
    )

    if verbose:
        print('loss', loss)

    pos_weights = y * pos_weights.unsqueeze(0)
    neg_weights = (1 - y) * neg_weights.unsqueeze(0)
    all_weights = pos_weights + neg_weights

    if verbose:
        print('all weights', all_weights)

    loss *= all_weights
    if verbose:
        print('weighted loss', loss)

    norm = torch.sum(all_weights, dim=1).unsqueeze(1)
    if verbose:
        print('normalization factors', norm)

    loss /= norm
    if verbose:
        print('normalized loss', loss)

    loss = torch.sum(loss, dim=1)
    if verbose:
        print('summed up over patient_overall-C1-C7 loss', loss)

    if reduction == 'mean':
        return torch.mean(loss)
    return loss

In [34]:
def filter_nones(b):
    return torch.utils.data.default_collate([v for v in b if v is not None])

def save_model(name, model):
    torch.save(model.state_dict(), f'{name}.pt')

### Model

In [52]:
class efficientNet3d(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = EfficientNet3D.from_name("efficientnet-b2", override_params={'num_classes': 7}, in_channels=1)
        n_features = self.net._fc.in_features
        self.net._fc = nn.Linear(in_features=n_features, out_features = n_features, bias=True)
        
        self.nn_fracture = torch.nn.Sequential(
            torch.nn.Linear(1408, 7)
        )
        
        self.nn_vertebrae = torch.nn.Sequential(
            torch.nn.Linear(1408, 7)
        )
        
    def forward(self, x):
        x = self.net(x)
        return self.nn_fracture(x), self.nn_vertebrae(x)

### Train 3D Model

In [81]:
# ----- Eval Pipeline
PREDICT_MAX_BATCHES = 1e9

def evaluate_model(model, ds, max_batches = PREDICT_MAX_BATCHES, shuffle = False):
    
    torch.manual_seed(42)
    model = model.to(device)
    dl_test = torch.utils.data.DataLoader(ds, batch_size = BATCH_SIZE, shuffle = shuffle)
    
    with torch.no_grad():
        model.eval()
        frac_losses = []
        vert_losses = []
        sum_losses = []
        
        for i, (X, y_frac, y_vert) in enumerate(dl_test):
            with autocast():
                y_frac_pred, y_vert_pred = model.forward(X.to(device))
                frac_loss = weighted_loss(y_frac_pred, y_frac.to(device)).item()
                vert_loss = torch.nn.functional.binary_cross_entropy_with_logits(y_vert_pred, y_vert.to(device)).item()
                loss = FRAC_LOSS_WEIGHT * frac_loss + vert_loss
                
                frac_losses.append(frac_loss)
                vert_losses.append(vert_loss)
                sum_losses.append(loss)
                
        return np.mean(frac_losses), np.mean(vert_losses), np.mean(sum_losses)

In [94]:
# ----- Train Pipeline
N_EPOCHS = 20
ONE_CYCLE_MAX_LR = 1e-4
ONE_CYCLE_PCT_START = 0.3
FRAC_LOSS_WEIGHTS = 2.
EFFNET_MAX_TRAIN_BATCHES = 4000

def train_model(ds_train, ds_eval, num):
    torch.manual_seed(42)
    name = f'{num}-fold effNet_model'
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size = BATCH_SIZE, 
                                           shuffle = True)
    print("Hello World!")    
    for batch_idx, (X, y_frac, y_vert) in dl_train:
        print("Hello World!")
    
    
    
    model = efficientNet3d().to(device)
    optim = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr = ONE_CYCLE_MAX_LR, epochs = 1,
                                                   steps_per_epoch = min(EFFNET_MAX_TRAIN_BATCHES, len(dl_train)),
                                                   pct_start = ONE_CYCLE_PCT_START)
    
    model.train()
    scaler = GradScaler()
    
    # ---------- 여기서 다시 짜야 한다.
    
    best_loss = 1e9
    for idx in range(N_EPOCHS):
        
        frac_losses = []
        vert_losses = []
        sum_losses = []
        
        
        for batch_idx, (X, y_frac, y_vert) in enumerate(dl_train):
            
            with autocast():
                y_frac_pred, y_vert_pred = model.forward(X.to(device))

                frac_loss = weighted_loss(y_frac_pred, y_frac.to(device))
                vert_loss = torch.nn.functional.binary_cross_entropy_with_logits(y_vert_pred, y_vert.to(device))
                loss = FRAC_LOSS_WEIGHT * frac_loss + vert_loss

                frac_losses.append(frac_loss)
                vert_losses.append(vert_loss)
                sum_losses.append(loss)

            optim.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            scheduler.step()
        
        frac_loss_mean, vert_loss_mean, sum_loss_mean = np.mean(frac_losses), np.mean(vert_losses), np.mean(sum_losses)
        val_frac_loss_mean, val_vert_loss_mean, val_sum_loss_mean = evaluate_model(model, ds, max_batches = PREDICT_MAX_BATCHES, shuffle = True)
        
        print(f'{idx}/Epochs {best_loss} in Update {num}/5 SCORE!!')
        print(f'===================== TRAIN ========================')
        print(f'Train_frac_loss : {frac_loss_mean}, Train_vert_loss : {vert_loss_mean}, Train_sum_loss : {sum_loss_mean}')
        print(f'===================== VALID ========================')
        print(f'Valid_frac_loss : {val_frac_loss_mean}, Valid_vert_loss : {val_vert_loss_mean}, Valid_sum_loss : {val_sum_loss_mean}')
        
        if best_loss > sum_loss_mean:
            best_loss = sum_loss_mean
            print("============================")
            print(f"{idx}/Epochs {best_loss} in Update {num}/5 3D Model !!")
            print("============================")
            save_model(name, model)

In [95]:
def gc_collect():
    gc.collect()
    torch.cuda.empty_cache()

In [101]:
# ----- FOLD LINE
effnet_models = list()
N_EPOCHS = 20

for fold in range(5):
    gc_collect()    
    
    ds_train = RSNADataset(subset = 'train', df_table = train_df.query('split != @fold'), transform = augs)
    ds_eval = RSNADataset(subset = 'valid', df_table = train_df.query('split == @fold'))
    
    train_model(ds_train, ds_eval, fold)

In [113]:
ds_train = RSNADataset(subset = 'train', df_table = train_df.query('split != 3'), transform = augs)

In [114]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size = BATCH_SIZE, 
                                           shuffle = True)

In [115]:
next(iter(dl_train))

In [None]:
temp = next(iter(ds_train)

In [None]:
temp[0].size()