In [1]:
import os
import glob
import re
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import pydicom
from PIL import Image

## Define Constants

In [2]:
# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']
IMG_SIZE = (512, 512)
TARGET_SLICES = 10
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'  # Adjust this path to your test data root directory
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Helper functions
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    """
    Alphanumeric sorting helper function.
    """
    return [atoi(c) for c in re.split(r'(\d+)', text)]

## Define the Test Dataset Class:

In [4]:
class LumbarSpineTestDataset(Dataset):
    def __init__(self, df, study_ids, transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform

    def get_img_paths(self, study_id, series_description):
        pdf = self.df[self.df['study_id'] == study_id]
        pdf_series = pdf[pdf['series_description'] == series_description]
        image_paths = []
        for idx, row in pdf_series.iterrows():
            series_id = row['series_id']
            paths = glob.glob(f'{rd}/test_images/{study_id}/{series_id}/*.dcm')
            paths = sorted(paths, key=natural_keys)
            image_paths.extend(paths)
        return image_paths

    def read_dcm_image(self, path):
        dicom_data = pydicom.dcmread(path)
        image = dicom_data.pixel_array.astype(np.float32)
        # Normalize the image
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)
        # Convert to PIL Image and grayscale
        image = Image.fromarray((image * 255).astype(np.uint8)).convert('L')
        return image

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.get_img_paths(study_id, series_description)
            series_images = []
            for img_path in image_paths:
                try:
                    img = self.read_dcm_image(img_path)
                    if self.transform:
                        img = self.transform(img)
                        img = img.squeeze(0)  # Remove the channel dimension, resulting in [H, W]
                    series_images.append(img)
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")
            if series_images:
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, H, W]
                # Resample slices to TARGET_SLICES
                series_tensor = resample_slices(series_tensor, target_slices=TARGET_SLICES)
            else:
                # If no images, create a tensor of zeros
                series_tensor = torch.zeros((TARGET_SLICES, IMG_SIZE[0], IMG_SIZE[1]))
            images[series_description] = series_tensor  # Shape: [TARGET_SLICES, H, W]
        sample = {
            'study_id': study_id,
            'images': images
        }
        return sample


## Define the Resampling Function:

In [5]:

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])



In [6]:
def resample_slices(image_tensor, target_slices=10):
    current_slices = image_tensor.shape[0]
    if current_slices == target_slices:
        return image_tensor  # No need to resample
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]
    # If fewer slices, upsample
    image_tensor = image_tensor.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, num_slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    image_tensor_resized = image_tensor_resized.squeeze(0).squeeze(0)  # Shape: [target_slices, H, W]
    return image_tensor_resized

## Prepare the Test DataLoader:

In [7]:
# Read test_series_descriptions.csv
test_df = pd.read_csv(f'{rd}/test_series_descriptions.csv')

# Replace 'T2/STIR' with 'T2_STIR' in series descriptions
test_df['series_description'] = test_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

study_ids = test_df['study_id'].unique()

# Create the test dataset and dataloader
test_dataset = LumbarSpineTestDataset(df=test_df, study_ids=study_ids, transform=transform)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,  # Adjust based on your system
    pin_memory=True
)


## Define the Model Classes:


In [8]:
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10, resnet_weights_path=None):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=False)
        if resnet_weights_path:
            resnet.load_state_dict(torch.load(resnet_weights_path))
        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        x = self.features(x)
        return x

# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()
        self.num_conditions = num_conditions
        self.num_classes = num_classes

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # The tensors are of shape [batch_size, in_channels, H, W]

        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps (learned by the model)
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        attended_features_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        attended_features_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        attended_features_axial = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(attended_features_t1, (1, 1)).view(attended_features_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(attended_features_t2_stir, (1, 1)).view(attended_features_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(attended_features_axial, (1, 1)).view(attended_features_axial.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, self.num_conditions, self.num_classes)  # Reshape to [batch_size, num_conditions, num_classes]

        return x, [attention_map_t1, attention_map_t2_stir, attention_map_axial]



## Load the Trained Model:


In [9]:
# Instantiate the model
num_conditions = 25  # Number of labels
num_classes = 3

resnet_weights_path = '/kaggle/input/rsna-chacha-pytorch-models/pytorch/default/9/resnet18-f37072fd.pth'# Get pretrained weights from local


model = MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes)

# Load the trained model's state_dict
model_save_path = '/kaggle/input/rsna-chacha-pytorch-models/pytorch/default/9/checkpoint.pth'
model.load_state_dict(torch.load(model_save_path, map_location=device))

# Move the model to device
model = model.to(device)
model.eval()




  model.load_state_dict(torch.load(model_save_path, map_location=device))


MultiSeriesSpineModel(
  (cnn_sagittal_t1): ResNetFeatureExtractor(
    (features): Sequential(
      (0): Conv2d(10, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   

In [10]:
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

LABELS = [f'{condition}_{level}' for condition in CONDITIONS for level in LEVELS]

row_names = []
predictions = []



In [11]:
with torch.no_grad():
    for batch in tqdm(test_loader):
        study_id = batch['study_id'][0]
        images = batch['images']

        # Process sagittal_t1
        sagittal_t1 = images['Sagittal T1']  # Shape: [num_slices, H, W]
        sagittal_t1 = sagittal_t1.unsqueeze(0)  # Add batch dimension: [1, num_slices, H, W]
        sagittal_t1 = sagittal_t1.reshape(1, -1, 512, 512).to(device)  # Shape: [1, num_slices, H, W]

        # Process sagittal_t2_stir
        sagittal_t2_stir = images['Sagittal T2_STIR']
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).reshape(1, -1, 512, 512).to(device)

        # Process axial_t2
        axial_t2 = images['Axial T2']
        axial_t2 = axial_t2.unsqueeze(0).reshape(1, -1, 512, 512).to(device)

        # Now pass these tensors to the model
        outputs, _ = model(sagittal_t1, sagittal_t2_stir, axial_t2)

        # outputs shape: [1, num_conditions, num_classes]
        outputs = outputs.squeeze(0)  # Shape: [num_conditions, num_classes]

        # Apply softmax to get probabilities
        probs = F.softmax(outputs, dim=1)  # Shape: [num_conditions, num_classes]
        pred_per_study = probs.cpu().numpy()  # Shape: [num_conditions, num_classes]

        # Generate row names and collect predictions
        for label in LABELS:
            row_names.append(f'{study_id}_{label}')
        predictions.append(pred_per_study)


100%|██████████| 1/1 [00:04<00:00,  4.56s/it]


In [12]:
# Stack predictions
predictions = np.vstack(predictions)  # Shape: [num_studies * num_conditions, num_classes]

# Flatten predictions
predictions = predictions.reshape(-1, 3)  # Shape: [num_studies * num_conditions, num_classes]

submission = pd.DataFrame({
    'row_id': row_names,
    'normal_mild': predictions[:, 0],
    'moderate': predictions[:, 1],
    'severe': predictions[:, 2]
})

submission.to_csv('submission.csv', index=False)

In [13]:
submission

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.544477,0.272434,0.183088
1,44036939_spinal_canal_stenosis_l2_l3,0.483547,0.317369,0.199083
2,44036939_spinal_canal_stenosis_l3_l4,0.439831,0.316598,0.243571
3,44036939_spinal_canal_stenosis_l4_l5,0.436395,0.300468,0.263137
4,44036939_spinal_canal_stenosis_l5_s1,0.540835,0.264725,0.19444
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.6032,0.270725,0.126075
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.511741,0.333944,0.154316
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.448163,0.359459,0.192378
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.408756,0.347851,0.243393
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.403958,0.326679,0.269364
