In [100]:
import os
import glob
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.model_selection import KFold
from torchvision.models import resnet18, ResNet18_Weights

In [101]:
# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')
df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)

# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
# Keep only rows where 'slice_number' is not NaN
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']

In [102]:
class LumbarSpineDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None):
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir  # The root directory where images are stored
        self.transform = transform

        # Get the list of study_ids
        self.study_ids = self.df['study_id'].unique()

        # List of label columns (assuming all columns except 'study_id' are labels)
        self.label_columns = [col for col in df.columns if col != 'study_id']

        # Prepare a mapping for images and annotations
        self.study_image_paths = self._prepare_image_paths()

        # Create a mapping from study_id to labels
        self.labels_dict = self._prepare_labels()

    def _prepare_image_paths(self):
        study_image_paths = {}
        for study_id in self.study_ids:
            study_image_paths[study_id] = {}
            for series_description in SERIES_DESCRIPTIONS:
                series_description_clean = series_description.replace('/', '_')
                image_dir = os.path.join(self.root_dir, 'cvt_png', str(study_id), series_description_clean)
                if os.path.exists(image_dir):
                    # Get all images in the directory
                    image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
                    study_image_paths[study_id][series_description] = image_paths
                else:
                    # Handle missing series
                    study_image_paths[study_id][series_description] = []
        return study_image_paths

    def _prepare_labels(self):
        labels_dict = {}
        for idx, row in self.df.iterrows():
            study_id = row['study_id']
            labels = []
            for col in self.label_columns:
                label = row[col]
                if pd.isnull(label) or label == -100:
                    label = -100  # Use -100 for missing labels (ignore_index)
                else:
                    label = int(label)
                labels.append(label)
            labels_dict[study_id] = labels
        return labels_dict

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        annotations = {}

        # Load images for each series description
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.study_image_paths[study_id][series_description]
            series_images = []
            for img_path in image_paths:
                img = Image.open(img_path).convert('L')  # Convert to grayscale
                if self.transform:
                    img = self.transform(img)  # Shape: [1, H, W]
                    img = img.squeeze(0)  # Remove the channel dimension, resulting in [H, W]
                series_images.append(img)
            if series_images:
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, H, W]
            else:
                series_tensor = torch.zeros((1, 512, 512))  # Placeholder tensor
            images[series_description] = series_tensor  # Shape: [num_slices, H, W]

        # Get labels for the study_id
        labels = self.labels_dict[study_id]
        labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long dtype for CrossEntropyLoss

        # Generate attention masks, default to zeros if no annotations
        attention_masks = {}
        for series_description in SERIES_DESCRIPTIONS:
            series_tensor = images[series_description]
            num_slices = series_tensor.shape[0]
            masks = []
            for slice_idx in range(num_slices):
                image_shape = series_tensor[slice_idx].shape  # Get (H, W)
                mask = torch.zeros(image_shape, dtype=torch.float32)  # Default to zero mask
                # If annotations exist, generate the attention mask
                study_annotations = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
                for _, row in study_annotations.iterrows():
                    if row['series_description'] == series_description:
                        x_pixel = int(row['x_scaled'] * image_shape[1])
                        y_pixel = int(row['y_scaled'] * image_shape[0])
                        sigma = 5  # Adjust sigma
                        y_grid, x_grid = torch.meshgrid(
                            torch.arange(image_shape[0], dtype=torch.float32),
                            torch.arange(image_shape[1], dtype=torch.float32),
                            indexing='ij'
                        )
                        gauss = torch.exp(-((x_grid - x_pixel) ** 2 + (y_grid - y_pixel) ** 2) / (2 * sigma ** 2))
                        mask = torch.maximum(mask, gauss)
                masks.append(mask)
            attention_masks[series_description] = torch.stack(masks, dim=0)  # Shape: [num_slices, H, W]

        sample = {
            'study_id': study_id,
            'images': images,
            'labels': labels_tensor,
            'attention_masks': attention_masks
        }

        return sample

In [103]:
# Define any transformations if needed
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the dataset
train_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Create a DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,  # Adjust batch size as needed
    shuffle=True,
    num_workers=4,  # Adjust based on your system
    pin_memory=True
)


In [104]:
# Define the ResNet feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        # Load ResNet18 with the new weights argument
        resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        #print(f"Input shape to ResNetFeatureExtractor: {x.shape}")
        x = self.features(x)
        return x

