In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import warnings 
warnings.filterwarnings('ignore')

from matplotlib.patches import Rectangle 
import os 
import re 
import random 
import matplotlib.pyplot as plt 
import plotly 
import plotly.graph_objects as go 
from plotly.subplots import make_subplots
import plotly.express as px 
from pydicom import dcmread 
from tqdm import tqdm 
import multiprocessing as mp 
import seaborn as sns 
import datetime

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
PATH = "../input/rsna-2022-cervical-spine-fracture-detection"

train_dataframe = pd.read_csv(os.path.join(PATH, 'train.csv'))

In [None]:
train_dataframe

In [None]:
train_dataframe.info()

In [None]:
train_dataframe.head()

In [None]:
train_dataframe.shape

In [None]:
train_dataframe['nb_scans'] = train_dataframe['StudyInstanceUID'].apply(lambda x: len(os.listdir(f'../input/rsna-2022-cervical-spine-fracture-detection/train_images/{x}')))
train_dataframe['nb_fractures'] = train_dataframe[['C1','C2', 'C3', 'C4', 'C5', 'C6', 'C7']].sum(axis = 1)

In [None]:
fig = px.box(train_dataframe, y = "nb_scans", points = "all", title = 'Nb scans per patient', color_discrete_sequence = ["goldenrod"])
fig.show()

In [None]:
dict_c = {}
for i in range (1,7): 
    dict_c[f'C{i}'] = train_dataframe.loc[train_dataframe[f'C{i}'] == 1].shape[0]
    
x = list(dict_c.keys())
y = list(dict_c.values())

max_index = np.argmax(list(dict_c.values()))
colors = ['lightblue' ,] * 7 
colors[max_index] = 'goldenrod'

fig = go.Figure(
    data = [
        go.Bar(
        x = x, 
        y = y, 
        marker_color = colors)
    ])
fig.update_layout(title_text = 'Fractured vertebrae location counts')

In [None]:
colors = ['lightblue',] * 6 
colors[0] = 'goldenrod'

dict_nb_fractures = dict(train_dataframe.loc[train_dataframe['nb_fractures'] != 0, 'nb_fractures'].value_counts())

x = list(dict_nb_fractures.keys())
y = list(dict_nb_fractures.values())

fig = go.Figure(data = [go.Bar ( x = x, y = y, marker_color = colors)])

fig.update_layout(title_text = 'Fracture counts distribution')

In [None]:
y0 = train_dataframe.loc[train_dataframe["nb_fractures"] == 1, 'nb_scans']
y1 = train_dataframe.loc[train_dataframe["nb_fractures"] == 2, 'nb_scans']
y2 = train_dataframe.loc[train_dataframe["nb_fractures"] == 3, 'nb_scans']
y3 = train_dataframe.loc[train_dataframe["nb_fractures"] == 4, 'nb_scans']
y4 = train_dataframe.loc[train_dataframe["nb_fractures"] == 5, 'nb_scans']
y5 = train_dataframe.loc[train_dataframe["nb_fractures"] == 6, 'nb_scans']

fig = go.Figure()
fig.add_trace(go.Box(y=y0, name = "1", marker_color = 'lightblue'))
fig.add_trace(go.Box(y=y1, name = "2", marker_color = 'lightblue'))
fig.add_trace(go.Box(y=y2, name = "3", marker_color = 'goldenrod'))
fig.add_trace(go.Box(y=y3, name = "4", marker_color = 'lightblue'))
fig.add_trace(go.Box(y=y4, name = "5", marker_color = 'goldenrod'))
fig.add_trace(go.Box(y=y5, name = "6", marker_color = 'lightblue'))

fig.update_layout(title_text = 'Scan counts per nb fracture', showlegend = False)
fig.update_xaxes(title = 'At least')
fig.show()

# READ DICOM FILES

**What is a DICOM file?**

