This code is created using the following [notebook](https://www.kaggle.com/code/samu2505/rsna-pytorch-train-lb-0-84-cv-0-54). 

In [1]:
!python -m pip install -q lightning

In [2]:
import os, gc, sys, copy, pickle
from pathlib import Path
import glob
from tqdm.auto import tqdm
tqdm.pandas()

import math
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
import multiprocessing as mp

import albumentations as A
import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms


In [3]:
import warnings
warnings.filterwarnings("ignore")
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [4]:
# Weights and biases login

import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("WANDB_API_KEY")
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [5]:
def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    torch.manual_seed(SEED)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True
#     os.environ['TF_CUDNN_DETERMINISTIC'] = str(SEED)
#     tf.random.set_seed(SEED)
#     keras.utils.set_random_seed(seed=SEED)
    print('seeding done!!!')

def flush():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

# Read the CSV files

In [6]:
DATA_PATH = Path("/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification")
os.listdir(DATA_PATH)

['sample_submission.csv',
 'train_images',
 'train_series_descriptions.csv',
 'train.csv',
 'train_label_coordinates.csv',
 'test_series_descriptions.csv',
 'test_images']

In [7]:
train_main = pd.read_csv(DATA_PATH/"train.csv")
train_desc = pd.read_csv(DATA_PATH/"train_series_descriptions.csv")
train_label_coordinates = pd.read_csv(DATA_PATH/"train_label_coordinates.csv")

In [8]:
train_main

Unnamed: 0,study_id,spinal_canal_stenosis_l1_l2,spinal_canal_stenosis_l2_l3,spinal_canal_stenosis_l3_l4,spinal_canal_stenosis_l4_l5,spinal_canal_stenosis_l5_s1,left_neural_foraminal_narrowing_l1_l2,left_neural_foraminal_narrowing_l2_l3,left_neural_foraminal_narrowing_l3_l4,left_neural_foraminal_narrowing_l4_l5,...,left_subarticular_stenosis_l1_l2,left_subarticular_stenosis_l2_l3,left_subarticular_stenosis_l3_l4,left_subarticular_stenosis_l4_l5,left_subarticular_stenosis_l5_s1,right_subarticular_stenosis_l1_l2,right_subarticular_stenosis_l2_l3,right_subarticular_stenosis_l3_l4,right_subarticular_stenosis_l4_l5,right_subarticular_stenosis_l5_s1
0,4003253,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,...,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
1,4646740,Normal/Mild,Normal/Mild,Moderate,Severe,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,...,Normal/Mild,Normal/Mild,Normal/Mild,Severe,Normal/Mild,Normal/Mild,Moderate,Moderate,Moderate,Normal/Mild
2,7143189,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,...,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
3,8785691,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,...,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
4,10728036,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,...,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Normal/Mild
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1970,4282019580,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,...,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Moderate,Moderate
1971,4283570761,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,...,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild
1972,4284048608,Normal/Mild,Normal/Mild,Normal/Mild,Severe,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,...,Normal/Mild,Normal/Mild,Normal/Mild,Severe,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Severe,Normal/Mild
1973,4287160193,Normal/Mild,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,...,Normal/Mild,Severe,Moderate,Moderate,Normal/Mild,Normal/Mild,Normal/Mild,Moderate,Moderate,Normal/Mild


In [9]:
train_desc

Unnamed: 0,study_id,series_id,series_description
0,4003253,702807833,Sagittal T2/STIR
1,4003253,1054713880,Sagittal T1
2,4003253,2448190387,Axial T2
3,4646740,3201256954,Axial T2
4,4646740,3486248476,Sagittal T1
...,...,...,...
6289,4287160193,1507070277,Sagittal T2/STIR
6290,4287160193,1820446240,Axial T2
6291,4290709089,3274612423,Sagittal T2/STIR
6292,4290709089,3390218084,Axial T2


In [10]:
train_label_coordinates

Unnamed: 0,study_id,series_id,instance_number,condition,level,x,y
0,4003253,702807833,8,Spinal Canal Stenosis,L1/L2,322.831858,227.964602
1,4003253,702807833,8,Spinal Canal Stenosis,L2/L3,320.571429,295.714286
2,4003253,702807833,8,Spinal Canal Stenosis,L3/L4,323.030303,371.818182
3,4003253,702807833,8,Spinal Canal Stenosis,L4/L5,335.292035,427.327434
4,4003253,702807833,8,Spinal Canal Stenosis,L5/S1,353.415929,483.964602
...,...,...,...,...,...,...,...
48687,4290709089,4237840455,11,Left Neural Foraminal Narrowing,L1/L2,219.465940,97.831063
48688,4290709089,4237840455,12,Left Neural Foraminal Narrowing,L2/L3,205.340599,140.207084
48689,4290709089,4237840455,12,Left Neural Foraminal Narrowing,L3/L4,202.724796,181.013624
48690,4290709089,4237840455,12,Left Neural Foraminal Narrowing,L4/L5,202.933333,219.733333


In [11]:
# Define function to reshape a single row of the DataFrame
def reshape_row(row):
    data = {'study_id': [], 'condition': [], 'level': [], 'severity': []}
    
    for column, value in row.items():
        if column not in ['study_id', 'series_id', 'instance_number', 'x', 'y', 'series_description']:
            parts = column.split('_')
            condition = ' '.join([word.capitalize() for word in parts[:-2]])
            level = parts[-2].capitalize() + '/' + parts[-1].capitalize()
            data['study_id'].append(row['study_id'])
            data['condition'].append(condition)
            data['level'].append(level)
            data['severity'].append(value)
    
    return pd.DataFrame(data)

# Reshape the DataFrame for all rows
new_train_df = pd.concat([reshape_row(row) for _, row in train_main.iterrows()], ignore_index=True)

# Display the first few rows of the reshaped dataframe
new_train_df.head(5)

Unnamed: 0,study_id,condition,level,severity
0,4003253,Spinal Canal Stenosis,L1/L2,Normal/Mild
1,4003253,Spinal Canal Stenosis,L2/L3,Normal/Mild
2,4003253,Spinal Canal Stenosis,L3/L4,Normal/Mild
3,4003253,Spinal Canal Stenosis,L4/L5,Normal/Mild
4,4003253,Spinal Canal Stenosis,L5/S1,Normal/Mild


In [12]:
# Merge the dataframes on the common columns
merged_df = pd.merge(new_train_df, train_label_coordinates, on=['study_id', 'condition', 'level'], how='inner')
final_merged_df = pd.merge(merged_df, train_desc, on=['series_id','study_id'], how='inner')

# Create the row_id column
final_merged_df['row_id'] = (
    final_merged_df['study_id'].astype(str) + '_' +
    final_merged_df['condition'].str.lower().str.replace(' ', '_') + '_' +
    final_merged_df['level'].str.lower().str.replace('/', '_')
)

# Create the image_path column
final_merged_df['image_path'] = (
    f'{str(DATA_PATH)}/train_images/' + 
    final_merged_df['study_id'].astype(str) + '/' +
    final_merged_df['series_id'].astype(str) + '/' +
    final_merged_df['instance_number'].astype(str) + '.dcm'
)

final_merged_df['severity'] = final_merged_df['severity'].map(
    {'Normal/Mild': 'normal_mild', 'Moderate': 'moderate', 'Severe': 'severe'}
)

train_data = final_merged_df.copy()
# Display the updated dataframe
train_data.head(5)

Unnamed: 0,study_id,condition,level,severity,series_id,instance_number,x,y,series_description,row_id,image_path
0,4003253,Spinal Canal Stenosis,L1/L2,normal_mild,702807833,8,322.831858,227.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l1_l2,/kaggle/input/rsna-2024-lumbar-spine-degenerat...
1,4003253,Spinal Canal Stenosis,L2/L3,normal_mild,702807833,8,320.571429,295.714286,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l2_l3,/kaggle/input/rsna-2024-lumbar-spine-degenerat...
2,4003253,Spinal Canal Stenosis,L3/L4,normal_mild,702807833,8,323.030303,371.818182,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l3_l4,/kaggle/input/rsna-2024-lumbar-spine-degenerat...
3,4003253,Spinal Canal Stenosis,L4/L5,normal_mild,702807833,8,335.292035,427.327434,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l4_l5,/kaggle/input/rsna-2024-lumbar-spine-degenerat...
4,4003253,Spinal Canal Stenosis,L5/S1,normal_mild,702807833,8,353.415929,483.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l5_s1,/kaggle/input/rsna-2024-lumbar-spine-degenerat...


In [13]:
train_data.iloc[0,:].image_path

'/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification/train_images/4003253/702807833/8.dcm'

In [14]:
# Define a function to check if a path exists
def check_exists(path):
    return os.path.exists(path)

# Define a function to check if a study ID directory exists
def check_study_id(row):
    study_id = row['study_id']
    path = f'{str(DATA_PATH)}/train_images/{study_id}'
    return check_exists(path)

# Define a function to check if a series ID directory exists
def check_series_id(row):
    study_id = row['study_id']
    series_id = row['series_id']
    path = f'{str(DATA_PATH)}/train_images/{study_id}/{series_id}'
    return check_exists(path)

# Define a function to check if an image file exists
def check_image_exists(row):
    image_path = row['image_path']
    return check_exists(image_path)

# Apply the functions to the train_data dataframe
train_data['study_id_exists'] = train_data.progress_apply(check_study_id, axis=1)
train_data['series_id_exists'] = train_data.progress_apply(check_series_id, axis=1)
train_data['image_exists'] = train_data.progress_apply(check_image_exists, axis=1)

# Filter train_data
train_data = train_data[(train_data['study_id_exists']) & (train_data['series_id_exists']) & (train_data['image_exists'])]

  0%|          | 0/48692 [00:00<?, ?it/s]

  0%|          | 0/48692 [00:00<?, ?it/s]

  0%|          | 0/48692 [00:00<?, ?it/s]

In [15]:
train_data.head()

Unnamed: 0,study_id,condition,level,severity,series_id,instance_number,x,y,series_description,row_id,image_path,study_id_exists,series_id_exists,image_exists
0,4003253,Spinal Canal Stenosis,L1/L2,normal_mild,702807833,8,322.831858,227.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l1_l2,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True
1,4003253,Spinal Canal Stenosis,L2/L3,normal_mild,702807833,8,320.571429,295.714286,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l2_l3,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True
2,4003253,Spinal Canal Stenosis,L3/L4,normal_mild,702807833,8,323.030303,371.818182,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l3_l4,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True
3,4003253,Spinal Canal Stenosis,L4/L5,normal_mild,702807833,8,335.292035,427.327434,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l4_l5,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True
4,4003253,Spinal Canal Stenosis,L5/S1,normal_mild,702807833,8,353.415929,483.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l5_s1,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True


# Prepare data for training

In [16]:
label2id = {v: i for i, v in enumerate(train_data['severity'].unique())}
id2label = {v:k for k,v in label2id.items()}
train_data['target'] = train_data['severity'].map(label2id)
train_data = train_data.dropna(subset=['severity']).reset_index(drop=True)
train_data.head()

Unnamed: 0,study_id,condition,level,severity,series_id,instance_number,x,y,series_description,row_id,image_path,study_id_exists,series_id_exists,image_exists,target
0,4003253,Spinal Canal Stenosis,L1/L2,normal_mild,702807833,8,322.831858,227.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l1_l2,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True,0
1,4003253,Spinal Canal Stenosis,L2/L3,normal_mild,702807833,8,320.571429,295.714286,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l2_l3,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True,0
2,4003253,Spinal Canal Stenosis,L3/L4,normal_mild,702807833,8,323.030303,371.818182,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l3_l4,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True,0
3,4003253,Spinal Canal Stenosis,L4/L5,normal_mild,702807833,8,335.292035,427.327434,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l4_l5,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True,0
4,4003253,Spinal Canal Stenosis,L5/S1,normal_mild,702807833,8,353.415929,483.964602,Sagittal T2/STIR,4003253_spinal_canal_stenosis_l5_s1,/kaggle/input/rsna-2024-lumbar-spine-degenerat...,True,True,True,0


In [17]:
train_data["target"].value_counts()

target
0    37626
1     7950
2     3081
Name: count, dtype: int64

In [18]:
train_data["target"].nunique()

3

In [19]:
import cv2
cv2.setNumThreads(0)
import PIL
import pydicom
import warnings

In [20]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data

In [21]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, transform=None, label_name='target'):
        self.dataframe = dataframe
        self.transform = transform
        self.label = dataframe.loc[:, label_name]

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        image_path = self.dataframe['image_path'][index]
        image = load_dicom(image_path)  # Define this function to load your DICOM images
        target = self.dataframe['target'][index]
        
        if self.transform:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
            image = self.transform(image=image)['image']
            image = image.transpose(2, 0, 1).astype(np.float32) / 255.

        return image, torch.tensor(target).float()
    
    def get_labels(self):
        return self.label