In [105]:
# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()
        self.num_conditions = num_conditions
        self.num_classes = num_classes

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # The tensors are of shape [batch_size, in_channels, H, W]
        #print(f"sagittal_t1 shape before CNN: {sagittal_t1.shape}")
        #print(f"sagittal_t2_stir shape before CNN: {sagittal_t2_stir.shape}")
        #print(f"axial_t2 shape before CNN: {axial_t2.shape}")

        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps (learned by the model)
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        attended_features_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        attended_features_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        attended_features_axial = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(attended_features_t1, (1, 1)).view(attended_features_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(attended_features_t2_stir, (1, 1)).view(attended_features_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(attended_features_axial, (1, 1)).view(attended_features_axial.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, self.num_conditions, self.num_classes)  # Reshape to [batch_size, num_conditions, num_classes]

        return x, [attention_map_t1, attention_map_t2_stir, attention_map_axial]


In [106]:
def resample_slices(image_tensor, target_slices=10):
    """
    Resample the number of slices to match the target number of slices.
    Assumes image_tensor has shape [num_slices, H, W]
    """
    #print("Shape of image_tensor in resample_slices:", image_tensor.shape)
    current_slices = image_tensor.shape[0]

    if current_slices == target_slices:
        return image_tensor  # No need to resample

    # If more slices, downsample to the target number
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]

    # If fewer slices, upsample by interpolation
    # Add channel dimension
    image_tensor = image_tensor.unsqueeze(0)  # Shape: [1, num_slices, H, W]
    image_tensor = image_tensor.unsqueeze(0)  # Shape: [1, 1, num_slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    image_tensor_resized = image_tensor_resized.squeeze(0).squeeze(0)  # Shape: [target_slices, H, W]
    return image_tensor_resized

In [107]:
# Early Stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_loss = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        self.best_loss = val_loss
        torch.save(model.state_dict(), 'checkpoint.pth')


In [108]:
# Instantiate the model
num_conditions = len(train_dataset.label_columns)
num_classes = 3

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Weight for the attention loss
lambda_attention = 0.1  # Adjust the weighting factor as needed

In [109]:
def train_k_fold(train_dataset, k_folds=5, num_epochs=10, model_save_path=False):
    kfold = KFold(n_splits=k_folds, shuffle=True)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'Fold {fold+1}/{k_folds}')

        # Create data loaders for this fold
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)

        train_loader = DataLoader(train_dataset, batch_size=1, sampler=train_subsampler, num_workers=4, pin_memory=True)
        val_loader = DataLoader(train_dataset, batch_size=1, sampler=val_subsampler, num_workers=4, pin_memory=True)

        # Initialize the model and move it to the correct device
        model = MultiSeriesSpineModel(num_conditions=len(train_dataset.label_columns), num_classes=3)
        model = model.to(device)

        if model_save_path:
            # Load the trained model's state_dict
            model.load_state_dict(torch.load(model_save_path, map_location=device))

        # Initialize optimizer and loss functions
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
        classification_criterion = nn.CrossEntropyLoss(ignore_index=-100)
        attention_criterion = nn.MSELoss()

        early_stopping = EarlyStopping(patience=3)

        for epoch in range(num_epochs):
            model.train()
            total_train_loss = 0.0
            total_val_loss = 0.0
    
            # Training loop
            for batch in tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1}/{num_epochs}", unit="batch"):
                # Extract images, labels, and attention masks
                images = batch['images']
                labels = batch['labels'].to(device)  # Move labels to the same device as model
                attention_masks = batch['attention_masks']
    
                # Remove batch dimension before resampling
                sagittal_t1 = resample_slices(images['Sagittal T1'].squeeze(0), target_slices=10).to(device)
                sagittal_t2_stir = resample_slices(images['Sagittal T2_STIR'].squeeze(0), target_slices=10).to(device)
                axial_t2 = resample_slices(images['Axial T2'].squeeze(0), target_slices=10).to(device)
    
                # Stack slices into the channel dimension
                sagittal_t1 = sagittal_t1.unsqueeze(0)  # Add batch dimension: [1, num_slices, H, W]
                sagittal_t1 = sagittal_t1.reshape(1, -1, 512, 512)  # Now, [batch_size, channels, H, W]
    
                sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).reshape(1, -1, 512, 512)
                axial_t2 = axial_t2.unsqueeze(0).reshape(1, -1, 512, 512)

                # Move attention masks to the same device
                mask_t1 = attention_masks['Sagittal T1'].to(device)  # Shape: [batch_size, num_slices, H, W]
                mask_t2_stir = attention_masks['Sagittal T2_STIR'].to(device)
                mask_axial = attention_masks['Axial T2'].to(device)
                
                # Combine masks across slices (max over slices)
                gt_mask_t1 = torch.max(mask_t1, dim=1)[0].unsqueeze(1)  # Shape: [batch_size, 1, H, W]
                gt_mask_t2_stir = torch.max(mask_t2_stir, dim=1)[0].unsqueeze(1)
                gt_mask_axial = torch.max(mask_axial, dim=1)[0].unsqueeze(1)
                
                # Forward pass
                outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)
                
                # Reshape outputs and labels
                outputs = outputs.view(-1, 3)
                labels_tensor = labels.view(-1)
                
                # Compute classification loss
                classification_loss = classification_criterion(outputs, labels_tensor)
                
                # Compute attention loss
                attention_loss = 0.0
                for att_map, gt_mask in zip(attention_maps, [gt_mask_t1, gt_mask_t2_stir, gt_mask_axial]):
                    #print("Shape of att_map:", att_map.shape)
                    #print("Shape of gt_mask before interpolation:", gt_mask.shape)
                    # Resize ground truth mask to match attention map size
                    gt_mask_resized = F.interpolate(
                        gt_mask,
                        size=att_map.shape[-2:],  # Match the size of the attention map
                        mode='bilinear',
                        align_corners=False
                    )
                    # Compute attention loss
                    attention_loss += attention_criterion(att_map, gt_mask_resized)


                # Total loss
                total_loss = classification_loss + lambda_attention * attention_loss

                # Zero gradients
                optimizer.zero_grad()

                # Backpropagation and optimization
                total_loss.backward()
                optimizer.step()

                total_train_loss += total_loss.item()

            avg_train_loss = total_train_loss / len(train_loader)

            print(f"Fold {fold+1} Epoch [{epoch+1}/{num_epochs}] Avg Train Loss: {avg_train_loss:.4f}")

            # Early stopping
            early_stopping(avg_train_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping triggered for Fold {fold+1}!")
                break

        # Load the last best checkpoint for this fold
        model.load_state_dict(torch.load('checkpoint.pth'))
        print(f"Completed Fold {fold+1}/{k_folds}. Model saved.")

In [None]:

# Example usage
train_k_fold(train_dataset, k_folds=5, num_epochs=10, model_save_path=False)

Fold 1/5


Fold 1 Epoch 1/10: 100%|██████████| 1580/1580 [41:04<00:00,  1.56s/batch] 


Fold 1 Epoch [1/10] Avg Train Loss: 0.6742


Fold 1 Epoch 2/10: 100%|██████████| 1580/1580 [41:33<00:00,  1.58s/batch] 


Fold 1 Epoch [2/10] Avg Train Loss: 0.5741


Fold 1 Epoch 3/10: 100%|██████████| 1580/1580 [41:20<00:00,  1.57s/batch] 


Fold 1 Epoch [3/10] Avg Train Loss: 0.5270


Fold 1 Epoch 4/10: 100%|██████████| 1580/1580 [42:10<00:00,  1.60s/batch] 


Fold 1 Epoch [4/10] Avg Train Loss: 0.5001


Fold 1 Epoch 5/10:  85%|████████▍ | 1338/1580 [35:01<02:43,  1.48batch/s] 

In [None]:
# Save the trained model's state_dict
model_save_path = 'spine_model_with_attention.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [2]:
# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')
df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)

# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
# Keep only rows where 'slice_number' is not NaN
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']



In [3]:
print(df.columns)
print(coordinates_df.columns)
print(series_description_df.columns)

Index(['study_id', 'spinal_canal_stenosis_l1_l2',
       'spinal_canal_stenosis_l2_l3', 'spinal_canal_stenosis_l3_l4',
       'spinal_canal_stenosis_l4_l5', 'spinal_canal_stenosis_l5_s1',
       'left_neural_foraminal_narrowing_l1_l2',
       'left_neural_foraminal_narrowing_l2_l3',
       'left_neural_foraminal_narrowing_l3_l4',
       'left_neural_foraminal_narrowing_l4_l5',
       'left_neural_foraminal_narrowing_l5_s1',
       'right_neural_foraminal_narrowing_l1_l2',
       'right_neural_foraminal_narrowing_l2_l3',
       'right_neural_foraminal_narrowing_l3_l4',
       'right_neural_foraminal_narrowing_l4_l5',
       'right_neural_foraminal_narrowing_l5_s1',
       'left_subarticular_stenosis_l1_l2', 'left_subarticular_stenosis_l2_l3',
       'left_subarticular_stenosis_l3_l4', 'left_subarticular_stenosis_l4_l5',
       'left_subarticular_stenosis_l5_s1', 'right_subarticular_stenosis_l1_l2',
       'right_subarticular_stenosis_l2_l3',
       'right_subarticular_stenosis_l3_l4',
 

In [4]:
class LumbarSpineDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None):
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir  # The root directory where images are stored
        self.transform = transform

        # Get the list of study_ids
        self.study_ids = self.df['study_id'].unique()

        # List of label columns (assuming all columns except 'study_id' are labels)
        self.label_columns = [col for col in df.columns if col != 'study_id']

        # Prepare a mapping for images and annotations
        self.study_image_paths = self._prepare_image_paths()

        # Create a mapping from study_id to labels
        self.labels_dict = self._prepare_labels()

    def _prepare_image_paths(self):
        study_image_paths = {}
        for study_id in self.study_ids:
            study_image_paths[study_id] = {}
            for series_description in SERIES_DESCRIPTIONS:
                series_description_clean = series_description.replace('/', '_')
                image_dir = os.path.join(self.root_dir, 'cvt_png', str(study_id), series_description_clean)
                if os.path.exists(image_dir):
                    # Get all images in the directory
                    image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
                    study_image_paths[study_id][series_description] = image_paths
                else:
                    # Handle missing series
                    study_image_paths[study_id][series_description] = []
        return study_image_paths

    def _prepare_labels(self):
        labels_dict = {}
        for idx, row in self.df.iterrows():
            study_id = row['study_id']
            labels = []
            for col in self.label_columns:
                label = row[col]
                if pd.isnull(label) or label == -100:
                    label = -100  # Use -100 for missing labels (ignore_index)
                else:
                    label = int(label)
                labels.append(label)
            labels_dict[study_id] = labels
        return labels_dict

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        annotations = {}

        # Load images for each series description
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.study_image_paths[study_id][series_description]
            series_images = []
            for img_path in image_paths:
                img = Image.open(img_path).convert('L')  # Convert to grayscale
                if self.transform:
                    img = self.transform(img)
                series_images.append(img)
            if series_images:
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, 1, H, W]
            else:
                series_tensor = torch.zeros((1, 1, 512, 512))  # Placeholder tensor
            images[series_description] = series_tensor

        # Get labels for the study_id
        labels = self.labels_dict[study_id]
        labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long dtype for CrossEntropyLoss

        # Generate attention masks, default to zeros if no annotations
        attention_masks = {}
        for series_description in SERIES_DESCRIPTIONS:
            series_tensor = images[series_description]
            num_slices = series_tensor.shape[0]
            masks = []
            for slice_idx in range(num_slices):
                image_shape = series_tensor[slice_idx].shape[-2:]  # Get (H, W)
                mask = torch.zeros(image_shape, dtype=torch.float32)  # Default to zero mask
                # If annotations exist, generate the attention mask
                study_annotations = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
                for _, row in study_annotations.iterrows():
                    if row['series_description'] == series_description and int(row['slice_number']) == slice_idx:
                        x_pixel = int(row['x_scaled'] * image_shape[1])
                        y_pixel = int(row['y_scaled'] * image_shape[0])
                        sigma = 5  # Adjust sigma
                        y_grid, x_grid = torch.meshgrid(
                            torch.arange(image_shape[0], dtype=torch.float32),
                            torch.arange(image_shape[1], dtype=torch.float32),
                            indexing='ij'
                        )
                        gauss = torch.exp(-((x_grid - x_pixel) ** 2 + (y_grid - y_pixel) ** 2) / (2 * sigma ** 2))
                        mask = torch.maximum(mask, gauss)
                masks.append(mask)
            attention_masks[series_description] = torch.stack(masks, dim=0) # Shape: [num_slices, 1, H, W]

        sample = {
            'study_id': study_id,
            'images': images,
            'labels': labels_tensor,
            'attention_masks': attention_masks
        }

        return sample

