## This is a customized prediction method

In [1]:
import os
import glob
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pydicom
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm import tqdm
from PIL import Image
import cv2

  warn_and_log(
  warn_and_log(
  warn_and_log(
  warn_and_log(
  warn_and_log(


In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']
IMG_SIZE = (512, 512)
TARGET_SLICES = 10
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'  

In [3]:
# Helper functions for natural sorting
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    """
    Alphanumeric sorting helper function.
    """
    return [atoi(c) for c in re.split(r'(\d+)', text)]

# Function to resample slices
def resample_slices(image_tensor, target_slices=10):
    current_slices = image_tensor.shape[0]
    if current_slices == target_slices:
        return image_tensor  # No need to resample
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]
    # If fewer slices, upsample
    image_tensor = image_tensor.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, num_slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    image_tensor_resized = image_tensor_resized.squeeze(0).squeeze(0)  # Shape: [target_slices, H, W]
    return image_tensor_resized

# Define preprocessing transformations consistent with training
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


In [4]:
# Define the test dataset class
class LumbarSpineTestDataset(Dataset):
    def __init__(self, df, study_ids, transform=None):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform

    def get_img_paths(self, study_id, series_description):
        pdf = self.df[self.df['study_id'] == study_id]
        pdf_series = pdf[pdf['series_description'] == series_description]
        image_paths = []
        for idx, row in pdf_series.iterrows():
            series_id = row['series_id']
            paths = glob.glob(f'{rd}/test_images/{study_id}/{series_id}/*.dcm')
            paths = sorted(paths, key=natural_keys)
            image_paths.extend(paths)
        return image_paths

    def read_dcm_image(self, path):
        dicom_data = pydicom.dcmread(path)
        image = dicom_data.pixel_array.astype(np.float32)
        # Normalize the image to [0, 255]
        image = (image - image.min()) / (image.max() - image.min() + 1e-6) * 255.0
        # Resize image to 512x512 using the same interpolation as training
        image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_CUBIC)
        # Convert to PIL Image in grayscale
        image = Image.fromarray(image.astype(np.uint8)).convert('L')
        return image

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.get_img_paths(study_id, series_description)
            if not image_paths:
                images[series_description] = torch.zeros((10, IMG_SIZE[0], IMG_SIZE[1]))
                continue

            if series_description == 'Axial T2':
                series_images = []
                for img_path in image_paths:
                    try:
                        img = self.read_dcm_image(img_path)
                        if self.transform:
                            img = self.transform(img)
                            img = img.squeeze(0)  # Remove channel dimension
                        series_images.append(img)
                    except Exception as e:
                        print(f"Error loading image {img_path}: {e}")
                if series_images:
                    series_tensor = torch.stack(series_images, dim=0)  # [num_slices, H, W]
                else:
                    series_tensor = torch.zeros((1, IMG_SIZE[0], IMG_SIZE[1]))
                # Resample to TARGET_SLICES (10 slices)
                series_tensor = resample_slices(series_tensor, target_slices=10)
                images[series_description] = series_tensor

            elif series_description in ['Sagittal T1', 'Sagittal T2_STIR']:
                series_images = []
                allimgs = image_paths
                len_imgs = len(allimgs)
                if len_imgs == 0:
                    images[series_description] = torch.zeros((10, IMG_SIZE[0], IMG_SIZE[1]))
                    continue
                step = len_imgs / 10.0
                st = len_imgs / 2.0 - 4.0 * step
                end = len_imgs + 0.0001
                indices = []
                for i in np.arange(st, end, step):
                    ind2 = max(0, int((i - 0.5001).round()))
                    indices.append(ind2)
                selected_imgs = []
                for ind in indices:
                    if ind >= len_imgs:
                        ind = len_imgs - 1
                    img_path = allimgs[ind]
                    try:
                        img = self.read_dcm_image(img_path)
                        if self.transform:
                            img = self.transform(img)
                            img = img.squeeze(0)  # Remove channel dimension
                        selected_imgs.append(img)
                    except Exception as e:
                        print(f"Error loading image {img_path}: {e}")
                if selected_imgs:
                    series_tensor = torch.stack(selected_imgs, dim=0)  # [10, H, W]
                else:
                    series_tensor = torch.zeros((10, IMG_SIZE[0], IMG_SIZE[1]))
                images[series_description] = series_tensor
            else:
                images[series_description] = torch.zeros((10, IMG_SIZE[0], IMG_SIZE[1]))
        sample = {
            'study_id': study_id,
            'images': images
        }
        return sample


In [5]:
# Read test_series_descriptions.csv
test_df = pd.read_csv(f'{rd}/test_series_descriptions.csv')

# Replace 'T2/STIR' with 'T2_STIR' in series descriptions
test_df['series_description'] = test_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

study_ids = test_df['study_id'].unique()

# Create the test dataset and dataloader
test_dataset = LumbarSpineTestDataset(df=test_df, study_ids=study_ids, transform=transform)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,  # Adjust based on your system
    pin_memory=True
)


In [6]:
# Define the ResNet feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=False)
        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        x = self.features(x)
        return x

# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()
        self.num_conditions = num_conditions
        self.num_classes = num_classes

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # The tensors are of shape [batch_size, in_channels, H, W]
        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps (learned by the model)
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        attended_features_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        attended_features_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        attended_features_axial = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(attended_features_t1, (1, 1)).view(attended_features_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(attended_features_t2_stir, (1, 1)).view(attended_features_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(attended_features_axial, (1, 1)).view(attended_features_axial.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, self.num_conditions, self.num_classes)  # Reshape to [batch_size, num_conditions, num_classes]

        return x, [attention_map_t1, attention_map_t2_stir, attention_map_axial]


In [7]:
# Define the EnsembleModel class
class EnsembleModel(nn.Module):
    def __init__(self, model_class, num_models, device):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList()
        for _ in range(num_models):
            model = model_class()
            model.to(device)
            model.eval()
            self.models.append(model)
        self.device = device

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        outputs_list = []
        for model in self.models:
            outputs, _ = model(sagittal_t1, sagittal_t2_stir, axial_t2)
            outputs_list.append(outputs)
        # Stack outputs and take mean over the ensemble dimension
        outputs = torch.stack(outputs_list, dim=0)
        avg_outputs = torch.mean(outputs, dim=0)
        return avg_outputs

In [8]:
# Instantiate the ensemble model
num_conditions = 25  # Number of labels
num_classes = 3
k_folds = 3  # Number of models in the ensemble


In [9]:
ensemble_model = EnsembleModel(
    model_class=lambda: MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes),
    num_models=k_folds,
    device=device
)

ensemble_model.load_state_dict(torch.load('/kaggle/input/rsna-chacha-pytorch-models/pytorch/default/20/ensemble_model_F3_E10.pth', map_location=device))
ensemble_model.to(device)
ensemble_model.eval()

  ensemble_model.load_state_dict(torch.load('/kaggle/input/rsna-chacha-pytorch-models/pytorch/default/20/ensemble_model_F3_E10.pth', map_location=device))


EnsembleModel(
  (models): ModuleList(
    (0-2): 3 x MultiSeriesSpineModel(
      (cnn_sagittal_t1): ResNetFeatureExtractor(
        (features): Sequential(
          (0): Conv2d(10, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          (4): Sequential(
            (0): BasicBlock(
              (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (relu): ReLU(inplace=True)
              (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (1)

In [10]:
# Define your label names consistent with training
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

LABELS = [f'{condition}_{level}' for condition in CONDITIONS for level in LEVELS]

# Initialize lists for row names and predictions
row_names = []
predictions = []



In [11]:
# Perform predictions
with torch.no_grad():
    for batch in tqdm(test_loader):
        study_id = batch['study_id'][0]
        images = batch['images']

        # Process images
        sagittal_t1 = images['Sagittal T1']  # Shape: [1, num_slices, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR']
        axial_t2 = images['Axial T2']

        # Resample slices if necessary (already done in dataset)
        # Stack slices into the channel dimension
        sagittal_t1 = sagittal_t1.reshape(1, -1, 512, 512).to(device)  # Shape: [batch_size, channels, H, W]
        sagittal_t2_stir = sagittal_t2_stir.reshape(1, -1, 512, 512).to(device)
        axial_t2 = axial_t2.reshape(1, -1, 512, 512).to(device)

        # Get ensemble prediction
        outputs = ensemble_model(sagittal_t1, sagittal_t2_stir, axial_t2)

        # outputs shape: [batch_size, num_conditions, num_classes]
        outputs = outputs.squeeze(0)  # Shape: [num_conditions, num_classes]

        # Apply softmax to get probabilities
        probs = F.softmax(outputs, dim=1)  # Shape: [num_conditions, num_classes]

        # Ensure the predictions are aligned with LABELS
        pred_per_study = probs.cpu().numpy()  # Shape: [num_conditions, num_classes]

        # Generate row names and collect predictions
        for label in LABELS:
            row_names.append(f'{study_id}_{label}')
        predictions.append(pred_per_study)

100%|██████████| 1/1 [00:04<00:00,  4.32s/it]


In [12]:
# Stack predictions
predictions = np.vstack(predictions)  # Shape: [num_studies * num_conditions, num_classes]

# Flatten predictions
predictions = predictions.reshape(-1, 3)  # Shape: [num_studies * num_conditions, num_classes]

# Create the submission DataFrame
submission = pd.DataFrame({
    'row_id': row_names,
    'normal_mild': predictions[:, 0],
    'moderate': predictions[:, 1],
    'severe': predictions[:, 2]
})

# Save to CSV
submission.to_csv('submission.csv', index=False)

print("Submission file 'submission.csv' has been generated.")

Submission file 'submission.csv' has been generated.


In [13]:
submission

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.326605,0.374253,0.299143
1,44036939_spinal_canal_stenosis_l2_l3,0.343004,0.325952,0.331044
2,44036939_spinal_canal_stenosis_l3_l4,0.36441,0.328353,0.307238
3,44036939_spinal_canal_stenosis_l4_l5,0.374329,0.285039,0.340632
4,44036939_spinal_canal_stenosis_l5_s1,0.323765,0.356296,0.319938
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.338018,0.357842,0.304141
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.355018,0.34565,0.299332
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.336803,0.33971,0.323488
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.369527,0.310421,0.320052
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.324511,0.35408,0.321409