In [22]:
def get_transforms(height, width):
    train_tsfm = A.Compose([
        A.Resize(height=height, width=width, interpolation=cv2.INTER_CUBIC, p=1.0), # also INTER_LANCZOS4
        # Geometric augmentations
        A.Perspective(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=(-30, 30), p=0.5), 
        
        A.CenterCrop(height=height, width=width, p=1.0),
    ])
    
    valid_tsfm = A.Compose([
        A.Resize(height=height, width=width),
    ])
    return {"train": train_tsfm, "eval": valid_tsfm}


In [23]:
def get_dataloaders(data, cfg, split="train"):
    img_size = cfg['img_size']
    height, width = img_size, img_size
    tsfm = get_transforms(height=height, width=width)
    if split == 'train':
        tr_tsfm = tsfm['train']
        ds = CustomDataset(data, transform=tr_tsfm)
        labels = ds.get_labels()
#         class_weights = torch.tensor(compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels))
        class_weights = torch.tensor([1, 2, 4])
        samples_weights = class_weights[labels]
#         print(class_weights)
        sampler = WeightedRandomSampler(weights=samples_weights, 
                                        num_samples=len(samples_weights), 
                                        replacement=True)

        dls = DataLoader(ds, 
                         batch_size=cfg['batch_size'], 
                         sampler=sampler, 
                         num_workers=os.cpu_count(), 
                         pin_memory=True, 
                         drop_last=True)
        
    elif split == 'valid' or split == 'test':
        eval_tsfm = tsfm['eval']
        ds = CustomDataset(data, transform=eval_tsfm)
        dls = DataLoader(ds, 
                         batch_size=2*cfg['batch_size'], 
                         sampler=SequentialSampler(ds),
                         num_workers=os.cpu_count(), 
                         pin_memory=True,
                         drop_last=False)
    else:
        raise Exception("Split should be 'train' or 'valid' or 'test'!!!")
    return dls