In [5]:
# Define any transformations if needed
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the dataset
train_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Create a DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,  # Adjust batch size as needed
    shuffle=True,
    num_workers=4,  # Adjust based on your system
    pin_memory=True
)


In [79]:
# Define the ResNet feature extractor

class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        # Load ResNet18 with the new weights argument
        resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        print(f"Input shape to ResNetFeatureExtractor: {x.shape}")
        x = self.features(x)
        return x


# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()
        self.num_conditions = num_conditions
        self.num_classes = num_classes

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # Forward pass through each ResNet18 model
        
        # Resample the slices if needed (assuming this has already been done)
        
        # Pool along the slice dimension to reduce [batch_size, slices, channels, H, W] to [batch_size, channels, H, W]
        sagittal_t1 = torch.mean(sagittal_t1, dim=1)  # Use mean pooling (or max pooling) across the slice dimension
        sagittal_t2_stir = torch.mean(sagittal_t2_stir, dim=1)
        axial_t2 = torch.mean(axial_t2, dim=1)
    
        # Now the tensors are of shape [batch_size, channels, height, width]
        print(f"sagittal_t1 shape before CNN: {sagittal_t1.shape}")
        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1) # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)
    
        # Generate attention maps (learned by the model)
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)
    
        # Apply attention
        attended_features_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        attended_features_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        attended_features_axial = features_axial_t2 * attention_map_axial
    
        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(attended_features_t1, (1, 1)).view(attended_features_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(attended_features_t2_stir, (1, 1)).view(attended_features_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(attended_features_axial, (1, 1)).view(attended_features_axial.size(0), -1)
    
        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)
    
        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, self.num_conditions, self.num_classes)  # Reshape to [batch_size, num_conditions, num_classes]
    
        return x, [attention_map_t1, attention_map_t2_stir, attention_map_axial]


