In [30]:
import os
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import pydicom
import argparse
import warnings
import numpy as np
import pandas as pd
from glob import glob
import nibabel
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold, StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from monai.transforms import Resize
import monai.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

**Define required directories**

In [48]:

data_dir = 'D:\RSNA-2022-cervical-spine-fracture-detection'
log_dir = './logs'
model_dir = './models'

os.makedirs(log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

**Define Transformations for input images to monai's inputs**

In [49]:
img_size = [128,128,128]
# Resize the input image to spatial size
monai_img_size = Resize(img_size)
#translate_range, a sequence of positive floats, is used to generate the n shift parameters
translate_range = [int(x*y) for x,y in zip(img_size, (0.3, 0.3, 0.3))]


transform_train_data = transforms.Compose([
    transforms.RandFlipd(keys=["image","mask"], prob=0.5, spatial_axis=1),
    transforms.RandFlipd(["image","mask"], prob=0.5, spatial_axis=2),
    transforms.RandAffined(keys=["image","mask"], translate_range=translate_range, padding_mode='zeros', prob=0.7),
    transforms.RandGridDistortiond(keys=("image", "mask"), prob=0.5, distort_limit=(-0.01, 0.01), mode="nearest")
])

**Load DataFrames**

In [50]:
df_train = pd.read_csv(os.path.join(data_dir,'train.csv'))
df_train.head()

Unnamed: 0,StudyInstanceUID,patient_overall,C1,C2,C3,C4,C5,C6,C7
0,1.2.826.0.1.3680043.6200,1,1,1,0,0,0,0,0
1,1.2.826.0.1.3680043.27262,1,0,1,0,0,0,0,0
2,1.2.826.0.1.3680043.21561,1,0,1,0,0,0,0,0
3,1.2.826.0.1.3680043.12351,0,0,0,0,0,0,0,0
4,1.2.826.0.1.3680043.1363,1,0,0,0,0,1,0,0


In [51]:
mask_files = os.listdir(f"{data_dir}/segmentations")
print(f"Number of mask files: {len(mask_files)}")

df_mask = pd.DataFrame({"mask_path":mask_files})
df_mask["StudyInstanceUID"] = df_mask["mask_path"].apply(lambda x:x[:-4])
df_mask["mask_path"] = df_mask["mask_path"].apply(lambda x: os.path.join(data_dir,'segmentations',x))
df_mask.head()

Number of mask files: 87


Unnamed: 0,mask_path,StudyInstanceUID
0,D:\RSNA-2022-cervical-spine-fracture-detection...,1.2.826.0.1.3680043.10633
1,D:\RSNA-2022-cervical-spine-fracture-detection...,1.2.826.0.1.3680043.10921
2,D:\RSNA-2022-cervical-spine-fracture-detection...,1.2.826.0.1.3680043.11827
3,D:\RSNA-2022-cervical-spine-fracture-detection...,1.2.826.0.1.3680043.11988
4,D:\RSNA-2022-cervical-spine-fracture-detection...,1.2.826.0.1.3680043.12281


In [55]:
df = df_train.merge(df_mask, on='StudyInstanceUID', how='left')
df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir,'train_images',x))
df['mask_path'].fillna('',inplace=True)
print(df.shape)
df.head()

(2019, 11)


Unnamed: 0,StudyInstanceUID,patient_overall,C1,C2,C3,C4,C5,C6,C7,mask_path,image_folder
0,1.2.826.0.1.3680043.6200,1,1,1,0,0,0,0,0,,D:\RSNA-2022-cervical-spine-fracture-detection...
1,1.2.826.0.1.3680043.27262,1,0,1,0,0,0,0,0,,D:\RSNA-2022-cervical-spine-fracture-detection...
2,1.2.826.0.1.3680043.21561,1,0,1,0,0,0,0,0,,D:\RSNA-2022-cervical-spine-fracture-detection...
3,1.2.826.0.1.3680043.12351,0,0,0,0,0,0,0,0,,D:\RSNA-2022-cervical-spine-fracture-detection...
4,1.2.826.0.1.3680043.1363,1,0,0,0,0,1,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...


In [66]:
df_segments = df.query('mask_path != ""').reset_index(drop=True)
print(df_segments.shape)
df_segments.head()

(87, 11)


Unnamed: 0,StudyInstanceUID,patient_overall,C1,C2,C3,C4,C5,C6,C7,mask_path,image_folder
0,1.2.826.0.1.3680043.1363,1,0,0,0,0,1,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...
1,1.2.826.0.1.3680043.25704,0,0,0,0,0,0,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...
2,1.2.826.0.1.3680043.20647,0,0,0,0,0,0,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...
3,1.2.826.0.1.3680043.31077,1,0,0,1,1,1,1,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...
4,1.2.826.0.1.3680043.17960,0,0,0,0,0,0,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...


**K-Fold Cross Validation**

In [80]:
# Split data into train and validation set with KFOLD Cross validation
k = 5
kf = KFold(n_splits=k, random_state=None)

df_segments['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(kf.split(df_segments)):
    df_segments.loc[valid_idx,'fold'] = fold

df_segments.head()

Unnamed: 0,StudyInstanceUID,patient_overall,C1,C2,C3,C4,C5,C6,C7,mask_path,image_folder,fold
70,1.2.826.0.1.3680043.15206,1,0,0,1,0,0,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...,4
71,1.2.826.0.1.3680043.8330,1,0,0,0,0,1,1,1,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...,4
72,1.2.826.0.1.3680043.24140,0,0,0,0,0,0,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...,4
73,1.2.826.0.1.3680043.32436,1,0,0,1,0,0,0,0,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...,4
74,1.2.826.0.1.3680043.28327,1,0,0,0,0,1,1,1,D:\RSNA-2022-cervical-spine-fracture-detection...,D:\RSNA-2022-cervical-spine-fracture-detection...,4