In [24]:
CONFIG = dict(
    project_name = "RSNA-2024-Lumbar-Spine-Classification-Torch-RZoro",
    artifact_name = "rsnaEffNetModel",
    load_kernel = None,
    load_last = True,
    n_folds = 5,
    backbone = "efficientnet_b0.ra_in1k", # efficientnet_b0.ra_in1k, efficientnet_b2.ra_in1k, efficientnet_b5.sw_in12k
    img_size = 384,
    n_slice_per_c = 16,
    in_chans = 3,

    drop_rate = 0.,
    drop_rate_last = 0.3,
    drop_path_rate = 0.,
    p_mixup = 0.5,
    p_rand_order_v1 = 0.2,
    lr = 8e-5, # 1e-3, 8e-4, 5e-4, 4e-4

    out_dim = 3,
    epochs = 50,
    batch_size = 32,
#     patience = 7,
    device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu",
    seed = 2024,
    log_wandb = True,
    with_clip = True,
)

CONFIG['patience'] = math.ceil(0.2 * CONFIG['epochs'])

seeding(CONFIG['seed'])

seeding done!!!


In [25]:
from sklearn import model_selection

kfold = model_selection.StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['seed'])
x = train_data.index.values
y = train_data['target'].values.astype(int)
# g = train_data['series2id'].values.astype(int)