In [80]:
def resample_slices(image_tensor, target_slices=10):
    """
    Resample the number of slices to match the target number of slices.
    """
    current_slices = image_tensor.shape[0]

    if current_slices == target_slices:
        return image_tensor  # No need to resample

    # If more slices, downsample to the target number
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]

    # If fewer slices, upsample by interpolation
    image_tensor = image_tensor.permute(1, 0, 2, 3).unsqueeze(0)  # Shape: [1, channels, slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    return image_tensor_resized.squeeze(0).permute(1, 0, 2, 3)  # Shape: [slices, channels, H, W]


In [81]:
# Early Stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_loss = float('inf')

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        self.best_loss = val_loss
        torch.save(model.state_dict(), 'checkpoint.pth')

In [82]:
# Instantiate the model
num_conditions = len(train_dataset.label_columns)
num_classes = 3

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Weight for the attention loss
lambda_attention = 0.1  # Adjust the weighting factor as needed


In [83]:
def train_k_fold(train_dataset, k_folds=5, num_epochs=10, model_save_path=False):
    kfold = KFold(n_splits=k_folds, shuffle=True)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):
        print(f'Fold {fold+1}/{k_folds}')

        # Create data loaders for this fold
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)

        train_loader = DataLoader(train_dataset, batch_size=1, sampler=train_subsampler, num_workers=4, pin_memory=True)
        val_loader = DataLoader(train_dataset, batch_size=1, sampler=val_subsampler, num_workers=4, pin_memory=True)

        # Initialize the model and move it to the correct device
        model = MultiSeriesSpineModel(num_conditions=len(train_dataset.label_columns), num_classes=3)

        # Always move the model to the correct device
        model = model.to(device)

        if model_save_path:
            # Load the trained model's state_dict
            model.load_state_dict(torch.load(model_save_path, map_location=device))

        # Initialize optimizer and loss functions
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
        classification_criterion = nn.CrossEntropyLoss(ignore_index=-100)
        attention_criterion = nn.MSELoss()

        early_stopping = EarlyStopping(patience=3)

        for epoch in range(num_epochs):
            model.train()
            total_train_loss = 0.0
            total_val_loss = 0.0

            # Training loop
            for batch in tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1}/{num_epochs}", unit="batch"):
                # Extract images, labels, and attention masks
                images = batch['images']
                labels = batch['labels'].to(device)  # Move labels to the same device as model
                attention_masks = batch['attention_masks']

                # Move all images to the same device
                sagittal_t1 = resample_slices(images['Sagittal T1'].squeeze(0), target_slices=10).unsqueeze(0).to(device)
                sagittal_t2_stir = resample_slices(images['Sagittal T2_STIR'].squeeze(0), target_slices=10).unsqueeze(0).to(device)
                axial_t2 = resample_slices(images['Axial T2'].squeeze(0), target_slices=10).unsqueeze(0).to(device)

                # Move attention masks to the same device
                mask_t1 = attention_masks['Sagittal T1'].to(device)  # Shape: [num_slices, H, W]
                mask_t2_stir = attention_masks['Sagittal T2_STIR'].to(device)
                mask_axial = attention_masks['Axial T2'].to(device)
                
                # Combine masks across slices (max over slices)
                gt_mask_t1 = torch.max(mask_t1, dim=0)[0].unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, H, W]
                gt_mask_t2_stir = torch.max(mask_t2_stir, dim=0)[0].unsqueeze(0).unsqueeze(0)
                gt_mask_axial = torch.max(mask_axial, dim=0)[0].unsqueeze(0).unsqueeze(0)
                
                # Forward pass
                outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)
                
                # Reshape outputs and labels
                outputs = outputs.view(-1, 3)
                labels_tensor = labels.view(-1)
                
                # Compute classification loss
                classification_loss = classification_criterion(outputs, labels_tensor)
                
                # Compute attention loss
                attention_loss = 0.0
                for att_map, gt_mask in zip(attention_maps, [gt_mask_t1, gt_mask_t2_stir, gt_mask_axial]):
                    print("Shape of att_map:", att_map.shape)
                    print("Shape of gt_mask before interpolation:", gt_mask.shape)
                    # Resize ground truth mask to match attention map size
                    gt_mask_resized = F.interpolate(
                        gt_mask,
                        size=att_map.shape[-2:],  # Match the size of the attention map
                        mode='bilinear',
                        align_corners=False
                    ).to(device)
                
                    # Compute attention loss
                    attention_loss += attention_criterion(att_map, gt_mask_resized)


                # Total loss
                total_loss = classification_loss + 0.1 * attention_loss

                # Zero gradients
                optimizer.zero_grad()

                # Backpropagation and optimization
                total_loss.backward()
                optimizer.step()

                total_train_loss += total_loss.item()

            # Validation loop

            avg_train_loss = total_train_loss / len(train_loader)
            avg_val_loss = total_val_loss / len(val_loader)

            print(f"Fold {fold+1} Epoch [{epoch+1}/{num_epochs}] Avg Train Loss: {avg_train_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")

            # Early stopping
            early_stopping(avg_val_loss, model)
            if early_stopping.early_stop:
                print(f"Early stopping triggered for Fold {fold+1}!")
                break

        # Load the last best checkpoint for this fold
        model.load_state_dict(torch.load('checkpoint.pth'))
        print(f"Completed Fold {fold+1}/{k_folds}. Model saved.")



