In [1]:
import os
import glob
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from sklearn.model_selection import KFold
from tqdm import tqdm
import pandas as pd
import numpy as np
import traceback




2024-09-30 18:42:42.870386: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-30 18:42:42.884103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-30 18:42:42.901847: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-30 18:42:42.907330: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-30 18:42:42.920099: I tensorflow/core/platform/cpu_feature_guar

In [2]:


# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']
CONDITIONS = [
    'spinal_canal_stenosis', 
    'left_neural_foraminal_narrowing', 
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]
LEVELS = [
    'l1_l2',
    'l2_l3',
    'l3_l4',
    'l4_l5',
    'l5_s1',
]
LABELS = [f'{condition}_{level}' for condition in CONDITIONS for level in LEVELS]

# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')

df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)

# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
# Keep only rows where 'slice_number' is not NaN
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

  df.replace(label2id, inplace=True)


In [3]:
class LumbarSpineDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None):
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir  # The root directory where images are stored
        self.transform = transform

        # Get the list of study_ids
        self.study_ids = self.df['study_id'].unique()

        # List of label columns
        self.label_columns = [col for col in df.columns if col != 'study_id']

        # Prepare a mapping for images and annotations
        self.study_image_paths = self._prepare_image_paths()

        # Create a mapping from study_id to labels
        self.labels_dict = self._prepare_labels()

    def _prepare_image_paths(self):
        study_image_paths = {}
        for study_id in self.study_ids:
            study_image_paths[study_id] = {}
            for series_description in SERIES_DESCRIPTIONS:
                series_description_clean = series_description.replace('/', '_')
                image_dir = os.path.join(self.root_dir, 'cvt_png', str(study_id), series_description_clean)
                if os.path.exists(image_dir):
                    # Get all images in the directory
                    image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
                    study_image_paths[study_id][series_description] = image_paths
                else:
                    # Handle missing series
                    study_image_paths[study_id][series_description] = []
        return study_image_paths

    def _prepare_labels(self):
        labels_dict = {}
        for idx, row in self.df.iterrows():
            study_id = row['study_id']
            labels = []
            for col in self.label_columns:
                label = row[col]
                if pd.isnull(label) or label == -100:
                    label = -100  # Use -100 for missing labels (ignore_index)
                else:
                    label = int(label)
                labels.append(label)
            labels_dict[study_id] = labels
        return labels_dict

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        annotations = {}

        # Load images for each series description
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.study_image_paths[study_id][series_description]
            series_images = []
            for img_path in image_paths:
                img = Image.open(img_path).convert('L')  # Convert to grayscale
                if self.transform:
                    img = self.transform(img)  # Shape: [1, H, W]
                    img = img.squeeze(0)  # Remove the channel dimension, resulting in [H, W]
                series_images.append(img)
            if series_images:
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, H, W]
            else:
                series_tensor = torch.zeros((1, 512, 512))  # Placeholder tensor
            images[series_description] = series_tensor  # Shape: [num_slices, H, W]

        # Get labels for the study_id
        labels = self.labels_dict[study_id]
        labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long dtype for CrossEntropyLoss

        # Generate attention masks, default to zeros if no annotations
        attention_masks = {}
        for series_description in SERIES_DESCRIPTIONS:
            series_tensor = images[series_description]
            num_slices = series_tensor.shape[0]
            masks = []
            for slice_idx in range(num_slices):
                image_shape = series_tensor[slice_idx].shape  # Get (H, W)
                mask = torch.zeros(image_shape, dtype=torch.float32)  # Default to zero mask
                # If annotations exist, generate the attention mask
                study_annotations = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
                for _, row in study_annotations.iterrows():
                    if row['series_description'] == series_description:
                        x_pixel = int(row['x_scaled'] * image_shape[1])
                        y_pixel = int(row['y_scaled'] * image_shape[0])
                        sigma = 5  # Adjust sigma
                        y_grid, x_grid = torch.meshgrid(
                            torch.arange(image_shape[0], dtype=torch.float32),
                            torch.arange(image_shape[1], dtype=torch.float32),
                            indexing='ij'
                        )
                        gauss = torch.exp(-((x_grid - x_pixel) ** 2 + (y_grid - y_pixel) ** 2) / (2 * sigma ** 2))
                        mask = torch.maximum(mask, gauss)
                masks.append(mask)
            attention_masks[series_description] = torch.stack(masks, dim=0)  # Shape: [num_slices, H, W]

        sample = {
            'study_id': study_id,
            'images': images,
            'labels': labels_tensor,
            'attention_masks': attention_masks
        }

        return sample