A DICOM files is an image saved in Digital Imaging and Communications in Medicine format. It contains an image from a medical scan, such as an ultrasound or an MRI. DICOM files may also include identification data for patients to link the image to a specific individual. 

In [None]:
ds = dcmread(os.path.join(PATH, 'train_images', '1.2.826.0.1.3680043.10001/1.dcm'))
ds

# EXTRACTING METADATA FROM THE DICOM FILES

In [None]:
def extract_file(file_path):
    
    ds = dcmread(file_path)
    image_id = file_path.split(sep="/")[-2]

    observation_dict = {}
    observation_dict['image_id'] = image_id
    
    file_meta_keys = list(ds.file_meta._dict.keys())
    remaining_meta_keys = list(ds._dict.keys())
    
    for key in file_meta_keys:
        observation_dict[str(key)] = str(ds.file_meta[key].value)
        
    # Not taking into account pixel value
    for key in remaining_meta_keys:
        if key != (0x7fe0, 0x0010):
            observation_dict[str(key)] = str(ds[key].value)
        
    return observation_dict

In [None]:
extract_file('../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.10001/1.dcm')

In [None]:
mapper_dict = { 
    'image_id' : 'image_id', 
     '(0002, 0001)' : "File Meta Information Version", 
     '(0002, 0002)' : "Media Storage SOP Class UID", 
     '(0002, 0003)' : "Media Storage SOP Instance UID", 
     '(0002, 0010)' : "Transfer Syntax UID", 
     '(0002, 0012)' : "Implementation Class UID", 
     '(0002, 0013)' : "Implementation Version Name", 

     '(0008, 0018)' : "SOPInstanceUID", 
     '(0008, 0023)' : "Date of Creation", 
     '(0008, 0033)' : "Time of Creation", 

     '(0010, 0010)' : "Patient Name", 
     '(0010, 0020)' : "Patient ID", 
     
     '(0018, 0050)' : "Slice Thickness", 
     
     '(0020, 000d)' : "Study Instance UID", 
     '(0020, 000e)' : "Series Instance UID", 
     '(0020, 0013)' : "Instance Number", 
     '(0020, 0032)' : "Image Position (Patient)", 
     '(0020, 0037)' : "Image Orientation (Patient)", 

     '(0028, 0002)' : "Samples per Pixel", 
     '(0028, 0004)' : "Photometric Interpretation", 
     '(0028, 0010)' : "Rows", 
     '(0028, 0011)' : "Columns", 
     '(0028, 0030)' : "Pixel Spacing", 
     '(0028, 0100)' : "Bits Allocated", 
     '(0028, 0101)' : "Bits Stored", 
     '(0028, 0102)' : "High Bit", 
     '(0028, 0103)' : "Pixel Representation", 
     '(0028, 1050)' : "Window Center", 
     '(0028, 1051)' : "Window Width", 
     '(0028, 1052)' : "Rescale Intercept", 
     '(0028, 1053)' : "Rescale Slope"} 

In [None]:
def meta_information_one_folder(folder): 
    
    folder_filenames = os.listdir(os.path.join(PATH, folder))
    one_obs = extract_file(os.path.join(PATH, folder, folder_filenames[0]))
    metadata = pd.DataFrame(columns = one_obs.keys())
    
    print(f'Extracting metadata from folder {folder}')
    for filename in tqdm(folder_filenames): 
        one_obs = extract_file(os.path.join(PATH, folder, filename))
        metadata = metadata.append(one_obs, ignore_index = True)
        
    metadata.columns = metadata.columns.map(mapper_dict)
    metadata.to_csv(f"dicom_metadata.csv", index = False)
    
    return metadata

In [None]:
metadata = meta_information_one_folder('train_images/1.2.826.0.1.3680043.10001')

In [None]:
for column in metadata.columns: 
    if len(set(metadata[column])) != 1: 
        print(column)

# PREPROCESSING METADATA

