In [1]:
import os
import glob
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pydicom
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm
from PIL import Image
import cv2




  warn_and_log(
  warn_and_log(
  warn_and_log(
  warn_and_log(
  warn_and_log(


In [2]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']
IMG_SIZE = (512, 512)
TARGET_SLICES = 10
rd = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'



In [3]:
# Helper functions for natural sorting
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    """
    Alphanumeric sorting helper function.
    """
    return [atoi(c) for c in re.split(r'(\d+)', text)]

# Function to resample slices
def resample_slices(volume, target_slices=10):
    """
    Resample or pad the number of slices to target_slices.
    Args:
        volume (np.ndarray): 3D array of shape [slices, H, W].
        target_slices (int): Desired number of slices.
    Returns:
        np.ndarray: Resampled 3D array.
    """
    current_slices = volume.shape[0]
    if current_slices == target_slices:
        return volume
    elif current_slices > target_slices:
        indices = np.linspace(0, current_slices - 1, target_slices).astype(int)
        return volume[indices]
    else:
        # Pad with zeros
        pad_width = target_slices - current_slices
        padding = ((0, pad_width), (0, 0), (0, 0))
        return np.pad(volume, padding, mode='constant', constant_values=0)

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts PIL Image to Tensor
    # Note: We will apply Normalize3D after stacking the scans
])

# Define the test dataset class
# Define the test dataset class
class LumbarSpineTestDataset(Dataset):
    def __init__(self, df, study_ids, transform=None, target_slices=10):
        self.df = df
        self.study_ids = study_ids
        self.transform = transform
        self.target_slices = target_slices

    def get_img_paths(self, study_id, series_description):
        pdf = self.df[self.df['study_id'] == study_id]
        pdf_series = pdf[pdf['series_description'] == series_description]
        image_paths = []
        for idx, row in pdf_series.iterrows():
            series_id = row['series_id']
            paths = glob.glob(f'{rd}/test_images/{study_id}/{series_id}/*.dcm')
            paths = sorted(paths, key=natural_keys)
            image_paths.extend(paths)
        return image_paths

    def read_dcm_image(self, path):
        dicom_data = pydicom.dcmread(path)
        image = dicom_data.pixel_array.astype(np.float32)
        # Normalize the image to [0, 255]
        image = (image - image.min()) / (image.max() - image.min() + 1e-6) * 255.0
        # Resize image to 512x512 using the same interpolation as training
        image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_CUBIC)
        # Convert to PIL Image in grayscale
        image = Image.fromarray(image.astype(np.uint8)).convert('L')
        return image

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        scans = []
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.get_img_paths(study_id, series_description)
            if not image_paths:
                # If no images, create a zero tensor
                scan = torch.zeros((self.target_slices, IMG_SIZE[0], IMG_SIZE[1]))
                scans.append(scan)
                continue
            # Read images
            series_images = []
            for img_path in image_paths:
                try:
                    img = self.read_dcm_image(img_path)
                    img = transform(img)  # Apply transform
                    img = img.squeeze(0)  # Remove channel dimension
                    series_images.append(img)
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")
            if series_images:
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, H, W]
            else:
                series_tensor = torch.zeros((1, IMG_SIZE[0], IMG_SIZE[1]))
            # Resample slices to TARGET_SLICES
            series_tensor = resample_slices(series_tensor.numpy(), target_slices=self.target_slices)
            series_tensor = torch.from_numpy(series_tensor)
            scans.append(series_tensor)
        # Stack scans along the channel dimension
        scan = torch.stack(scans, dim=0)  # Shape: [3, TARGET_SLICES, H, W]
        # Apply Normalize3D after stacking
        if self.transform:
            scan = self.transform(scan)  # Apply any additional transforms
        sample = {
            'study_id': study_id,
            'scan': scan
        }
        return sample
    
class Normalize3D(object):
    def __init__(self, mean, std):
        """
        Args:
            mean (list or tuple): Mean values for each channel.
            std (list or tuple): Standard deviation values for each channel.
        """
        self.mean = torch.tensor(mean).view(-1, 1, 1, 1)  # Shape: [C, 1, 1, 1]
        self.std = torch.tensor(std).view(-1, 1, 1, 1)    # Shape: [C, 1, 1, 1]
    
    def __call__(self, tensor):
        """
        Args:
            tensor (torch.Tensor): Tensor image of size [C, D, H, W] to be normalized.
        
        Returns:
            torch.Tensor: Normalized tensor.
        """
        return (tensor - self.mean) / self.std




In [4]:
# Read test_series_descriptions.csv
test_df = pd.read_csv(f'{rd}/test_series_descriptions.csv')

# Replace 'T2/STIR' with 'T2_STIR' in series descriptions
test_df['series_description'] = test_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

study_ids = test_df['study_id'].unique()

# Create the test dataset and dataloader
test_dataset = LumbarSpineTestDataset(
    df=test_df,
    study_ids=study_ids,
    transform=transforms.Compose([
        Normalize3D(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Apply Normalize3D
    ]),
    target_slices=TARGET_SLICES
)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,  # Adjust based on your system
    pin_memory=True
)

# Define your label names consistent with training
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]

LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]

LABELS = [f'{condition}_{level}' for condition in CONDITIONS for level in LEVELS]



In [5]:
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride
        self.se = SEBlock(planes)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.se(out)  # Apply SE block

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=512):
        super(ResNet3D, self).__init__()
        self.in_planes = 64
        # Updated to accept 3 channels instead of 1
        self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2), padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_planes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        # Input x: [batch_size, 3, slices, H, W]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)  # [batch, 64, slices, H/4, W/4]
        x = self.layer2(x)  # [batch, 128, slices/2, H/8, W/8]
        x = self.layer3(x)  # [batch, 256, slices/4, H/16, W/16]
        x = self.layer4(x)  # [batch, 512, slices/8, H/32, W/32]

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x  # [batch_size, num_classes]

class CoordAttentionModule(nn.Module):
    def __init__(self, feature_dim, coord_dim):
        super(CoordAttentionModule, self).__init__()
        self.attention_fc = nn.Sequential(
            nn.Linear(coord_dim, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim),
            nn.Sigmoid()
        )

    def forward(self, x, coords):
        attention_weights = self.attention_fc(coords)  # [batch_size, feature_dim]
        x = x * attention_weights  # Element-wise multiplication
        return x

class CoordAttention3DResNet(nn.Module):
    def __init__(self, num_classes, coord_dim):
        super(CoordAttention3DResNet, self).__init__()
        self.resnet3d = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=512)
        self.fc = nn.Linear(512, num_classes)
        self.coord_attention = CoordAttentionModule(512, coord_dim)

    def forward(self, x, coords=None):
        x = self.resnet3d(x)  # [batch_size, 512]
        if self.training and coords is not None:
            x = self.coord_attention(x, coords)
        x = self.fc(x)
        return x


In [6]:
# Now, define the EnsembleModel class
class EnsembleModel(nn.Module):
    def __init__(self, model_class, num_models, device):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList()
        for _ in range(num_models):
            # Initialize the model with necessary parameters
            model = model_class()
            model.to(device)
            model.eval()  # Set model to evaluation mode
            self.models.append(model)
        self.device = device

    def forward(self, x):
        outputs_list = []
        for model in self.models:
            outputs = model(x)
            outputs_list.append(outputs)
        # Stack outputs and take mean over the ensemble dimension
        outputs = torch.stack(outputs_list, dim=0)
        avg_outputs = torch.mean(outputs, dim=0)
        return avg_outputs

In [7]:
# Instantiate the ensemble model
num_conditions = 25  # Number of labels
num_classes_per_label = 3  # Number of classes per label
total_num_classes = num_conditions * num_classes_per_label
coord_dim = len(CONDITIONS) * len(LEVELS) * 2  # As per your training setup