In [4]:


# Define any transformations if needed
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the dataset
train_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)


In [5]:
def resample_slices(image_tensor, target_slices=10):
    # Ensure the image tensor has at least 3 dimensions
    if image_tensor.dim() == 2:
        image_tensor = image_tensor.unsqueeze(0)  # Add slice dimension
    current_slices = image_tensor.shape[0]
    if current_slices == target_slices:
        return image_tensor  # No need to resample
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]
    # If fewer slices, upsample
    image_tensor = image_tensor.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, num_slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    image_tensor_resized = image_tensor_resized.squeeze(0).squeeze(0)  # Shape: [target_slices, H, W]
    return image_tensor_resized

# Early Stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='checkpoint.pth'):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_loss = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            torch.save(model.state_dict(), self.path)
            print(f"Validation loss decreased ({self.best_loss:.4f}). Saving model to {self.path}")


    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        self.best_loss = val_loss
        torch.save(model.state_dict(), 'checkpoint.pth')

# Define the ResNet feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)

        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        x = self.features(x)
        return x

# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()
        self.num_conditions = num_conditions
        self.num_classes = num_classes

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # The tensors are of shape [batch_size, in_channels, H, W]
        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps (learned by the model)
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        attended_features_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        attended_features_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        attended_features_axial = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(attended_features_t1, (1, 1)).view(attended_features_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(attended_features_t2_stir, (1, 1)).view(attended_features_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(attended_features_axial, (1, 1)).view(attended_features_axial.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, self.num_conditions, self.num_classes)  # Reshape to [batch_size, num_conditions, num_classes]

        return x, [attention_map_t1, attention_map_t2_stir, attention_map_axial]

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



In [6]:
import torch
import torch.nn.functional as F

def custom_loss(outputs, labels, any_severe_scalar=0.5):
    """
    Custom loss function that mirrors the Kaggle evaluation metric.

    Args:
        outputs (torch.Tensor): Model predictions of shape [batch_size, num_conditions * num_levels, num_classes=3].
        labels (torch.Tensor): Ground truth labels of shape [batch_size, num_conditions * num_levels] with values {0,1,2} or -100 for missing.
        any_severe_scalar (float): Scalar to weight the 'any_severe_spinal' loss component.

    Returns:
        tuple: (total_loss, per_condition_loss_dict)
    """
    # Define sample weights: 1 for normal/mild (0), 2 for moderate (1), 4 for severe (2)
    sample_weights = torch.ones_like(labels, dtype=torch.float, device=labels.device)
    sample_weights[labels == 0] = 1.0
    sample_weights[labels == 1] = 2.0
    sample_weights[labels == 2] = 4.0
    # Assign zero weight to ignored labels
    sample_weights[labels == -100] = 0.0

    # Flatten outputs and labels for easier processing
    batch_size, num_labels, num_classes = outputs.shape
    outputs_flat = outputs.view(-1, num_classes)  # [batch_size * num_conditions * num_levels, 3]
    labels_flat = labels.view(-1)                # [batch_size * num_conditions * num_levels]

    # Compute log probabilities
    log_probs = F.log_softmax(outputs_flat, dim=1)  # [N, 3]

    # Mask to filter out ignored labels
    valid_mask = labels_flat != -100
    labels_valid = labels_flat[valid_mask]          # [num_valid]
    log_probs_valid = log_probs[valid_mask, labels_valid.long()]  # [num_valid]

    # Get sample weights for valid labels
    weights_valid = sample_weights.view(-1)[valid_mask]  # [num_valid]

    # Compute weighted negative log likelihood for main labels
    loss_main = -(log_probs_valid * weights_valid).mean()

    # ----- Any Severe Spinal Prediction -----
    # Identify indices corresponding to 'spinal_canal_stenosis'
    spinal_condition = 'spinal_canal_stenosis'
    spinal_indices = [i for i, label in enumerate(LABELS) if label.startswith(spinal_condition)]

    if not spinal_indices:
        raise ValueError("No labels found for 'spinal_canal_stenosis' in LABELS.")

    # Extract outputs and labels for spinal condition
    outputs_spinal = outputs[:, spinal_indices, :]  # [batch_size, num_spinal, 3]
    labels_spinal = labels[:, spinal_indices]       # [batch_size, num_spinal]

    # Compute ground truth for 'any_severe_spinal': 1 if any spinal label is severe, else 0
    any_severe_label = (labels_spinal == 2).float().max(dim=1)[0]  # [batch_size]

    # Compute predicted probability for 'severe' class
    prob_severe = F.softmax(outputs_spinal, dim=2)[:, :, 2]  # [batch_size, num_spinal]

    # 'any_severe_pred' is the maximum probability of severe across spinal labels
    any_severe_pred = prob_severe.max(dim=1)[0]  # [batch_size]

    # Compute binary cross-entropy loss for 'any_severe_spinal'
    loss_any_severe = F.binary_cross_entropy(any_severe_pred, any_severe_label, reduction='mean')

    # ----- Per-Condition Losses -----
    # Initialize a dictionary to store per-condition losses
    per_condition_loss = {}

    for condition in ['spinal', 'foraminal', 'subarticular']:
        condition_indices = [i for i, label in enumerate(LABELS) if label.startswith(condition)]
        if not condition_indices:
            per_condition_loss[condition] = 0.0
            continue  # Skip if no labels for this condition

        outputs_condition = outputs[:, condition_indices, :]  # [batch_size, num_condition_labels, 3]
        labels_condition = labels[:, condition_indices]       # [batch_size, num_condition_labels]

        # Flatten for loss computation
        outputs_condition_flat = outputs_condition.view(-1, num_classes)
        labels_condition_flat = labels_condition.view(-1)     # [batch_size * num_condition_labels]

        # Compute log probabilities
        log_probs_condition = F.log_softmax(outputs_condition_flat, dim=1)  # [N, 3]

        # Mask to filter out ignored labels
        valid_mask_condition = labels_condition_flat != -100
        labels_valid_condition = labels_condition_flat[valid_mask_condition]          # [num_valid_condition]
        log_probs_valid_condition = log_probs_condition[valid_mask_condition, labels_valid_condition.long()]  # [num_valid_condition]

        # Get sample weights for valid labels
        condition_sample_weights = sample_weights.view(-1)[condition_indices]          # [num_condition_labels]
        weights_valid_condition = condition_sample_weights.view(-1)[valid_mask_condition]  # [num_valid_condition]

        # Compute weighted negative log likelihood
        if weights_valid_condition.numel() > 0:
            loss_condition = -(log_probs_valid_condition * weights_valid_condition).mean()
            per_condition_loss[condition] = loss_condition.item()
        else:
            per_condition_loss[condition] = 0.0

    # ----- Combine Losses -----
    total_loss = loss_main + any_severe_scalar * loss_any_severe
    per_condition_loss['total_loss'] = loss_main.item()
    per_condition_loss['any_severe_loss'] = loss_any_severe.item()

    return total_loss, per_condition_loss


In [7]:
def custom_collate_fn(batch):
    collated_batch = {}
    # Handle 'study_id' separately
    collated_batch['study_id'] = [item['study_id'] for item in batch]
    # Handle 'labels'
    labels_list = []
    for item in batch:
        labels = item['labels']
        if not isinstance(labels, torch.Tensor):
            labels = torch.tensor(labels, dtype=torch.long)
        if labels.dim() == 0:
            labels = labels.unsqueeze(0)
        labels_list.append(labels)
    collated_batch['labels'] = torch.stack(labels_list)
    # Handle 'images' and 'attention_masks'
    for key in ['images', 'attention_masks']:
        collated_batch[key] = {}
        sub_keys = batch[0][key].keys()
        for sub_key in sub_keys:
            items_list = []
            for item in batch:
                data = item[key][sub_key]
                if not isinstance(data, torch.Tensor):
                    data = torch.tensor(data)
                items_list.append(data)
            collated_batch[key][sub_key] = torch.stack(items_list)
    return collated_batch


In [8]:

# [Include your existing constants, Dataset class, model definitions, custom_loss, resample_slices, and EarlyStopping class here]

def train_one_epoch(model, device, train_loader, optimizer, any_severe_scalar, lambda_attention, writer, fold, epoch, scheduler=None):
    model.train()
    total_loss = 0.0
    train_losses = []
    per_condition_losses = {'spinal': 0.0, 'foraminal': 0.0, 'subarticular': 0.0}
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Fold {fold} Epoch {epoch}/Training", unit="batch")):
        try:
            # Extract images, labels, and attention masks
            images = batch['images']
            labels = batch['labels'].to(device)  # Shape: [batch_size, num_conditions]
            attention_masks = batch['attention_masks']

            # Process each series
            sagittal_t1 = images['Sagittal T1']  # Shape: [batch_size, num_slices, H, W]
            sagittal_t2_stir = images['Sagittal T2_STIR']
            axial_t2 = images['Axial T2']

            # Resample slices
            sagittal_t1 = [resample_slices(img.squeeze(0), target_slices=10) for img in sagittal_t1]
            sagittal_t2_stir = [resample_slices(img.squeeze(0), target_slices=10) for img in sagittal_t2_stir]
            axial_t2 = [resample_slices(img.squeeze(0), target_slices=10) for img in axial_t2]

            # Check shapes before stacking
            for img in sagittal_t1:
                if img.shape != (10, 512, 512):
                    print(f"Invalid image shape in sagittal_t1: {img.shape}")
                    raise ValueError("Invalid image shape in sagittal_t1")
            for img in sagittal_t2_stir:
                if img.shape != (10, 512, 512):
                    print(f"Invalid image shape in sagittal_t2_stir: {img.shape}")
                    raise ValueError("Invalid image shape in sagittal_t2_stir")
            for img in axial_t2:
                if img.shape != (10, 512, 512):
                    print(f"Invalid image shape in axial_t2: {img.shape}")
                    raise ValueError("Invalid image shape in axial_t2")

            # Stack slices into the channel dimension
            sagittal_t1 = torch.stack([img.reshape(-1, 512, 512) for img in sagittal_t1]).to(device)
            sagittal_t2_stir = torch.stack([img.reshape(-1, 512, 512) for img in sagittal_t2_stir]).to(device)
            axial_t2 = torch.stack([img.reshape(-1, 512, 512) for img in axial_t2]).to(device)

            # Move attention masks to the same device
            mask_t1 = attention_masks['Sagittal T1'].to(device)  # Shape: [batch_size, num_slices, H, W]
            mask_t2_stir = attention_masks['Sagittal T2_STIR'].to(device)
            mask_axial = attention_masks['Axial T2'].to(device)

            # Combine masks across slices (max over slices)
            gt_mask_t1 = torch.max(mask_t1, dim=1)[0].unsqueeze(1)  # Shape: [batch_size, 1, H, W]
            gt_mask_t2_stir = torch.max(mask_t2_stir, dim=1)[0].unsqueeze(1)
            gt_mask_axial = torch.max(mask_axial, dim=1)[0].unsqueeze(1)

            # Forward pass
            outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)

            # Compute the custom loss
            total_batch_loss, per_condition_loss = custom_loss(outputs, labels, any_severe_scalar=any_severe_scalar)

            # Compute attention loss
            attention_loss = 0.0
            attention_criterion = nn.MSELoss()
            for attention_map, gt_mask in zip(attention_maps, [gt_mask_t1, gt_mask_t2_stir, gt_mask_axial]):
                # Upsample the attention map to match the ground truth mask size
                attention_map_upsampled = F.interpolate(attention_map, size=gt_mask.shape[2:], mode='bilinear', align_corners=False)
                attention_loss += attention_criterion(attention_map_upsampled, gt_mask)

            # Total loss
            total_batch_loss = total_batch_loss + lambda_attention * attention_loss

            # Backpropagation and optimization
            optimizer.zero_grad()
            total_batch_loss.backward()
            optimizer.step()

            total_loss += total_batch_loss.item()
            train_losses.append(total_batch_loss.item())
            
            # Accumulate per-condition losses
            for condition in per_condition_loss:
                if condition in per_condition_losses:
                    per_condition_losses[condition] += per_condition_loss[condition]

            # Log batch loss to TensorBoard every 100 batches
            if (batch_idx + 1) % 100 == 0:
                writer.add_scalar(f'Fold{fold}/Train_Batch_Loss', total_batch_loss.item(), epoch * len(train_loader) + batch_idx)

        except Exception as e:
            print(f"Error processing batch during training: {e}")
            traceback.print_exc()
            continue  # Skip this batch

    avg_train_loss = total_loss / len(train_loader)
    writer.add_scalar(f'Fold{fold}/Train_Avg_Loss', avg_train_loss, epoch)
    
    # Log per-condition losses
    for condition, loss in per_condition_losses.items():
        avg_condition_loss = loss / len(train_loader)
        writer.add_scalar(f'Fold{fold}/Train_{condition}_Loss', avg_condition_loss, epoch)

    return avg_train_loss

def validate_one_epoch(model, device, val_loader, any_severe_scalar, lambda_attention, writer, fold, epoch):
    model.eval()
    total_loss = 0.0
    val_losses = []
    per_condition_losses = {'spinal': 0.0, 'foraminal': 0.0, 'subarticular': 0.0}
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc=f"Fold {fold} Epoch {epoch}/Validation", unit="batch")):
            try:
                images = batch['images']
                labels = batch['labels'].to(device)
                attention_masks = batch['attention_masks']
        
                # Preprocess images
                sagittal_t1 = [resample_slices(img.squeeze(0), target_slices=10) for img in images['Sagittal T1']]
                sagittal_t2_stir = [resample_slices(img.squeeze(0), target_slices=10) for img in images['Sagittal T2_STIR']]
                axial_t2 = [resample_slices(img.squeeze(0), target_slices=10) for img in images['Axial T2']]

                # Check shapes before stacking
                for img in sagittal_t1:
                    if img.shape != (10, 512, 512):
                        print(f"Invalid image shape in sagittal_t1 (validation): {img.shape}")
                        raise ValueError("Invalid image shape in sagittal_t1 (validation)")
                for img in sagittal_t2_stir:
                    if img.shape != (10, 512, 512):
                        print(f"Invalid image shape in sagittal_t2_stir (validation): {img.shape}")
                        raise ValueError("Invalid image shape in sagittal_t2_stir (validation)")
                for img in axial_t2:
                    if img.shape != (10, 512, 512):
                        print(f"Invalid image shape in axial_t2 (validation): {img.shape}")
                        raise ValueError("Invalid image shape in axial_t2 (validation)")

                sagittal_t1 = torch.stack([img.reshape(-1, 512, 512) for img in sagittal_t1]).to(device)
                sagittal_t2_stir = torch.stack([img.reshape(-1, 512, 512) for img in sagittal_t2_stir]).to(device)
                axial_t2 = torch.stack([img.reshape(-1, 512, 512) for img in axial_t2]).to(device)
        
                mask_t1 = attention_masks['Sagittal T1'].to(device)  # Shape: [batch_size, num_slices, H, W]
                mask_t2_stir = attention_masks['Sagittal T2_STIR'].to(device)
                mask_axial = attention_masks['Axial T2'].to(device)
        
                # Combine masks across slices (max over slices)
                gt_mask_t1 = torch.max(mask_t1, dim=1)[0].unsqueeze(1)  # Shape: [batch_size, 1, H, W]
                gt_mask_t2_stir = torch.max(mask_t2_stir, dim=1)[0].unsqueeze(1)
                gt_mask_axial = torch.max(mask_axial, dim=1)[0].unsqueeze(1)
        
                # Forward pass
                outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)

                # Compute the custom loss
                total_batch_loss, per_condition_loss = custom_loss(outputs, labels, any_severe_scalar=any_severe_scalar)

                # Compute attention loss
                attention_loss = 0.0
                attention_criterion = nn.MSELoss()
                for attention_map, gt_mask in zip(attention_maps, [gt_mask_t1, gt_mask_t2_stir, gt_mask_axial]):
                    # Upsample the attention map to match the ground truth mask size
                    attention_map_upsampled = F.interpolate(
                        attention_map,
                        size=gt_mask.shape[2:],  # This will be [H, W]
                        mode='bilinear',
                        align_corners=False
                    )
                    attention_loss += attention_criterion(attention_map_upsampled, gt_mask)

                # Total loss
                total_batch_loss = total_batch_loss + lambda_attention * attention_loss

                total_loss += total_batch_loss.item()
                val_losses.append(total_batch_loss.item())
                
                # Accumulate per-condition losses
                for condition in per_condition_loss:
                    if condition in per_condition_losses:
                        per_condition_losses[condition] += per_condition_loss[condition]
                
                # Log batch validation loss to TensorBoard every 100 batches
                if (batch_idx + 1) % 100 == 0:
                    writer.add_scalar(f'Fold{fold}/Val_Batch_Loss', total_batch_loss.item(), epoch * len(val_loader) + batch_idx)

            except Exception as e:
                print(f"Error processing batch during validation: {e}")
                traceback.print_exc()
                continue  # Skip this batch

    avg_val_loss = total_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
    writer.add_scalar(f'Fold{fold}/Val_Avg_Loss', avg_val_loss, epoch)
    
    # Log per-condition losses
    for condition, loss in per_condition_losses.items():
        avg_condition_loss = loss / len(val_loader)
        writer.add_scalar(f'Fold{fold}/Val_{condition}_Loss', avg_condition_loss, epoch)
    
    return avg_val_loss