train_data['fold'] = -1
for fold, (tr_idx, val_idx) in enumerate(kfold.split(x,y)):
    train_data.loc[val_idx, 'fold'] = fold
    
train_data['fold'].value_counts()

fold
1    9732
0    9732
4    9731
3    9731
2    9731
Name: count, dtype: int64

# Prepare Model for training

In [26]:
import timm

import lightning as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, StochasticWeightAveraging
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from torch.utils.data import WeightedRandomSampler, SequentialSampler
from sklearn.utils.class_weight import compute_class_weight

In [27]:
# class TimmModel(nn.Module):
#     def __init__(self, backbone, pretrained=False):
#         super(TimmModel, self).__init__()

#         self.encoder = timm.create_model(
#             backbone,
#             num_classes=CONFIG["out_dim"],
#             features_only=False,
#             drop_rate=CONFIG["drop_rate"],
#             drop_path_rate=CONFIG["drop_path_rate"],
#             pretrained=pretrained
#         )

#         if 'efficient' in backbone:
#             hdim = self.encoder.conv_head.out_channels
#             self.encoder.classifier = nn.Identity()
#         elif 'convnext' in backbone:
#             hdim = self.encoder.head.fc.in_features
#             self.encoder.head.fc = nn.Identity()


#         self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=CONFIG["drop_rate"], bidirectional=True, batch_first=True)
#         self.head = nn.Sequential(
#             nn.Linear(512, 256),
#             nn.BatchNorm1d(256),
#             nn.Dropout(CONFIG["drop_rate_last"]),
#             nn.LeakyReLU(0.1),
#             nn.Linear(256, CONFIG["out_dim"]),
#         )