In [84]:
# Example usage
train_k_fold(train_dataset, k_folds=5, num_epochs=10, model_save_path =  False)

Fold 1/5


Fold 1 Epoch 1/10:   0%|          | 0/1580 [00:00<?, ?batch/s]

sagittal_t1 shape before CNN: torch.Size([1, 1, 512, 512])
Input shape to ResNetFeatureExtractor: torch.Size([1, 1, 512, 512])


Fold 1 Epoch 1/10:   0%|          | 0/1580 [00:08<?, ?batch/s]


RuntimeError: Given groups=1, weight of size [64, 10, 7, 7], expected input[1, 1, 512, 512] to have 10 channels, but got 1 channels instead

In [26]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model_w_attentio_v4.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


Model saved to multi_series_spine_model_w_attentio_v4.pth


In [None]:
# Training loop
num_epochs = 10
model.train()

for epoch in range(num_epochs):
    epoch_classification_loss = 0.0
    epoch_attention_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch_idx, batch in enumerate(progress_bar):
        # Extract images, labels, and attention masks from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]
        attention_masks = batch['attention_masks']

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Prepare attention masks and move to device
        mask_t1 = attention_masks['Sagittal T1'].unsqueeze(0).to(device)       # Shape: [1, num_slices, 1, H, W]
        mask_t2_stir = attention_masks['Sagittal T2_STIR'].unsqueeze(0).to(device)
        mask_axial = attention_masks['Axial T2'].unsqueeze(0).to(device)

        # Forward pass
        outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)
        labels_tensor = labels_tensor.view(-1)

        # Compute classification loss
        classification_loss = classification_criterion(outputs, labels_tensor)

        # Compute attention loss
        attention_loss = 0.0
        for att_map, gt_mask in zip(attention_maps, [mask_t1, mask_t2_stir, mask_axial]):
            # Combine per-slice masks into a single mask
            gt_mask_combined = torch.max(gt_mask, dim=1)[0]  # Shape: [batch_size, 1, 1, H, W]
            gt_mask_combined = gt_mask_combined.squeeze(2)    # Now shape: [batch_size, 1, H, W]

            # Resize ground truth mask to match attention map size
            gt_mask_resized = F.interpolate(
                gt_mask_combined,
                size=att_map.shape[-2:],
                mode='bilinear',
                align_corners=False
            )

            # Compute attention loss
            attention_loss += attention_criterion(att_map, gt_mask_resized)

            # # Debugging: Print shapes on first iteration
            # if epoch == 0 and batch_idx == 0:
            #     print(f"gt_mask shape: {gt_mask.shape}")
            #     print(f"gt_mask_combined shape: {gt_mask_combined.shape}")
            #     print(f"gt_mask_resized shape: {gt_mask_resized.shape}")
            #     print(f"att_map shape: {att_map.shape}")

        # Total loss
        total_loss = classification_loss + lambda_attention * attention_loss

        # Zero gradients
        optimizer.zero_grad()

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update losses
        epoch_classification_loss += classification_loss.item()
        epoch_attention_loss += attention_loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'Cls Loss': f'{classification_loss.item():.4f}',
            'Att Loss': f'{attention_loss.item():.4f}'
        })

    # Epoch summary
    avg_classification_loss = epoch_classification_loss / len(train_loader)
    avg_attention_loss = epoch_attention_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Avg Cls Loss: {avg_classification_loss:.4f}, Avg Att Loss: {avg_attention_loss:.4f}")