def train_k_fold_with_custom_loss(
    train_dataset, 
    k_folds=5, 
    num_epochs=10, 
    any_severe_scalar=0.5, 
    lambda_attention=0.1,
    log_dir='./runs/spine_model_experiment'
):
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=log_dir)
    
    kfold = KFold(n_splits=k_folds, shuffle=True)
    
    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'\nFold {fold+1}/{k_folds}')
        
        # Create data loaders for this fold with custom collate function
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=1,
            sampler=train_subsampler,
            num_workers=8,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        val_loader = DataLoader(
            train_dataset,
            batch_size=1,
            sampler=val_subsampler,
            num_workers=8,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        # Initialize the model and move it to the correct device
        model = MultiSeriesSpineModel(num_conditions=len(train_dataset.label_columns), num_classes=3)
        model = model.to(device)
        
        # Initialize optimizer, scheduler, and loss functions
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                               factor=0.1, patience=2, verbose=True)
        
        # Initialize EarlyStopping with a unique path per fold
        checkpoint_path = f'best_model_fold_{fold+1}.pth'
        early_stopping = EarlyStopping(patience=3, path=checkpoint_path)
        
        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch+1}/{num_epochs}')
            
            # Training step
            avg_train_loss = train_one_epoch(
                model, device, train_loader, optimizer, 
                any_severe_scalar, lambda_attention, 
                writer, fold+1, epoch+1, scheduler
            )
            
            # Validation step
            avg_val_loss = validate_one_epoch(
                model, device, val_loader, any_severe_scalar, 
                lambda_attention, writer, fold+1, epoch+1
            )
            
            print(f"Fold {fold+1} Epoch [{epoch+1}/{num_epochs}] Avg Train Loss: {avg_train_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")
        
            # Early stopping based on validation loss
            early_stopping(avg_val_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping triggered for Fold {fold+1}!")
                break
        
        # Load the best model for this fold
        model.load_state_dict(torch.load(checkpoint_path))
        
        # Save the best model for this fold
        torch.save(model.state_dict(), f'best_model_fold_{fold+1}.pth')
        print(f"Best model for Fold {fold+1} saved to best_model_fold_{fold+1}.pth.")
    
    # Close the TensorBoard writer after training
    writer.close()




In [9]:
%load_ext tensorboard


In [10]:
from sagemaker.interactive_apps import tensorboard

region = "us-east-2"
app = tensorboard.TensorBoardApp(region)
print("Navigate to the following URL:")
print(
    app.get_app_url(
        #training_job_name="kaggle-rsna", # Optional. Specify the name of the job to track.
        open_in_default_web_browser=False           # Set to False to print the URL to terminal.
    )
)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml
Navigate to the following URL:
https://us-east-2.console.aws.amazon.com/sagemaker/home?region=us-east-2#/tensor-board-landing


In [None]:
# Example usage
n_folds = 3
n_epochs = 10
train_k_fold_with_custom_loss(train_dataset, k_folds=n_folds, num_epochs=n_epochs)


Fold 1/3





Epoch 1/10


Fold 1 Epoch 1/Training:  22%|██▏       | 289/1316 [06:09<09:21,  1.83batch/s]  

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EnsembleModel(nn.Module):
    def __init__(self, model_class, model_paths, device):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList()
        for path in model_paths:
            model = model_class()
            model.load_state_dict(torch.load(path, map_location=device))
            model.to(device)
            model.eval()  # Set model to evaluation mode
            self.models.append(model)
        self.device = device

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        outputs_list = []
        for model in self.models:
            outputs, _ = model(sagittal_t1, sagittal_t2_stir, axial_t2)
            outputs_list.append(outputs)
        # Stack outputs and take mean over the ensemble dimension
        outputs = torch.stack(outputs_list, dim=0)
        avg_outputs = torch.mean(outputs, dim=0)
        return avg_outputs


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
k_folds = n_folds
model_paths = [f'model_fold_{i+1}.pth' for i in range(k_folds)]

ensemble_model = EnsembleModel(
    model_class=lambda: MultiSeriesSpineModel(num_conditions=len(LABELS), num_classes=3),
    model_paths=model_paths,
    device=device
)

ensemble_model.to(device)
ensemble_model.eval()


In [None]:
torch.save(ensemble_model.state_dict(), f'ensemble_model_F{n_folds}_E{n_epochs}.pth')