#     def forward(self, x):
#         feat = self.encoder(x)
#         feat, _ = self.lstm(feat)
#         feat = self.head(feat)
#         return feat

class TimmModel(nn.Module):
    def __init__(self, backbone, pretrained=False):
        super(TimmModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            num_classes=CONFIG["out_dim"],
            features_only=False,
            drop_rate=CONFIG["drop_rate"],
            drop_path_rate=CONFIG["drop_path_rate"],
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()


#         self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=CONFIG["drop_rate"], bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(hdim, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(CONFIG["drop_rate_last"]),
            
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(CONFIG["drop_rate_last"]),
            
            nn.LeakyReLU(0.1),
            nn.Linear(256, CONFIG["out_dim"]),
        )

    def forward(self, x):
        feat = self.encoder(x)
#         feat, _ = self.lstm(feat)
        feat = self.head(feat)
        return feat

In [28]:
class LumbarLightningModel(pl.LightningModule):
    def __init__(self, pretrained=False, lr=8e-4):
        self.save_hyperparameters()
        super().__init__()
        self.model = TimmModel(backbone=CONFIG["backbone"], pretrained=pretrained)
        class_weights = torch.tensor([1, 2, 4], dtype=torch.float32)
        self.loss_fn = nn.CrossEntropyLoss(weight=class_weights)
    
    def forward(self, images):
        return self.model(images)
    
    def shared_step(self, batch):
        images, labels = batch[0], batch[1]
        logits = self.forward(images)
        loss = self.loss_fn(logits, labels.to(torch.int64))
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("valid_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=CONFIG['lr'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"], eta_min=0)
#         scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=10)
        return [optimizer], [scheduler]

# Training in Pytorch

In [29]:
from collections import Counter, defaultdict

class MetricMonitor:
    def __init__(self, float_precision=4):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

In [30]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)
        