In [None]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model_w_attentio_v2.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


In [None]:
for epoch in range(num_epochs):
    epoch_classification_loss = 0.0
    epoch_attention_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch_idx, batch in enumerate(progress_bar):
        # Extract images, labels, and attention masks from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]
        attention_masks = batch['attention_masks']

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Prepare attention masks and move to device
        mask_t1 = attention_masks['Sagittal T1'].unsqueeze(0).to(device)       # Shape: [1, num_slices, 1, H, W]
        mask_t2_stir = attention_masks['Sagittal T2_STIR'].unsqueeze(0).to(device)
        mask_axial = attention_masks['Axial T2'].unsqueeze(0).to(device)

        # Forward pass
        outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)
        labels_tensor = labels_tensor.view(-1)

        # Compute classification loss
        classification_loss = classification_criterion(outputs, labels_tensor)

        # Compute attention loss
        attention_loss = 0.0
        for att_map, gt_mask in zip(attention_maps, [mask_t1, mask_t2_stir, mask_axial]):
            # Combine per-slice masks into a single mask
            gt_mask_combined = torch.max(gt_mask, dim=1)[0]  # Shape: [batch_size, 1, 1, H, W]
            gt_mask_combined = gt_mask_combined.squeeze(2)    # Now shape: [batch_size, 1, H, W]

            # Resize ground truth mask to match attention map size
            gt_mask_resized = F.interpolate(
                gt_mask_combined,
                size=att_map.shape[-2:],
                mode='bilinear',
                align_corners=False
            )

            # Compute attention loss
            attention_loss += attention_criterion(att_map, gt_mask_resized)

            # # Debugging: Print shapes on first iteration
            # if epoch == 0 and batch_idx == 0:
            #     print(f"gt_mask shape: {gt_mask.shape}")
            #     print(f"gt_mask_combined shape: {gt_mask_combined.shape}")
            #     print(f"gt_mask_resized shape: {gt_mask_resized.shape}")
            #     print(f"att_map shape: {att_map.shape}")

        # Total loss
        total_loss = classification_loss + lambda_attention * attention_loss

        # Zero gradients
        optimizer.zero_grad()

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update losses
        epoch_classification_loss += classification_loss.item()
        epoch_attention_loss += attention_loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'Cls Loss': f'{classification_loss.item():.4f}',
            'Att Loss': f'{attention_loss.item():.4f}'
        })

    # Epoch summary
    avg_classification_loss = epoch_classification_loss / len(train_loader)
    avg_attention_loss = epoch_attention_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Avg Cls Loss: {avg_classification_loss:.4f}, Avg Att Loss: {avg_attention_loss:.4f}")