In [None]:
metadata['Date of Creation'] = metadata['Time of Creation'].apply(lambda x: datetime.datetime.fromtimestamp(eval(x)).strftime('%Y-%m-%d %H:%M:%S'))
metadata[['Media Storage SOP Instance UID', 'SOPInstanceUID', 'Time of Creation', 'Date of Creation', 'Instance Number', 'Image Position (Patient)']].head()

# SOME OBSERVATIONS

In [None]:
def get_random_files_from_patient(path):
   
    
    return random.sample(os.listdir(path), 9)


def rescale_image(dicom_file):
   
    
    image = dicom_file.pixel_array.flatten()
    rescaled_image = image * dicom_file.RescaleSlope + dicom_file.RescaleIntercept
    
    return image, rescaled_image


def display_images(files_list, graph_indexes = np.arange(9)):
   
    
   
    fig, axs = plt.subplots(3,3, figsize=(20,12))
    for idx, file in enumerate(files_list):
        
        full_path = os.path.join('../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.10001', file)
        ds = dcmread(full_path)

        axs[idx//3, 0].imshow(ds.pixel_array, cmap=plt.get_cmap('gray'))   
        axs[idx//3, 0].axis("off")
        
        image, rescaled_image = rescale_image(ds)
        
        sns.distplot(image.flatten(), ax=axs[idx//3, 1]);
        sns.distplot(rescaled_image.flatten(), ax=axs[idx//3, 2])
        axs[idx//3, 1].set_title("Raw pixel array distributions")
        axs[idx//3, 2].set_title("HU unit distributions");    
        
        
  
    plt.subplots_adjust(bottom = 0.001)
    plt.subplots_adjust(top = 0.99)
    
  
    plt.show()

In [None]:
files_to_display = get_random_files_from_patient('../input/rsna-2022-cervical-spine-fracture-detection/train_images/1.2.826.0.1.3680043.10001')
display_images(files_to_display)

# BOUNDING BOXES 

In [None]:
bb_train = pd.read_csv('../input/rsna-2022-cervical-spine-fracture-detection/train_bounding_boxes.csv')
bb_train.head()

In [None]:
bb_train.shape

In [None]:
bb_train.isnull().sum()

In [None]:
!pip install monai


In [None]:
!pip install -qU "python_gdcm" pydicom pylibjpeg

In [None]:
import os
import gc
from monai.transforms import LoadImaged, EnsureChannelFirstd, ResampleToMatchd, Orientationd, Compose
import monai
import numpy as np
from tqdm import tqdm
import multiprocessing
from ipywidgets import interactive, widgets, fixed
from matplotlib import animation, rc; rc('animation', html='jshtml')
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
image_dir = "../input/rsna-2022-cervical-spine-fracture-detection/train_images/"
mask_dir = "../input/rsna-2022-cervical-spine-fracture-detection/segmentations/"
mask_list = os.listdir(mask_dir)
image_list = os.listdir(image_dir)

In [None]:
transform = Compose(
    [
        LoadImaged(reader=("PydicomReader", "nibabelreader"), keys=["image", "seg"]),
        EnsureChannelFirstd(keys=["image", "seg"]),
        # unify the orientations of image and mask
        Orientationd(keys=["image", "seg"], axcodes="RAS"),
    ]
)

In [None]:
def create_animation(img, seg, seg_rev=False, fps=10):

    images = img
    segs = seg
    
    if seg_rev:
        segs = segs[::-1]
    ims_sgs = [np.concatenate([images[i], segs[i]], axis=1) for i in range(len(images))]
    
  
    animation_arr = np.stack(ims_sgs, axis=0)
    
    del images, ims_sgs
    gc.collect()
    
  
    fig = plt.figure(figsize=(5,5), dpi=160)  
    im = plt.imshow(animation_arr[0], cmap='bone')
    plt.axis('off')
    
  
    def animate_func(i):
        im.set_array(animation_arr[i])
        return [im]
    plt.close()
    
    anim = animation.FuncAnimation(fig, animate_func, frames = animation_arr.shape[0], interval = 1000//fps)
    
    return anim

In [None]:
img_sample = "1.2.826.0.1.3680043.25704"
mask_sample = f"{img_sample}.nii"

data = {"image": os.path.join(image_dir, img_sample), "seg": os.path.join(mask_dir, mask_sample)}
output = transform(data)

img = output["image"].numpy().transpose([0, 3, 2, 1])[0]
seg = output["seg"].numpy().transpose([0, 3, 2, 1])[0]

img = (img-np.min(img))/(np.max(img)-np.min(img)+1e-6)
img = (img*255).astype(np.uint8)

seg = np.where(seg>0, 255, 0).astype(np.uint8)


create_animation(img, seg, fps=30)

In [None]:
img_sample = "1.2.826.0.1.3680043.1363"
mask_sample = f"{img_sample}.nii"

data = {"image": os.path.join(image_dir, img_sample), "seg": os.path.join(mask_dir, mask_sample)}
output = transform(data)

img = output["image"].numpy().transpose([0, 3, 2, 1])[0]
seg = output["seg"].numpy().transpose([0, 3, 2, 1])[0]

img = (img-np.min(img))/(np.max(img)-np.min(img)+1e-6)
img = (img*255).astype(np.uint8)

seg = np.where(seg>0, 255, 0).astype(np.uint8)


create_animation(img, seg, fps=30)

In [None]:
!pip install -q monai
!pip install -q git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

In [None]:
import os 
import re 
import gc 
import cv2 
import wandb 
from PIL import Image 
import random 
import math 
import shutil 
from glob import glob 
from tqdm import tqdm 
from pprint import pprint 
from time import time 
import warnings 
import pandas as pd 
import numpy as np 
import seaborn as sns 
import matplotlib as mpl 
import matplotlib.patches as patches 
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg 
from matplotlib.offsetbox import AnnotationBbox, OffsetImage 
from matplotlib.colors import ListedColormap, LinearSegmentedColormap 
from matplotlib.patches import Rectangle 
from IPython.display import display_html 
plt.rcParams.update({'font.size' : 16})

warnings.filterwarnings("ignore")
os.environ["WANDB_SILENT"] = "true"
CONFIG = {'competition' : 'RSNA_SpineFructure', '_wandb_kernel' : 'aot'}

class clr: 
    S = '\033[1m' + '\033[94m'
    E = '\033[0m'
    
my_colors = ["#5EAFD9", "#449DD1", "#3977BB", 
             "#2D51A5", "#5C4C8F", "#8B4679", 
             "#C53D4C", "#E23836", "#FF4633", "#FF5746"]
CMAP1 = ListedColormap(my_colors)

print(clr.S+"Notebook Color Schemes:"+clr.E)
sns.palplot(sns.color_palette(my_colors))
plt.show()

In [None]:
import torch 
from torch.utils.data import TensorDataset, DataLoader, Dataset 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
from torch.optim import lr_scheduler 
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler 
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR 
import torchvision 
import torchvision.transforms as transforms 
from warmup_scheduler import GradualWarmupScheduler 
import albumentations 

from sklearn.model_selection import GroupKFold, train_test_split, StratifiedKFold 
from sklearn.metrics import roc_auc_score, cohen_kappa_score, confusion_matrix 

from monai.transforms import Randomizable, apply_transform 
from monai.transforms import Compose, Resize, ScaleIntensity, ToTensor, RandAffine 
from monai.networks.nets import densenet 

In [None]:
from kaggle_secrets import UserSecretsClient 
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb")

! wandb login $secret_value_0

In [None]:
def set_seed(seed = 0): 
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def show_values_on_bars(axs, h_v = "v", space = 0.4): 
    def _show_on_single_plot(ax): 
        if h_v == "v": 
            for p in ax.patches: 
                _x = p.get_x() + p.get_width() / 2
                _y = p.get_y() + p.get_height()
                value = int(p.get_height())
                ax.text(_x, _y, format(value, ','), ha = "center")
        elif h_v == "h": 
            for p in ax.patches: 
                _x = p.get_x() + p.get_width() + float(space)
                _y = p.get_x() + p.get_height()
                value = int(p.get_height())
                ax.text(_x, _y, format(value, ','), ha = "left")
                
    if isinstance(axs, np.ndarray): 
        for idx, ax in np.ndenumerate(axs): 
            _show_on_single_plot(ax)
        else: 
            _show_on_single_plot(axs)
            
def atoi(text): 
    return int(text) if text.isdigit() else text 

def natural_keys(text): 
    
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

def save_dataset_artifact(run_name, artifact_name, path, data_type = "dataset"): 
    
    run = wandb.init(project = 'RSNA_SpineFructure', 
                     name = run_name, 
                     config = CONFIG)
    artifact = wandb.Artifact(name = artifact_name, type = data_type)
    artifact.add_file(path)
    
    wandb.log_artifact(artifact)
    wandb.finish()
    print("Artifact has been saved successfully.")
    
def create_wandb_plot(x_data = None, y_data = None, x_name = None, y_name = None, title = None, log = None, plot = "line"): 
    data = [[label, val] for (label, val) in zip(x_data, y_data)]
    table = wandb.Table(data = data, columns = [x_name, y_name])
    
    if plot == "line": 
        wandb.log({log : wandb.plot.line(table, x_name, y_name, title = title)})
    elif plot == "bar": 
        wandb.log({log : wandb.plot.bar(table, x_name, y_name, title = title)})
    elif plot == "scatter": 
        wandb.log({log : wandb.plot.scatter(table, x_name, y_name, title = title)})

        
def create_wandb_hist(x_data = None, x_name = None, title = None, log = None): 
    
    data = [[x] for x in x_data]
    table = wandb.Table(data = data, columns = [x_name])
    wandb.log({log : wandb.plot.histogram(table, x_name, title = title)})

In [None]:
set_seed(0)

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(clr.S+"Device:"+clr.E, DEVICE)

DF_SIZE = 0.03
N_SPLITS = 5
KERNEL_TYPE = 'densenet121_baseline'
IMG_RESIZE = 100 
STACK_RESIZE = 50 
use_amp = False 
NUM_WORKERS = 1 
BATCH_SIZE = 2 
LR = 0.05 
OUT_DIM = 8 
EPOCHS = 2 

In [None]:
target_cols = ['C1', 'C2', 'C3', 
               'C4', 'C5', 'C6', 'C7', 
               'patient_overall']

In [None]:
competition_weights = { 
     '-' : torch.tensor([1,1,1,1,1,1,1,7], dtype = torch.float, device = DEVICE), 
     '+' : torch.tensor([2,2,2,2,2,2,2,14], dtype = torch.float, device = DEVICE)}

In [None]:
logits = torch.tensor([[0.2221, 0.1037, 0.0739, 0.1112, 0.1026, 0.0902, 0.1597, 0.1365], 
                       [0.1702, 0.0952, 0.0815, 0.1262, 0.1185, 0.1097, 0.1675, 0.1312]], device = DEVICE)
print(clr.S+"Prediction:"+clr.E, "\n", logits)

targets = torch.tensor([[0., 0., 0. , 0. , 0. , 0. , 0. , 0.], 
                        [1. , 0. , 0. , 0. , 0. , 0. , 0. , 1.]], device = DEVICE)
print(clr.S+"Target:"+clr.E, "\n", targets)

In [None]:
weights = targets * competition_weights ['+'] + (1 - targets) + competition_weights['-']
print(clr.S+"Weights:"+clr.E, "\n", weights)

In [None]:
L = torch.zeros(targets.shape, device=DEVICE)

w = weights
y = targets
p = logits

for i in range(L.shape[0]):
    for j in range(L.shape[1]):
        L[i, j] = -w[i, j] * (
            y[i, j] * math.log(p[i, j]) +
            (1 - y[i, j]) * math.log(1 - p[i, j]))
        
print(clr.S+"LOSSES:"+clr.E, "\n", L)

In [None]:
Exams_Loss = torch.div(torch.sum(L, dim = 1), torch.sum(w, dim = 1))
print(clr.S+"Exam Losses:"+clr.E, "\n", Exams_Loss)

In [None]:
def get_custom_loss(logits, targets): 
    
    weights = targets * competition_weights['+'] + (1 - targets) * competition_weights['-']
    
    L = torch.zeros(targets.shape, device = DEVICE)
    
    w = weights 
    y = targets 
    p = logits 
    eps = 1e-8
    
    for i in range(L.shape[0]): 
        for j in range(L.shape[1]): 
            L[i, j] = -w[i, j] * (y[i, j] * math.log(p[i, j] + eps) + (1 - y[i, j]) * math.log(1 - p[i, j] + eps))
            
    Exams_Loss = torch.div(torch.sum(L, dim = 1), torch.sum(w, dim = 1))
    
    return Exams_Loss

In [None]:
np.random.seed(0)

df = pd.read_csv("../input/rsna-2022-cervical-spine-fracture-detection/train.csv")

instances = df.StudyInstanceUID.unique().tolist()
instances = random.sample(instances, k=int(len(instances)*DF_SIZE))
df = df[df["StudyInstanceUID"].isin(instances)].reset_index(drop=True)
print(clr.S+"Dataframe size:"+clr.E, df.shape)

kfold = GroupKFold(n_splits=N_SPLITS)
df['fold'] = -1

for k, (_, valid_i) in enumerate(kfold.split(df,
                                             groups=df.StudyInstanceUID)):
    df.loc[valid_i, 'fold'] = k
    
print(clr.S+"K Folds Count:"+clr.E)
df["fold"].value_counts()

In [None]:
class RSNADataset(Dataset, Randomizable): 
    
    def __init__(self, csv, mode, transform = None): 
        self.csv = csv 
        self.mode = mode 
        self.transform = transform 
        
    def __len__(self): 
        return self.csv.shape[0]
    
    def randomize(self) -> None: 
        
        MAX_SEED = np.iinfo(np.uint32).max + 1
        self.seed = self.R.randint(MAX_SEED, dtype = "uint32")
        
    def __getitem__(self, index): 
        self.randomize()
        
        dt = self.csv.iloc[index, :]
        study_paths = glob(f"../input/rsna-fracture-detection/zip_png_images/{dt.StudyInstanceUID}/*")
        study_paths.sort(key = natural_keys)
        
        study_images = [cv2.imread(path)[:,:,::-1] for path in study_paths]
        stacked_image = np.stack([img.astype(np.float32) for img in study_images], axis = 2).transpose(3,0,1,2)
        
        if self.transform: 
            if isinstance(self.transform, Randomizable): 
                self.transform.set_random_state(seed = self.seed)
                
            stacked_image = apply_transform(self.transform, stacked_image)
            
        if self.mode == "test": 
            return{"image" : stacked_image}
        else: 
            targets = torch.tensor(dt[target_cols]).float()
            return {"image" : stacked_image, "targets" : targets}

In [None]:
def data_to_device(data): 
    image, targets = data.values()
    return image.to(DEVICE), targets.to(DEVICE)

In [None]:
train_transforms = Compose([ScaleIntensity(), Resize((IMG_RESIZE, IMG_RESIZE, STACK_RESIZE)), ToTensor()])
valid_transforms = Compose([ScaleIntensity(), Resize((IMG_RESIZE, IMG_RESIZE, STACK_RESIZE)), ToTensor()])

In [None]:
sample_df = df.head(6)

dataset = RSNADataset(csv = sample_df, mode = "train", transform = train_transforms)
dataloader = DataLoader(dataset, batch_size = 3, shuffle = False)

for k, data in enumerate(dataloader): 
    image, targets = data_to_device(data)
    print(clr.S + f"Batch: {k}" + clr.E, "\n" + 
          clr.S + "Image:" + clr.E, image.shape, "\n" + 
          clr.S + "Targets:" + clr.E, targets, "\n"+ 
          "="*50)

In [None]:
del dataset, dataloader, image, targets 
gc.collect()

In [None]:
CRITERION = nn.BCEWithLogitsLoss(reduction = 'none')

def get_criterion(logits, target): 
    loss = CRITERION(logits.view(-1), target.view(-1))
    return loss

In [None]:
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
   
    
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, 
                                                       total_epoch, after_scheduler)
    
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier 
                                                     for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier 
                    for base_lr in self.base_lrs]
        
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) 
                    for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) 
                    for base_lr in self.base_lrs]


In [None]:
def add_in_file(text, f): 
    with open(f'log_{KERNEL_TYPE}.txt', 'a+') as f: 
        print(text, file = f)

In [None]:
def train_epoch(model, dataloader, optimizer, epoch, f):
    
    
    print("Training...")
    add_in_file('Training...', f)
    
   
    start_time = time()
    
   
    model.train()
    train_losses, train_comp_losses = [], []
    
   
    bar = tqdm(dataloader)
    for data in bar:
        image, targets = data_to_device(data)
        
      
        optimizer.zero_grad()
        logits = model(image)
        loss = get_criterion(logits, targets)
        loss.sum().backward()
        optimizer.step()
        
   
        comp_loss = get_custom_loss(logits, targets)

       
        train_losses.append(loss.detach().cpu().numpy())
        train_comp_losses.append(comp_loss.detach().cpu().numpy().mean())
        
        gc.collect()

  
    mean_train_loss = np.mean(train_losses)
    mean_comp_loss = np.mean(train_comp_losses)
    
    total_time = round((time() - start_time)/60, 3)
    add_in_file('Train Mean Loss: {}'.format(mean_train_loss), f)
    add_in_file('Train Mean Comp Loss: {}'.format(mean_comp_loss), f)
    add_in_file('~~~ Train Time: {} mins ~~~'.format(total_time), f)
    
   
    wandb.log({"train_loss": mean_train_loss,
               "train_comp_loss": mean_comp_loss,}, step=epoch)
                
  
    print(clr.S+"Train Mean Loss:"+clr.E, mean_train_loss)
    print(clr.S+"Train Mean Comp Loss:"+clr.E, mean_comp_loss)
    print(clr.S+f"~~~ Train Time: {total_time} mins ~~~"+clr.E)
    
    return mean_train_loss

In [None]:
def valid_epoch(model, dataloader, epoch, f):
    
   
    print("Validation...")
    add_in_file('Validation...', f)
    
   
    start_time = time()
    
    
    model.eval()
    valid_preds, valid_targets, valid_comp_loss = [], [], []
    
    with torch.no_grad():
        for data in dataloader:
            
            image, targets = data_to_device(data)
            logits = model(image)
            
            comp_loss = get_custom_loss(logits, targets)
          
            valid_targets.append(targets.detach().cpu())
            valid_preds.append(logits.detach().cpu())
            valid_comp_loss.append(comp_loss.detach().cpu().numpy().mean())
            
            gc.collect()

    valid_losses = get_criterion(torch.cat(valid_preds), torch.cat(valid_targets)).numpy()
    mean_valid_loss = np.mean(valid_losses)
    
    mean_comp_valid_loss = np.mean(valid_comp_loss)
    
    PREDS = np.concatenate(torch.cat(valid_preds).numpy())
    TARGETS = np.concatenate(torch.cat(valid_targets).numpy())
    auc = roc_auc_score(TARGETS, PREDS)

    total_time = round((time() - start_time)/60, 3)
    add_in_file('Valid Mean Loss: {}'.format(mean_valid_loss), f)
    add_in_file('Valid Mean Comp Loss: {}'.format(mean_comp_valid_loss), f)
    add_in_file('Valid AUC: {}'.format(auc), f)
    add_in_file('~~~ Valid Time: {} mins ~~~'.format(total_time), f)
    
    wandb.log({"valid_loss": mean_valid_loss,
               "valid_comp_loss": mean_comp_valid_loss,
               "valid_auc": auc}, step=epoch)
     
    print(clr.S+"Valid Mean Loss:"+clr.E, mean_valid_loss)
    print(clr.S+"Valid Mean Comp Loss:"+clr.E, mean_comp_valid_loss)
    print(clr.S+"Valid AUC:"+clr.E, auc)
    print(clr.S+f"~~~ Validation Time: {total_time} mins ~~~"+clr.E)
    
    return mean_valid_loss

In [None]:
def run_train(fold):
    
  
    RUN_CONFIG = CONFIG.copy()
    params = dict(model="densenet121", 
                  epochs=EPOCHS, 
                  split=N_SPLITS, 
                  batch=BATCH_SIZE, lr=LR,
                  img_size=IMG_RESIZE, stack_size=STACK_RESIZE,
                  data_size=DF_SIZE)
    RUN_CONFIG.update(params)
    run = wandb.init(project='RSNA_SpineFructure', config=CONFIG)
    
   
    train = df[df["fold"] != fold].reset_index(drop=True)
    valid = df[df["fold"] == fold].reset_index(drop=True)
    
    
    train_dataset = RSNADataset(csv=train, mode="train", 
                                transform=train_transforms)
    valid_dataset = RSNADataset(csv=valid, mode="train", 
                                transform=valid_transforms)
    
    trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                             sampler=RandomSampler(train_dataset))
    validloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
    

    model = densenet.densenet121(spatial_dims=3, in_channels=3,
                                 out_channels=OUT_DIM)
    model.class_layers.out = nn.Sequential(nn.Linear(in_features=1024, out_features=OUT_DIM), 
                                           nn.Softmax(dim=1))
    model.to(DEVICE)
    wandb.watch(model, log_freq=100) # 🐝
    
  
    optimizer = optim.Adam(model.parameters(), lr=LR)
    scheduler_cosine = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 2)
    scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, 
                                                total_epoch=1, 
                                                after_scheduler=scheduler_cosine)
    
   
    valid_loss_BEST = 1000
  
    model_file = f'{KERNEL_TYPE}_best_fold{fold}.pth'
   
    f = open(f'log_{KERNEL_TYPE}.txt', 'a')
    
    
    for epoch in range(EPOCHS):
        
        add_in_file('======== Epoch: {}/{} ========'.format(epoch+1, EPOCHS), f)
        print("="*8, clr.S+f"Epoch {epoch}"+clr.E, "="*8)
        
        scheduler_warmup.step(epoch-1)
    
       
        mean_train_loss = train_epoch(model, trainloader, optimizer, epoch, f)
        mean_valid_loss = valid_epoch(model, validloader, epoch, f)
        
    
        if mean_valid_loss < valid_loss_BEST:
            print('Saving model ...')
            add_in_file('Saving model => {}'.format(model_file), f)
            torch.save(model.state_dict(), model_file)
            valid_loss_BEST = mean_valid_loss
            
    torch.cuda.empty_cache()
    gc.collect()
    
  
    wandb.finish()

In [None]:
run_train(fold=0)

In [None]:
f = open('../input/rsna-fracture-detection/log_densenet121_baseline.txt', 'r')
print(f.read())
f.close()

In [None]:
save_dataset_artifact(run_name="save_logs", artifact_name="logs",
                      path="../input/rsna-fracture-detection/log_densenet121_baseline.txt", data_type="dataset")
save_dataset_artifact(run_name="save_model", artifact_name="model",
                      path="../input/rsna-fracture-detection/densenet121_baseline_best_fold0.pth", data_type="model")