def shared_step(model, batch, criterion):
    image, target = batch[0], batch[1]
    image = image.to(CONFIG["device"], non_blocking=True)
    target = target.to(CONFIG["device"], non_blocking=True)
    logits = model.forward(image.to(torch.float32))
    loss = criterion(logits.view(-1, CONFIG["out_dim"]), target.view(-1).to(torch.int64))

    return {
        "loss": loss
    }


def train(train_loader, model, criterion, optimizer, epoch, scaler):
    metric_monitor = MetricMonitor()
    model.train()
    stream = tqdm(train_loader)
    train_loss = 0
    for i, batch in enumerate(stream, start=1):
        optimizer.zero_grad(set_to_none=True)
        
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = shared_step(model, batch, criterion)
            loss =  outputs['loss']
        
        metric_monitor.update("Loss", loss)
        train_loss += loss.detach().float()
        _train_metrics = {
            "train/step_loss": loss,
        }
        if CONFIG['log_wandb']:
            wandb.log(_train_metrics)
            
        # backward pass, with gradient scaling
        scaler.scale(loss).backward()
        
        # clip the gradient
        if CONFIG['with_clip']:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
#             nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#             nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)
            
        scaler.step(optimizer)
        scaler.update()
        
        stream.set_description(
            "Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
        )
        
    total_train_loss = train_loss / len(train_loader)
    
    flush()
    return _train_metrics, total_train_loss


def validate(val_loader, model, criterion, epoch):
    metric_monitor = MetricMonitor()
    model.eval()
    stream = tqdm(val_loader)
    valid_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(stream, start=1):
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                outputs = shared_step(model, batch, criterion)
                loss =  outputs['loss']

            metric_monitor.update("Loss", loss)
            valid_loss += loss.detach().float()
            _valid_metrics = {
                    "valid/step_loss": loss,
                }
            
            if CONFIG['log_wandb']:
                wandb.log(_valid_metrics)
            
            stream.set_description(
                "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)
            )
            
    total_valid_loss = valid_loss / len(val_loader)
    flush()
    return _valid_metrics, total_valid_loss

In [31]:
def train_and_validate(model, train_dataset, val_dataset, fold=0):
    if CONFIG['log_wandb']:
        run = wandb.init(
            project=CONFIG["project_name"],
            resume="allow",
        )
        artifact = wandb.Artifact(f"{CONFIG['artifact_name']}_{fold}", type="model")
    
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            DEVICE_IDS = list(range(torch.cuda.device_count()))
            print(f"\nUsing {len(DEVICE_IDS)} GPUs to train ...\n")
            model = nn.DataParallel(model, device_ids=DEVICE_IDS)
            
    model = model.to(CONFIG["device"])
    model.apply(init_weights)
    train_loader = get_dataloaders(train_dataset, CONFIG, split="train")
    valid_loader = get_dataloaders(val_dataset, CONFIG, split="valid")
    