# Set the number of folds and epochs (adjust according to your saved model)
k_folds = 2  # Number of models in the ensemble
num_epochs = 5  # Number of epochs used during training

# Define the model class used in the ensemble
def model_class():
    return CoordAttention3DResNet(num_classes=total_num_classes, coord_dim=coord_dim)

# Instantiate the ensemble model
ensemble_model = EnsembleModel(
    model_class=model_class,
    num_models=k_folds,
    device=device
)

# Load the ensembled model weights
ensemble_model_path = '/kaggle/input/rsna-chacha-pytorch-models/pytorch/default/17/ensemble_3d_resnet_model_F2_E5.pth'
ensemble_model.load_state_dict(torch.load(ensemble_model_path, map_location=device))
ensemble_model.to(device)
ensemble_model.eval()

  ensemble_model.load_state_dict(torch.load(ensemble_model_path, map_location=device))


EnsembleModel(
  (models): ModuleList(
    (0-1): 2 x CoordAttention3DResNet(
      (resnet3d): ResNet3D(
        (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool3d(kernel_size=3, stride=(1, 2, 2), padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): BasicBlock3D(
            (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (se): SEBlock(
              (avg_p

In [8]:
# Initialize lists for row names and predictions
row_names = []
predictions = []
num_classes = 3

# Perform predictions
with torch.no_grad():
    for batch in tqdm(test_loader):
        study_id = batch['study_id'][0]
        scan = batch['scan']  # Shape: [1, 3, slices, H, W]
        scan = scan.to(device)
        outputs = ensemble_model(scan)  # Shape: [batch_size, total_num_classes]
        outputs = outputs.squeeze(0)  # Shape: [total_num_classes]
        outputs = outputs.view(num_conditions, num_classes)  # Shape: [num_conditions, num_classes]

        # Apply softmax to get probabilities
        probs = F.softmax(outputs, dim=1)  # Shape: [num_conditions, num_classes]

        # Ensure the predictions are aligned with LABELS
        pred_per_study = probs.cpu().numpy()  # Shape: [num_conditions, num_classes]

        # Generate row names and collect predictions
        for label in LABELS:
            row_names.append(f'{study_id}_{label}')
        predictions.append(pred_per_study)



100%|██████████| 1/1 [00:09<00:00,  9.41s/it]


In [9]:
# Stack predictions
predictions = np.vstack(predictions)  # Shape: [num_studies * num_conditions, num_classes]

# Create the submission DataFrame
submission = pd.DataFrame({
    'row_id': row_names,
    'normal_mild': predictions[:, 0],
    'moderate': predictions[:, 1],
    'severe': predictions[:, 2]
})

# Save to CSV
submission.to_csv('submission.csv', index=False)

print("Submission file 'submission.csv' has been generated.")

Submission file 'submission.csv' has been generated.


In [10]:
submission

Unnamed: 0,row_id,normal_mild,moderate,severe
0,44036939_spinal_canal_stenosis_l1_l2,0.940144,0.03757,0.022285
1,44036939_spinal_canal_stenosis_l2_l3,0.834173,0.115013,0.050813
2,44036939_spinal_canal_stenosis_l3_l4,0.658272,0.216996,0.124732
3,44036939_spinal_canal_stenosis_l4_l5,0.620966,0.192703,0.186331
4,44036939_spinal_canal_stenosis_l5_s1,0.934442,0.0425,0.023058
5,44036939_left_neural_foraminal_narrowing_l1_l2,0.949889,0.042989,0.007123
6,44036939_left_neural_foraminal_narrowing_l2_l3,0.875555,0.109058,0.015388
7,44036939_left_neural_foraminal_narrowing_l3_l4,0.619475,0.335656,0.044869
8,44036939_left_neural_foraminal_narrowing_l4_l5,0.499927,0.428991,0.071083
9,44036939_left_neural_foraminal_narrowing_l5_s1,0.445183,0.41856,0.136257


In [11]:
print(num_classes)

3