In [None]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model_w_attentio_40E.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
for epoch in range(num_epochs):
    epoch_classification_loss = 0.0
    epoch_attention_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch_idx, batch in enumerate(progress_bar):
        # Extract images, labels, and attention masks from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]
        attention_masks = batch['attention_masks']

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Prepare attention masks and move to device
        mask_t1 = attention_masks['Sagittal T1'].unsqueeze(0).to(device)       # Shape: [1, num_slices, 1, H, W]
        mask_t2_stir = attention_masks['Sagittal T2_STIR'].unsqueeze(0).to(device)
        mask_axial = attention_masks['Axial T2'].unsqueeze(0).to(device)

        # Forward pass
        outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)
        labels_tensor = labels_tensor.view(-1)

        # Compute classification loss
        classification_loss = classification_criterion(outputs, labels_tensor)

        # Compute attention loss
        attention_loss = 0.0
        for att_map, gt_mask in zip(attention_maps, [mask_t1, mask_t2_stir, mask_axial]):
            # Combine per-slice masks into a single mask
            gt_mask_combined = torch.max(gt_mask, dim=1)[0]  # Shape: [batch_size, 1, 1, H, W]
            gt_mask_combined = gt_mask_combined.squeeze(2)    # Now shape: [batch_size, 1, H, W]

            # Resize ground truth mask to match attention map size
            gt_mask_resized = F.interpolate(
                gt_mask_combined,
                size=att_map.shape[-2:],
                mode='bilinear',
                align_corners=False
            )

            # Compute attention loss
            attention_loss += attention_criterion(att_map, gt_mask_resized)

            # # Debugging: Print shapes on first iteration
            # if epoch == 0 and batch_idx == 0:
            #     print(f"gt_mask shape: {gt_mask.shape}")
            #     print(f"gt_mask_combined shape: {gt_mask_combined.shape}")
            #     print(f"gt_mask_resized shape: {gt_mask_resized.shape}")
            #     print(f"att_map shape: {att_map.shape}")

        # Total loss
        total_loss = classification_loss + lambda_attention * attention_loss

        # Zero gradients
        optimizer.zero_grad()

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update losses
        epoch_classification_loss += classification_loss.item()
        epoch_attention_loss += attention_loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'Cls Loss': f'{classification_loss.item():.4f}',
            'Att Loss': f'{attention_loss.item():.4f}'
        })

    # Epoch summary
    avg_classification_loss = epoch_classification_loss / len(train_loader)
    avg_attention_loss = epoch_attention_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Avg Cls Loss: {avg_classification_loss:.4f}, Avg Att Loss: {avg_attention_loss:.4f}")


In [None]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model_w_attentio_50E.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
for epoch in range(num_epochs):
    epoch_classification_loss = 0.0
    epoch_attention_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch_idx, batch in enumerate(progress_bar):
        # Extract images, labels, and attention masks from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]
        attention_masks = batch['attention_masks']

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Prepare attention masks and move to device
        mask_t1 = attention_masks['Sagittal T1'].unsqueeze(0).to(device)       # Shape: [1, num_slices, 1, H, W]
        mask_t2_stir = attention_masks['Sagittal T2_STIR'].unsqueeze(0).to(device)
        mask_axial = attention_masks['Axial T2'].unsqueeze(0).to(device)

        # Forward pass
        outputs, attention_maps = model(sagittal_t1, sagittal_t2_stir, axial_t2)

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)
        labels_tensor = labels_tensor.view(-1)

        # Compute classification loss
        classification_loss = classification_criterion(outputs, labels_tensor)

        # Compute attention loss
        attention_loss = 0.0
        for att_map, gt_mask in zip(attention_maps, [mask_t1, mask_t2_stir, mask_axial]):
            # Combine per-slice masks into a single mask
            gt_mask_combined = torch.max(gt_mask, dim=1)[0]  # Shape: [batch_size, 1, 1, H, W]
            gt_mask_combined = gt_mask_combined.squeeze(2)    # Now shape: [batch_size, 1, H, W]

            # Resize ground truth mask to match attention map size
            gt_mask_resized = F.interpolate(
                gt_mask_combined,
                size=att_map.shape[-2:],
                mode='bilinear',
                align_corners=False
            )

            # Compute attention loss
            attention_loss += attention_criterion(att_map, gt_mask_resized)

            # # Debugging: Print shapes on first iteration
            # if epoch == 0 and batch_idx == 0:
            #     print(f"gt_mask shape: {gt_mask.shape}")
            #     print(f"gt_mask_combined shape: {gt_mask_combined.shape}")
            #     print(f"gt_mask_resized shape: {gt_mask_resized.shape}")
            #     print(f"att_map shape: {att_map.shape}")

        # Total loss
        total_loss = classification_loss + lambda_attention * attention_loss

        # Zero gradients
        optimizer.zero_grad()

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update losses
        epoch_classification_loss += classification_loss.item()
        epoch_attention_loss += attention_loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'Cls Loss': f'{classification_loss.item():.4f}',
            'Att Loss': f'{attention_loss.item():.4f}'
        })

    # Epoch summary
    avg_classification_loss = epoch_classification_loss / len(train_loader)
    avg_attention_loss = epoch_attention_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Avg Cls Loss: {avg_classification_loss:.4f}, Avg Att Loss: {avg_attention_loss:.4f}")


In [None]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model_w_attentio_60E.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")