#     criterion = nn.CrossEntropyLoss().to(device)
    
    # weighted cross entropy loss
    class_weights = torch.tensor([1, 2, 4], dtype=torch.float32)
    criterion = nn.CrossEntropyLoss(weight=class_weights).to(CONFIG["device"])
    
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"])
    scaler = torch.cuda.amp.GradScaler()
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"], eta_min=0)
    
    best_metric = np.inf
    loss_min = np.inf
    es = 0
    ES_RATIO = 0.25 if CONFIG["epochs"] < 10 else 0.20
    weights_file = "rsna_2024_lumbar_spine_fold_{fold}_epoch_{epoch}.pth"
    for epoch in range(1, CONFIG["epochs"] + 1):
        scheduler.step()
        lr = optimizer.param_groups[0]['lr']
        _train_metrics, train_loss = train(train_loader, model, criterion, optimizer, epoch, scaler)
        _valid_metrics, val_loss = validate(valid_loader, model, criterion, epoch)
        
        _train_metrics["train/loss"] = train_loss
        _valid_metrics["valid/loss"] = val_loss
        if CONFIG['log_wandb']:
            wandb.log({"learning_rate": lr})
            wandb.log({**_train_metrics, **_valid_metrics})
        
        if val_loss < best_metric:
            print(f"Best metric: ({best_metric:.6f} --> {val_loss:.6f}). Saving model ...")
            if torch.cuda.device_count() > 2:
                torch.save(model.module.state_dict(), weights_file.format(fold=fold, epoch=epoch))
            else:
                torch.save(model.state_dict(), weights_file.format(fold=fold, epoch=epoch))
            best_metric = val_loss
            if CONFIG['log_wandb']:
                if epoch == 1:
                    artifact.add_file(weights_file.format(fold=fold, epoch=epoch))
                    run.log_artifact(artifact)
                else:
                    draft_artifact = wandb.Artifact(f"{CONFIG['artifact_name']}_{fold}", type="model")
                    draft_artifact.add_file(weights_file.format(fold=fold, epoch=epoch))
                    run.log_artifact(draft_artifact)
                
            es = 0
            
        else:
            es += 1
            
        if es > math.ceil(ES_RATIO*CONFIG["epochs"]):
            print(f"Early stopping on epoch {epoch} ...")
            break
    
    if CONFIG['log_wandb']:
        wandb.config = CONFIG
        wandb.finish()
        
    del model, train_loader, valid_loader
    flush()

In [32]:
for fold in range(CONFIG["n_folds"]):
    model = TimmModel(backbone=CONFIG["backbone"], pretrained=True)
    train_ds = train_data[train_data['fold'] != fold].reset_index(drop=True)
    valid_ds = train_data[train_data['fold'] == fold].reset_index(drop=True)
    train_and_validate(model, train_ds, valid_ds, fold=fold)
    
    break
gc.collect()
flush()

model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

[34m[1mwandb[0m: Currently logged in as: [33mmandar4tech[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240712_055529-p8d03ge3[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mfancy-fog-14[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/mandar4tech/RSNA-2024-Lumbar-Spine-Classification-Torch-RZoro[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/mandar4tech/RSNA-2024-Lumbar-Spine-Classification-Torch-RZoro/runs/p8d03ge3[0m


  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

Best metric: (inf --> 0.779729). Saving model ...


  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

Best metric: (0.779729 --> 0.777265). Saving model ...


  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

Best metric: (0.777265 --> 0.719233). Saving model ...


  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

Best metric: (0.719233 --> 0.712740). Saving model ...


  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:01<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:01<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:01<?, ?it/s]

  0%|          | 0/1216 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

Early stopping on epoch 17 ...


[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:   learning_rate ████▇▇▇▆▆▅▅▄▄▃▂▂▁
[34m[1mwandb[0m:      train/loss █▆▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁
[34m[1mwandb[0m: train/step_loss █▄▃▄▄▃▄▅▄▄▂▃▅▄▂▃▃▃▃▂▂▄▂▅▃▂▂▃▂▂▂▃▃▂▁▂▂▁▂▂
[34m[1mwandb[0m:      valid/loss ▂▂▁▁▁▁▂▂▂▃▃▅▄▆▆█▇
[34m[1mwandb[0m: valid/step_loss ▅▃▃▃▄▃▄▄▂▃▃▃▂▃▃▅▄▄▃▇▃▅▅▅▆▄▁▄▃▃▅▃▃▄▃▆▄█▆▅
[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run summary:
[34m[1mwandb[0m:   learning_rate 6e-05
[34m[1mwandb[0m:      train/loss 0.38237
[34m[1mwandb[0m: train/step_loss 0.4302
[34m[1mwandb[0m:      valid/loss 0.99202
[34m[1mwandb[0m: valid/step_loss 0.29479
[34m[1mwandb[0m: 
[34m[1mwandb[0m: 🚀 View run [33mfancy-fog-14[0m at: [34m[4mhttps://wandb.ai/mandar4tech/RSNA-2024-Lumbar-Spine-Classification-Torch-RZoro/runs/p8d03ge3[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/mandar4te