Import the Libraries

In [None]:
from sklearn.model_selection import KFold
from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import timm
from timm.utils import ModelEmaV2
from transformers import get_cosine_schedule_with_warmup
import albumentations as A
from sklearn.model_selection import KFold
import re
import pydicom

In [None]:
import os
import gc
import sys
from PIL import Image
import cv2
import math, random
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

Config

In [None]:
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'
OUTPUT_DIR = f'/kaggle/input/rsna2024-lsdc-training-baseline/rsna24-results'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
# Configuration
N_WORKERS = os.cpu_count()
USE_AMP = True
SEED = 8620
IMG_SIZE = (512, 512)
IN_CHANS = 42
N_LABELS = 25
N_CLASSES = 3 * N_LABELS
N_FOLDS = 5
MODEL_NAME = "edgenext_base.in21k_ft_in1k"
BATCH_SIZE = 1
data_dir = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'

In [None]:
# Device setup
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load DataFrames
df = pd.read_csv(f'{data_dir}/test_series_descriptions.csv')
study_ids = df['study_id'].unique().tolist()

sample_sub = pd.read_csv(f'{data_dir}/sample_submission.csv')
LABELS = sample_sub.columns[1:].tolist()

In [None]:
# Conditions and Levels
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

In [None]:
# Helper functions
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [atoi(c) for c in re.split(r'(\d+)', text)]

Dataloader

In [None]:
import glob
import cv2
import pydicom
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A

class RSNA24TestDataset(Dataset):
    def __init__(self, df, study_ids, phase='test', transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform
        self.phase = phase

    def __len__(self):
        return len(self.study_ids)

    def get_img_paths(self, study_id, series_desc):
        series_df = self.df[(self.df['study_id'] == study_id) & 
                            (self.df['series_description'] == series_desc)]
        img_paths = []
        for _, row in series_df.iterrows():
            paths = sorted(glob.glob(f'{data_dir}/test_images/{study_id}/{row["series_id"]}/*.dcm'), 
                           key=natural_keys)
            img_paths.extend(paths)
        return img_paths
    def read_dcm_image(self, src_path):
        dicom_data = pydicom.dcmread(src_path)
        image = dicom_data.pixel_array
        norm_img = (image - image.min()) / (image.max() - image.min() + 1e-6) * 255
        resized_img = cv2.resize(norm_img, IMG_SIZE, interpolation=cv2.INTER_CUBIC)
        return resized_img.astype(np.uint8)

    def load_series_images(self, study_id, series_desc, start_idx):
        images = np.zeros((IMG_SIZE[0], IMG_SIZE[1], 14), dtype=np.uint8)
        img_paths = self.get_img_paths(study_id, series_desc)
        
        if not img_paths:
            print(f'{study_id}: {series_desc} has no images')
            return images
        
        step = len(img_paths) / 14.0
        mid_point = len(img_paths) / 2.0 - 6.0 * step
        
        for j, i in enumerate(np.arange(mid_point, len(img_paths), step)):
            try:
                idx = max(0, int(round(i - 0.5)))
                images[..., j] = self.read_dcm_image(img_paths[idx])
            except Exception as e:
                print(f'Failed to load {series_desc} for {study_id}: {e}')
                
        return images
    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        channels = 42
        x = np.zeros((IMG_SIZE[0], IMG_SIZE[1], channels), dtype=np.uint8)

        # Load images for each series
        x[..., :14] = self.load_series_images(study_id, 'Sagittal T1', 0)
        x[..., 14:28] = self.load_series_images(study_id, 'Sagittal T2/STIR', 14)
        x[..., 28:] = self.load_series_images(study_id, 'Axial T2', 28)

        # Apply transformations
        if self.transform:
            x = self.transform(image=x)['image']

        x = x.transpose(2, 0, 1)  # Channels-first for PyTorch
        return x, str(study_id)

# Transformations
transforms_test = A.Compose([
    A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
    A.Normalize(mean=0.5, std=0.5)
])

# Dataset and DataLoader
test_ds = RSNA24TestDataset(df, study_ids, transform=transforms_test)
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=N_WORKERS,
                     pin_memory=True, drop_last=False)

Define Model

In [None]:
import torch.nn as nn
import timm

class RSNA24Model(nn.Module):
    """
    Custom model class for RSNA 2024 Lumbar Spine Degenerative Classification.
    
    Args:
        model_name (str): Name of the model architecture from the TIMM library.
        in_c (int): Number of input channels. Default is 42.
        n_classes (int): Number of output classes. Default is 75.
        pretrained (bool): If True, use pre-trained weights. Default is True.
        features_only (bool): If True, return features only instead of final output. Default is False.
    """
    
    def __init__(self, model_name, in_c=42, n_classes=75, pretrained=True, features_only=False):
        super().__init__()
        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            features_only=features_only,
            in_chans=in_c,
            num_classes=n_classes,
            global_pool='avg'
        )

    def forward(self, x):
        """
        Forward pass for the model.
        
        Args:
            x (torch.Tensor): Input tensor with shape (batch_size, in_c, height, width).
            
        Returns:
            torch.Tensor: Model output with shape (batch_size, n_classes).
        """
        return self.model(x)
        

In [None]:
df.head()

Load Models

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

# Constants
CKPT_PATHS = sorted([
    "/kaggle/input/rsna-2024-edgenext-base/model_fold-0.pt",
    "/kaggle/input/rsna-2024-edgenext-base/model_fold-1.pt",
    "/kaggle/input/rsna-2024-edgenext-base/model_fold-2.pt",
])

# Load Models
models = []
for cp in CKPT_PATHS:
    print(f'Loading checkpoint: {cp}...')
    model = RSNA24Model(MODEL_NAME, IN_CHANS, N_CLASSES, pretrained=False)
    try:
        model.load_state_dict(torch.load(cp))
    except Exception as e:
        print(f'Error loading {cp}: {e}')
        continue

    model.eval()
    model.half()
    model.to(device)
    models.append(model)

# Autocast for mixed precision
autocast = torch.cuda.amp.autocast(enabled=USE_AMP, dtype=torch.half)

# Predictions and Submission
y_preds = []
row_names = []

with torch.no_grad():
    for x, study_id in tqdm(test_dl, leave=True):
        x = x.to(device)
        pred_per_study = np.zeros((N_LABELS, 3))

        # Generate row names for submission
        for cond in CONDITIONS:
            for level in LEVELS:
                row_names.append(f'{study_id[0]}_{cond}_{level}')

        with autocast:
            for model in models:
                y = model(x)[0]
                for col in range(N_LABELS):
                    pred = y[col * 3 : (col + 1) * 3]
                    y_pred = pred.float().softmax(dim=0).cpu().numpy()
                    pred_per_study[col] += y_pred / len(models)
        
        y_preds.append(pred_per_study)

# Combine and save predictions
y_preds = np.concatenate(y_preds, axis=0)

submission_df = pd.DataFrame({
    'row_id': row_names,
    **{label: y_preds[:, i] for i, label in enumerate(LABELS)}
})

submission_df.to_csv('submission.csv', index=False)

# Verify Submission
print(pd.read_csv('submission.csv').head(3))
