In [1]:
#This is Good Practioce for the moment

!rm -rf /opt/conda/lib/python3.10/site-packages/fsspec*
!pip install fsspec==2024.6.0 --force-reinstall --no-deps
!pip install opencv-python

Collecting fsspec==2024.6.0
  Using cached fsspec-2024.6.0-py3-none-any.whl.metadata (11 kB)
Using cached fsspec-2024.6.0-py3-none-any.whl (176 kB)
Installing collected packages: fsspec
Successfully installed fsspec-2024.6.0
Collecting opencv-python
  Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (62.5 MB)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.10.0.84


In [2]:
#install also to vizualize figures
!sudo apt-get update
!sudo apt-get install -y libgl1-mesa-glx
!sudo apt-get install -y libglib2.0-0

Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease                         
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]        
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1581 B]
Get:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:6 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2267 kB]
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [999 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [3114 kB]
Get:9 http://security.ubuntu.com/ubuntu jammy-security/multiverse amd64 Packages [44.7 kB]
Get:10 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1150 kB]
Get:11 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [3030 kB]
Get:12 http://archive.ubuntu.com/ubuntu jamm

In [None]:
import os
import glob
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score
import numpy as np


In [4]:
# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')
df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)

# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
# Keep only rows where 'slice_number' is not NaN
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']


In [5]:
class LumbarSpineDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None):
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir  # The root directory where images are stored
        self.transform = transform

        # Get the list of study_ids
        self.study_ids = self.df['study_id'].unique()

        # List of label columns (assuming all columns except 'study_id' are labels)
        self.label_columns = [col for col in df.columns if col != 'study_id']

        # Prepare a mapping for images and annotations
        self.study_image_paths = self._prepare_image_paths()

        # Create a mapping from study_id to labels
        self.labels_dict = self._prepare_labels()

    def _prepare_image_paths(self):
        study_image_paths = {}
        for study_id in self.study_ids:
            study_image_paths[study_id] = {}
            for series_description in SERIES_DESCRIPTIONS:
                series_description_clean = series_description.replace('/', '_')
                image_dir = os.path.join(self.root_dir, 'cvt_png', str(study_id), series_description_clean)
                if os.path.exists(image_dir):
                    # Get all images in the directory
                    image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
                    study_image_paths[study_id][series_description] = image_paths
                else:
                    # Handle missing series
                    study_image_paths[study_id][series_description] = []
        return study_image_paths

    def _prepare_labels(self):
        labels_dict = {}
        for idx, row in self.df.iterrows():
            study_id = row['study_id']
            labels = []
            for col in self.label_columns:
                label = row[col]
                if pd.isnull(label) or label == -100:
                    label = -100  # Use -100 for missing labels (ignore_index)
                else:
                    label = int(label)
                labels.append(label)
            labels_dict[study_id] = labels
        return labels_dict

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        annotations = {}

        # Load images for each series description
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.study_image_paths[study_id][series_description]
            series_images = []
            for img_path in image_paths:
                img = Image.open(img_path).convert('L')  # Convert to grayscale
                if self.transform:
                    img = self.transform(img)
                series_images.append(img)
            if series_images:
                # Stack images along the depth dimension
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, 1, H, W]
            else:
                # Handle missing images
                series_tensor = torch.zeros((1, 1, 512, 512))  # Placeholder tensor
            images[series_description] = series_tensor

        # Get labels for the study_id
        labels = self.labels_dict[study_id]
        labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long dtype for CrossEntropyLoss

        # Get annotations for the study_id (if needed)
        study_annotations = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
        for _, row in study_annotations.iterrows():
            condition = row['condition']
            level = row['level']
            x = row['x_scaled']
            y = row['y_scaled']
            series_description = row['series_description']
            slice_number = int(row['slice_number'])
            key = f"{condition}_{level}"
            if key not in annotations:
                annotations[key] = {}
            if series_description not in annotations[key]:
                annotations[key][series_description] = []
            annotations[key][series_description].append({
                'x': x,
                'y': y,
                'slice_number': slice_number
            })

        # Return a dictionary containing images, labels, and annotations
        sample = {
            'study_id': study_id,
            'images': images,
            'labels': labels_tensor,
            'annotations': annotations
        }

        return sample


In [39]:
# Define any transformations if needed
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the dataset
train_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Create a DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,  # Adjust batch size as needed
    shuffle=True,
    num_workers=4,  # Adjust based on your system
    pin_memory=True
)


In [40]:
# Define the ResNet feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        x = self.features(x)
        return x  # Output shape: [batch_size, 512, H, W]

# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # Forward pass through each ResNet18 model
        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        features_sagittal_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        features_sagittal_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        features_axial_t2 = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(features_sagittal_t1, (1, 1)).view(features_sagittal_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(features_sagittal_t2_stir, (1, 1)).view(features_sagittal_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(features_axial_t2, (1, 1)).view(features_axial_t2.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, num_conditions, num_classes)  # Reshape to [batch_size, num_conditions, num_classes]
        return x  # Return logits


In [41]:
# Resample slices function
def resample_slices(image_tensor, target_slices=10):
    """
    Resample the number of slices to match the target number of slices.
    """
    current_slices = image_tensor.shape[0]

    if current_slices == target_slices:
        return image_tensor  # No need to resample

    # If more slices, downsample to the target number
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]

    # If fewer slices, upsample by interpolation
    image_tensor = image_tensor.permute(1, 0, 2, 3).unsqueeze(0)  # Shape: [1, channels, slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    return image_tensor_resized.squeeze(0).permute(1, 0, 2, 3)  # Shape: [slices, channels, H, W]


In [42]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0.0, path='best_model.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
            path (str): Path to save the best model.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path

        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)


In [43]:
# Instantiate the model
num_conditions = len(train_dataset.label_columns)
num_classes = 3
model = MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes)


# # Load the trained model's state_dict
# model_save_path = 'multi_series_spine_model.pth'
# model.load_state_dict(torch.load(model_save_path, map_location=device))

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)  # Use ignore_index to ignore missing labels
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)




In [44]:
from sklearn.model_selection import KFold

# Define transformations with data augmentation if desired
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the full dataset
full_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Define number of conditions and classes
num_conditions = len(full_dataset.label_columns)  # 25 in your case
num_classes = 3  # 'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2

# Initialize K-Fold
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)


In [45]:
print('ok')

ok


In [None]:
# Training loop with Cross-Validation and Early Stopping
for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset)):
    print(f'\n=== Fold {fold + 1}/{n_splits} ===')

    # Create subsets for training and validation
    train_subset = Subset(full_dataset, train_idx)
    val_subset = Subset(full_dataset, val_idx)

    # Create DataLoaders with increased batch size
    train_loader = DataLoader(
        dataset=train_subset,
        batch_size=1,  # Increased from 1
        shuffle=True,
        num_workers=4,  # Adjust based on your system
        pin_memory=True
    )

    val_loader = DataLoader(
        dataset=val_subset,
        batch_size=1,  # Increased from 1
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Instantiate a new model for each fold
    model = MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes)
    model = model.to(device)

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)  # Using AdamW with higher lr

    # Define a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=5, verbose=True, path=f'best_model_fold_{fold + 1}.pt')

    # Define number of epochs
    num_epochs = 50  # Increased epochs to allow early stopping

    # Lists to store epoch-wise loss for plotting
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Fold {fold + 1} Epoch {epoch+1}/{num_epochs}", unit="batch")

        for batch in progress_bar:
            # Extract images and labels from the batch
            images = batch['images']
            labels = batch['labels']  # Tensor of shape [num_conditions]

            # Get the image tensors
            sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
            sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
            axial_t2 = images['Axial T2'].squeeze(0)

            # Resample slices to 10
            sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
            sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
            axial_t2 = resample_slices(axial_t2, target_slices=10)

            # Remove singleton channel dimension if present
            sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
            sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
            axial_t2 = axial_t2.squeeze(1)

            # Add batch dimension and move to device
            sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
            sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
            axial_t2 = axial_t2.unsqueeze(0).to(device)

            # Prepare labels tensor and move to device
            labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)  # Shape: [1, num_conditions, num_classes]

            # Reshape outputs and labels
            outputs = outputs.view(-1, num_classes)       # Shape: [num_conditions, num_classes]
            labels_tensor = labels_tensor.view(-1)        # Shape: [num_conditions]

            # Compute loss
            total_loss = criterion(outputs, labels_tensor)

            # Backward pass
            total_loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Optimizer step
            optimizer.step()

            # Update loss
            epoch_loss += total_loss.item()

            # Update progress bar
            progress_bar.set_postfix({'Loss': f'{total_loss.item():.4f}'})

        # Calculate average training loss
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Fold {fold + 1} Epoch {epoch+1}/{num_epochs} Training Loss: {avg_train_loss:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in val_loader:
                images = batch['images']
                labels = batch['labels']

                sagittal_t1 = images['Sagittal T1'].squeeze(0)
                sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
                axial_t2 = images['Axial T2'].squeeze(0)

                sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
                sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
                axial_t2 = resample_slices(axial_t2, target_slices=10)

                sagittal_t1 = sagittal_t1.squeeze(1)
                sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
                axial_t2 = axial_t2.squeeze(1)

                sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)
                sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
                axial_t2 = axial_t2.unsqueeze(0).to(device)

                labels_tensor = labels.unsqueeze(0).to(device)

                outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)
                outputs = outputs.view(-1, num_classes)
                labels_tensor = labels_tensor.view(-1)

                loss = criterion(outputs, labels_tensor)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels_tensor.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Fold {fold + 1} Epoch {epoch+1}/{num_epochs} Validation Loss: {avg_val_loss:.4f}")

        # Update scheduler
        scheduler.step(avg_val_loss)

        # Check early stopping
        early_stopping(avg_val_loss, model)

        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch+1} for fold {fold + 1}")
            break

    # Load the best model for the current fold
    model.load_state_dict(torch.load(f'best_model_fold_{fold + 1}.pt'))

    # Evaluation on the validation set
    model.eval()
    fold_val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            images = batch['images']
            labels = batch['labels']

            sagittal_t1 = images['Sagittal T1'].squeeze(0)
            sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
            axial_t2 = images['Axial T2'].squeeze(0)

            sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
            sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
            axial_t2 = resample_slices(axial_t2, target_slices=10)

            sagittal_t1 = sagittal_t1.squeeze(1)
            sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
            axial_t2 = axial_t2.squeeze(1)

            sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)
            sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
            axial_t2 = axial_t2.unsqueeze(0).to(device)

            labels_tensor = labels.unsqueeze(0).to(device)

            outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)
            outputs = outputs.view(-1, num_classes)
            labels_tensor = labels_tensor.view(-1)

            loss = criterion(outputs, labels_tensor)
            fold_val_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels_tensor.cpu().numpy())

    avg_fold_val_loss = fold_val_loss / len(val_loader)
    print(f"Fold {fold + 1} Best Validation Loss: {avg_fold_val_loss:.4f}")

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Fold {fold + 1} Validation Accuracy: {accuracy:.4f}")
    print(f"Fold {fold + 1} Validation F1 Score: {f1:.4f}")

    # Store metrics
    fold_accuracies.append(accuracy)
    fold_f1_scores.append(f1)
    fold_val_losses.append(avg_fold_val_loss)

    # Plot training and validation loss for the current fold
    plt.figure(figsize=(10,5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {fold + 1} Loss Curve')
    plt.legend()
    plt.show()

    # Clean up for the next fold
    del model
    torch.cuda.empty_cache()

# After all folds
print("\n=== Cross-Validation Results ===")
print(f"Average Validation Loss: {np.mean(fold_val_losses):.4f} ± {np.std(fold_val_losses):.4f}")
print(f"Average Validation Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")
print(f"Average Validation F1 Score: {np.mean(fold_f1_scores):.4f} ± {np.std(fold_f1_scores):.4f}")


=== Fold 1/5 ===


Fold 1 Epoch 1/50: 100%|██████████| 1580/1580 [05:17<00:00,  4.98batch/s, Loss=0.2833]

Fold 1 Epoch 1/50 Training Loss: 0.6250





Fold 1 Epoch 1/50 Validation Loss: 1.0048
Validation loss decreased (1.004845 --> 1.004845).  Saving model ...


Fold 1 Epoch 2/50: 100%|██████████| 1580/1580 [05:21<00:00,  4.92batch/s, Loss=0.3006]

Fold 1 Epoch 2/50 Training Loss: 0.5760





Fold 1 Epoch 2/50 Validation Loss: 0.9363
Validation loss decreased (0.936295 --> 0.936295).  Saving model ...


Fold 1 Epoch 3/50:   0%|          | 0/1580 [00:00<?, ?batch/s]

In [None]:
# Training loop
num_epochs = 10  # Define the number of epochs
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch in progress_bar:
        # Extract images and labels from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)  # Shape: [1, num_conditions, num_classes]

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)       # Shape: [num_conditions, num_classes]
        labels_tensor = labels_tensor.view(-1)        # Shape: [num_conditions]

        # Compute loss
        total_loss = criterion(outputs, labels_tensor)

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update loss
        epoch_loss += total_loss.item()

        # Update progress bar
        progress_bar.set_postfix({'Loss': f'{total_loss.item():.4f}'})

    # Epoch summary
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")


Epoch 1/10:  38%|███▊      | 760/1975 [02:13<04:13,  4.80batch/s, Loss=0.3511]

In [None]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")



### Objective:
We want to plot an annotated axial slice image from our dataset. The annotations come from the `coordinates_df`, which contains x, y coordinates and additional information about the study, including the `series_id`, `instance_number`, `condition`, and `level`. These annotations represent the specific slices and the associated condition-level severity we're trying to classify/estimate.

### Key Points:
1. **Data Sources**:
   - **`df`**: This contains the labels for `condition` and `level` across different spinal areas for each `study_id`.
   - **`coordinates_df`**: This contains the x, y coordinates, `series_id`, `instance_number`, `condition`, and `level` related to each `study_id`.
   - **`series_description_df`**: This maps the `series_id` to its respective `series_description` (e.g., 'Axial T2', 'Sagittal T1').

2. **Image Path Mapping**:
   - From `coordinates_df`, we need to extract the `study_id`, `series_id`, and `instance_number` to locate the corresponding axial image. 
   - The image path is generated using:
     ```python
     image_path = f'./rsna_output/cvt_png/{study_id}/{series_description}/{instance_number:03d}.png'
     ```
     where `series_description` is derived from the `series_id` using the `series_description_df`.

3. **DataLoader Responsibilities**:
   - The DataLoader needs to provide the required information (`study_id`, `series_id`, `instance_number`, `x`, `y`) to correctly map images and annotations.
   - For slices without annotations, the model should focus on 'no annotation' data.

### Process Flow:
1. **Fetch Image and Annotations**:
   - For each study (`study_id`), find the `x`, `y` coordinates from `coordinates_df`.
   - Get the corresponding `series_id` and map it to a `series_description` using `series_description_df`.
   - Locate the slice image using `series_description` and `instance_number`.

2. **Plotting**:
   - Display the axial slice image with a bounding box drawn around the `x`, `y` coordinates for the annotation.
   - Display the label for the corresponding `condition` and `level`.



In [None]:
# from torchviz import make_dot
# from PIL import Image
# import matplotlib.pyplot as plt

# # Create dummy data to simulate model input
# batch_size = 2
# dummy_sagittal_t1 = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Sagittal T1
# dummy_sagittal_t2_stir = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Sagittal T2/STIR
# dummy_axial_t2 = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Axial T2

# # Pass through the model to get a forward pass
# condition_pred, coord_pred = model(dummy_sagittal_t1, dummy_sagittal_t2_stir, dummy_axial_t2)

# # Create the computational graph
# dot = make_dot((condition_pred, coord_pred), params=dict(model.named_parameters()))

# # Render to a file and display it
# dot.render("model_diagram", format="png")  # Save as PNG

# # Load and display the image
# img = Image.open("model_diagram.png")
# plt.figure(figsize=(10, 10))  # Increase the figure size for better clarity
# plt.imshow(img)
# plt.axis('off')  # Hide axes for clarity
# plt.show()

In [1]:
#This is Good Practioce for the moment

!rm -rf /opt/conda/lib/python3.10/site-packages/fsspec*
!pip install fsspec==2024.6.0 --force-reinstall --no-deps
!pip install opencv-python

Collecting fsspec==2024.6.0
  Using cached fsspec-2024.6.0-py3-none-any.whl.metadata (11 kB)
Using cached fsspec-2024.6.0-py3-none-any.whl (176 kB)
Installing collected packages: fsspec
Successfully installed fsspec-2024.6.0
Collecting opencv-python
  Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (62.5 MB)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.10.0.84


In [2]:
#install also to vizualize figures
!sudo apt-get update
!sudo apt-get install -y libgl1-mesa-glx
!sudo apt-get install -y libglib2.0-0

Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease                         
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]        
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1581 B]
Get:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:6 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2267 kB]
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [999 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-updates/restricted amd64 Packages [3114 kB]
Get:9 http://security.ubuntu.com/ubuntu jammy-security/multiverse amd64 Packages [44.7 kB]
Get:10 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1150 kB]
Get:11 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [3030 kB]
Get:12 http://archive.ubuntu.com/ubuntu jamm

In [None]:
import os
import glob
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, f1_score
import numpy as np


In [4]:
# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')
df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)

# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
# Keep only rows where 'slice_number' is not NaN
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']


In [5]:
class LumbarSpineDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None):
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir  # The root directory where images are stored
        self.transform = transform

        # Get the list of study_ids
        self.study_ids = self.df['study_id'].unique()

        # List of label columns (assuming all columns except 'study_id' are labels)
        self.label_columns = [col for col in df.columns if col != 'study_id']

        # Prepare a mapping for images and annotations
        self.study_image_paths = self._prepare_image_paths()

        # Create a mapping from study_id to labels
        self.labels_dict = self._prepare_labels()

    def _prepare_image_paths(self):
        study_image_paths = {}
        for study_id in self.study_ids:
            study_image_paths[study_id] = {}
            for series_description in SERIES_DESCRIPTIONS:
                series_description_clean = series_description.replace('/', '_')
                image_dir = os.path.join(self.root_dir, 'cvt_png', str(study_id), series_description_clean)
                if os.path.exists(image_dir):
                    # Get all images in the directory
                    image_paths = sorted(glob.glob(os.path.join(image_dir, '*.png')))
                    study_image_paths[study_id][series_description] = image_paths
                else:
                    # Handle missing series
                    study_image_paths[study_id][series_description] = []
        return study_image_paths

    def _prepare_labels(self):
        labels_dict = {}
        for idx, row in self.df.iterrows():
            study_id = row['study_id']
            labels = []
            for col in self.label_columns:
                label = row[col]
                if pd.isnull(label) or label == -100:
                    label = -100  # Use -100 for missing labels (ignore_index)
                else:
                    label = int(label)
                labels.append(label)
            labels_dict[study_id] = labels
        return labels_dict

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        images = {}
        annotations = {}

        # Load images for each series description
        for series_description in SERIES_DESCRIPTIONS:
            image_paths = self.study_image_paths[study_id][series_description]
            series_images = []
            for img_path in image_paths:
                img = Image.open(img_path).convert('L')  # Convert to grayscale
                if self.transform:
                    img = self.transform(img)
                series_images.append(img)
            if series_images:
                # Stack images along the depth dimension
                series_tensor = torch.stack(series_images, dim=0)  # Shape: [num_slices, 1, H, W]
            else:
                # Handle missing images
                series_tensor = torch.zeros((1, 1, 512, 512))  # Placeholder tensor
            images[series_description] = series_tensor

        # Get labels for the study_id
        labels = self.labels_dict[study_id]
        labels_tensor = torch.tensor(labels, dtype=torch.long)  # Use long dtype for CrossEntropyLoss

        # Get annotations for the study_id (if needed)
        study_annotations = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
        for _, row in study_annotations.iterrows():
            condition = row['condition']
            level = row['level']
            x = row['x_scaled']
            y = row['y_scaled']
            series_description = row['series_description']
            slice_number = int(row['slice_number'])
            key = f"{condition}_{level}"
            if key not in annotations:
                annotations[key] = {}
            if series_description not in annotations[key]:
                annotations[key][series_description] = []
            annotations[key][series_description].append({
                'x': x,
                'y': y,
                'slice_number': slice_number
            })

        # Return a dictionary containing images, labels, and annotations
        sample = {
            'study_id': study_id,
            'images': images,
            'labels': labels_tensor,
            'annotations': annotations
        }

        return sample


In [39]:
# Define any transformations if needed
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the dataset
train_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Create a DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,  # Adjust batch size as needed
    shuffle=True,
    num_workers=4,  # Adjust based on your system
    pin_memory=True
)


In [40]:
# Define the ResNet feature extractor
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, in_channels=10):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to accept in_channels
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Extract layers up to layer4 (exclude avgpool and fc layers)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )

    def forward(self, x):
        x = self.features(x)
        return x  # Output shape: [batch_size, 512, H, W]

# Define the main model
class MultiSeriesSpineModel(nn.Module):
    def __init__(self, num_conditions=25, num_classes=3):
        super(MultiSeriesSpineModel, self).__init__()

        # Feature extractors for each MRI series
        self.cnn_sagittal_t1 = ResNetFeatureExtractor(in_channels=10)
        self.cnn_sagittal_t2_stir = ResNetFeatureExtractor(in_channels=10)
        self.cnn_axial_t2 = ResNetFeatureExtractor(in_channels=10)

        # Define attention layers for each series
        self.attention_sagittal_t1 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_sagittal_t2_stir = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention_axial_t2 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()
        )

        # Define the final classification layers
        combined_feature_size = 512 * 3  # Since we're concatenating features from three models

        self.fc1 = nn.Linear(combined_feature_size, 512)
        self.fc2 = nn.Linear(512, num_conditions * num_classes)  # Output layer

    def forward(self, sagittal_t1, sagittal_t2_stir, axial_t2):
        # Forward pass through each ResNet18 model
        features_sagittal_t1 = self.cnn_sagittal_t1(sagittal_t1)  # Shape: [batch_size, 512, H, W]
        features_sagittal_t2_stir = self.cnn_sagittal_t2_stir(sagittal_t2_stir)
        features_axial_t2 = self.cnn_axial_t2(axial_t2)

        # Generate attention maps
        attention_map_t1 = self.attention_sagittal_t1(features_sagittal_t1)  # Shape: [batch_size, 1, H, W]
        attention_map_t2_stir = self.attention_sagittal_t2_stir(features_sagittal_t2_stir)
        attention_map_axial = self.attention_axial_t2(features_axial_t2)

        # Apply attention
        features_sagittal_t1 = features_sagittal_t1 * attention_map_t1  # Element-wise multiplication
        features_sagittal_t2_stir = features_sagittal_t2_stir * attention_map_t2_stir
        features_axial_t2 = features_axial_t2 * attention_map_axial

        # Global average pooling
        features_sagittal_t1 = F.adaptive_avg_pool2d(features_sagittal_t1, (1, 1)).view(features_sagittal_t1.size(0), -1)
        features_sagittal_t2_stir = F.adaptive_avg_pool2d(features_sagittal_t2_stir, (1, 1)).view(features_sagittal_t2_stir.size(0), -1)
        features_axial_t2 = F.adaptive_avg_pool2d(features_axial_t2, (1, 1)).view(features_axial_t2.size(0), -1)

        # Concatenate features
        combined_features = torch.cat([features_sagittal_t1, features_sagittal_t2_stir, features_axial_t2], dim=1)

        # Pass through final classification layers
        x = F.relu(self.fc1(combined_features))
        x = self.fc2(x)  # Shape: [batch_size, num_conditions * num_classes]
        x = x.view(-1, num_conditions, num_classes)  # Reshape to [batch_size, num_conditions, num_classes]
        return x  # Return logits


In [41]:
# Resample slices function
def resample_slices(image_tensor, target_slices=10):
    """
    Resample the number of slices to match the target number of slices.
    """
    current_slices = image_tensor.shape[0]

    if current_slices == target_slices:
        return image_tensor  # No need to resample

    # If more slices, downsample to the target number
    if current_slices > target_slices:
        indices = torch.linspace(0, current_slices - 1, target_slices).long()
        return image_tensor[indices]

    # If fewer slices, upsample by interpolation
    image_tensor = image_tensor.permute(1, 0, 2, 3).unsqueeze(0)  # Shape: [1, channels, slices, H, W]
    image_tensor_resized = F.interpolate(
        image_tensor,
        size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]),
        mode='trilinear',
        align_corners=False
    )
    return image_tensor_resized.squeeze(0).permute(1, 0, 2, 3)  # Shape: [slices, channels, H, W]


In [42]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0.0, path='best_model.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
            path (str): Path to save the best model.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path

        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)


In [43]:
# Instantiate the model
num_conditions = len(train_dataset.label_columns)
num_classes = 3
model = MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes)


# # Load the trained model's state_dict
# model_save_path = 'multi_series_spine_model.pth'
# model.load_state_dict(torch.load(model_save_path, map_location=device))

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)  # Use ignore_index to ignore missing labels
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)




In [44]:
from sklearn.model_selection import KFold

# Define transformations with data augmentation if desired
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust mean and std if necessary
])

# Instantiate the full dataset
full_dataset = LumbarSpineDataset(
    df=df,
    coordinates_df=coordinates_df,
    series_description_df=series_description_df,
    root_dir='./rsna_output',  # Adjust the path as needed
    transform=transform
)

# Define number of conditions and classes
num_conditions = len(full_dataset.label_columns)  # 25 in your case
num_classes = 3  # 'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2

# Initialize K-Fold
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)


In [45]:
print('ok')

ok


In [None]:
# Training loop with Cross-Validation and Early Stopping
for fold, (train_idx, val_idx) in enumerate(kf.split(full_dataset)):
    print(f'\n=== Fold {fold + 1}/{n_splits} ===')

    # Create subsets for training and validation
    train_subset = Subset(full_dataset, train_idx)
    val_subset = Subset(full_dataset, val_idx)

    # Create DataLoaders with increased batch size
    train_loader = DataLoader(
        dataset=train_subset,
        batch_size=1,  # Increased from 1
        shuffle=True,
        num_workers=4,  # Adjust based on your system
        pin_memory=True
    )

    val_loader = DataLoader(
        dataset=val_subset,
        batch_size=1,  # Increased from 1
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Instantiate a new model for each fold
    model = MultiSeriesSpineModel(num_conditions=num_conditions, num_classes=num_classes)
    model = model.to(device)

    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)  # Using AdamW with higher lr

    # Define a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Initialize EarlyStopping
    early_stopping = EarlyStopping(patience=5, verbose=True, path=f'best_model_fold_{fold + 1}.pt')

    # Define number of epochs
    num_epochs = 50  # Increased epochs to allow early stopping

    # Lists to store epoch-wise loss for plotting
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Fold {fold + 1} Epoch {epoch+1}/{num_epochs}", unit="batch")

        for batch in progress_bar:
            # Extract images and labels from the batch
            images = batch['images']
            labels = batch['labels']  # Tensor of shape [num_conditions]

            # Get the image tensors
            sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
            sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
            axial_t2 = images['Axial T2'].squeeze(0)

            # Resample slices to 10
            sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
            sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
            axial_t2 = resample_slices(axial_t2, target_slices=10)

            # Remove singleton channel dimension if present
            sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
            sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
            axial_t2 = axial_t2.squeeze(1)

            # Add batch dimension and move to device
            sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
            sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
            axial_t2 = axial_t2.unsqueeze(0).to(device)

            # Prepare labels tensor and move to device
            labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)  # Shape: [1, num_conditions, num_classes]

            # Reshape outputs and labels
            outputs = outputs.view(-1, num_classes)       # Shape: [num_conditions, num_classes]
            labels_tensor = labels_tensor.view(-1)        # Shape: [num_conditions]

            # Compute loss
            total_loss = criterion(outputs, labels_tensor)

            # Backward pass
            total_loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Optimizer step
            optimizer.step()

            # Update loss
            epoch_loss += total_loss.item()

            # Update progress bar
            progress_bar.set_postfix({'Loss': f'{total_loss.item():.4f}'})

        # Calculate average training loss
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        print(f"Fold {fold + 1} Epoch {epoch+1}/{num_epochs} Training Loss: {avg_train_loss:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in val_loader:
                images = batch['images']
                labels = batch['labels']

                sagittal_t1 = images['Sagittal T1'].squeeze(0)
                sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
                axial_t2 = images['Axial T2'].squeeze(0)

                sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
                sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
                axial_t2 = resample_slices(axial_t2, target_slices=10)

                sagittal_t1 = sagittal_t1.squeeze(1)
                sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
                axial_t2 = axial_t2.squeeze(1)

                sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)
                sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
                axial_t2 = axial_t2.unsqueeze(0).to(device)

                labels_tensor = labels.unsqueeze(0).to(device)

                outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)
                outputs = outputs.view(-1, num_classes)
                labels_tensor = labels_tensor.view(-1)

                loss = criterion(outputs, labels_tensor)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels_tensor.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        print(f"Fold {fold + 1} Epoch {epoch+1}/{num_epochs} Validation Loss: {avg_val_loss:.4f}")

        # Update scheduler
        scheduler.step(avg_val_loss)

        # Check early stopping
        early_stopping(avg_val_loss, model)

        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch+1} for fold {fold + 1}")
            break

    # Load the best model for the current fold
    model.load_state_dict(torch.load(f'best_model_fold_{fold + 1}.pt'))

    # Evaluation on the validation set
    model.eval()
    fold_val_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            images = batch['images']
            labels = batch['labels']

            sagittal_t1 = images['Sagittal T1'].squeeze(0)
            sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
            axial_t2 = images['Axial T2'].squeeze(0)

            sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
            sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
            axial_t2 = resample_slices(axial_t2, target_slices=10)

            sagittal_t1 = sagittal_t1.squeeze(1)
            sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
            axial_t2 = axial_t2.squeeze(1)

            sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)
            sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
            axial_t2 = axial_t2.unsqueeze(0).to(device)

            labels_tensor = labels.unsqueeze(0).to(device)

            outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)
            outputs = outputs.view(-1, num_classes)
            labels_tensor = labels_tensor.view(-1)

            loss = criterion(outputs, labels_tensor)
            fold_val_loss += loss.item()

            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels_tensor.cpu().numpy())

    avg_fold_val_loss = fold_val_loss / len(val_loader)
    print(f"Fold {fold + 1} Best Validation Loss: {avg_fold_val_loss:.4f}")

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Fold {fold + 1} Validation Accuracy: {accuracy:.4f}")
    print(f"Fold {fold + 1} Validation F1 Score: {f1:.4f}")

    # Store metrics
    fold_accuracies.append(accuracy)
    fold_f1_scores.append(f1)
    fold_val_losses.append(avg_fold_val_loss)

    # Plot training and validation loss for the current fold
    plt.figure(figsize=(10,5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {fold + 1} Loss Curve')
    plt.legend()
    plt.show()

    # Clean up for the next fold
    del model
    torch.cuda.empty_cache()

# After all folds
print("\n=== Cross-Validation Results ===")
print(f"Average Validation Loss: {np.mean(fold_val_losses):.4f} ± {np.std(fold_val_losses):.4f}")
print(f"Average Validation Accuracy: {np.mean(fold_accuracies):.4f} ± {np.std(fold_accuracies):.4f}")
print(f"Average Validation F1 Score: {np.mean(fold_f1_scores):.4f} ± {np.std(fold_f1_scores):.4f}")


=== Fold 1/5 ===


Fold 1 Epoch 1/50: 100%|██████████| 1580/1580 [05:17<00:00,  4.98batch/s, Loss=0.2833]

Fold 1 Epoch 1/50 Training Loss: 0.6250





Fold 1 Epoch 1/50 Validation Loss: 1.0048
Validation loss decreased (1.004845 --> 1.004845).  Saving model ...


Fold 1 Epoch 2/50: 100%|██████████| 1580/1580 [05:21<00:00,  4.92batch/s, Loss=0.3006]

Fold 1 Epoch 2/50 Training Loss: 0.5760





Fold 1 Epoch 2/50 Validation Loss: 0.9363
Validation loss decreased (0.936295 --> 0.936295).  Saving model ...


Fold 1 Epoch 3/50:   0%|          | 0/1580 [00:00<?, ?batch/s]

In [47]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to multi_series_spine_model.pth


In [None]:
# Training loop
num_epochs = 10  # Define the number of epochs
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

    for batch in progress_bar:
        # Extract images and labels from the batch
        images = batch['images']
        labels = batch['labels']  # Tensor of shape [num_conditions]

        # Get the image tensors
        sagittal_t1 = images['Sagittal T1'].squeeze(0)  # Shape: [num_slices, 1, H, W]
        sagittal_t2_stir = images['Sagittal T2_STIR'].squeeze(0)
        axial_t2 = images['Axial T2'].squeeze(0)

        # Resample slices to 10
        sagittal_t1 = resample_slices(sagittal_t1, target_slices=10)
        sagittal_t2_stir = resample_slices(sagittal_t2_stir, target_slices=10)
        axial_t2 = resample_slices(axial_t2, target_slices=10)

        # Remove singleton channel dimension if present
        sagittal_t1 = sagittal_t1.squeeze(1)  # Shape: [10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.squeeze(1)
        axial_t2 = axial_t2.squeeze(1)

        # Add batch dimension and move to device
        sagittal_t1 = sagittal_t1.unsqueeze(0).to(device)  # Shape: [1, 10, H, W]
        sagittal_t2_stir = sagittal_t2_stir.unsqueeze(0).to(device)
        axial_t2 = axial_t2.unsqueeze(0).to(device)

        # Prepare labels tensor and move to device
        labels_tensor = labels.unsqueeze(0).to(device)  # Shape: [1, num_conditions]

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(sagittal_t1, sagittal_t2_stir, axial_t2)  # Shape: [1, num_conditions, num_classes]

        # Reshape outputs and labels
        outputs = outputs.view(-1, num_classes)       # Shape: [num_conditions, num_classes]
        labels_tensor = labels_tensor.view(-1)        # Shape: [num_conditions]

        # Compute loss
        total_loss = criterion(outputs, labels_tensor)

        # Backward pass
        total_loss.backward()

        # Optimizer step
        optimizer.step()

        # Update loss
        epoch_loss += total_loss.item()

        # Update progress bar
        progress_bar.set_postfix({'Loss': f'{total_loss.item():.4f}'})

    # Epoch summary
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")


Epoch 1/10:  38%|███▊      | 760/1975 [02:13<04:13,  4.80batch/s, Loss=0.3511]

In [None]:
# Save the trained model's state_dict
model_save_path = 'multi_series_spine_model.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")



### Objective:
We want to plot an annotated axial slice image from our dataset. The annotations come from the `coordinates_df`, which contains x, y coordinates and additional information about the study, including the `series_id`, `instance_number`, `condition`, and `level`. These annotations represent the specific slices and the associated condition-level severity we're trying to classify/estimate.

### Key Points:
1. **Data Sources**:
   - **`df`**: This contains the labels for `condition` and `level` across different spinal areas for each `study_id`.
   - **`coordinates_df`**: This contains the x, y coordinates, `series_id`, `instance_number`, `condition`, and `level` related to each `study_id`.
   - **`series_description_df`**: This maps the `series_id` to its respective `series_description` (e.g., 'Axial T2', 'Sagittal T1').

2. **Image Path Mapping**:
   - From `coordinates_df`, we need to extract the `study_id`, `series_id`, and `instance_number` to locate the corresponding axial image. 
   - The image path is generated using:
     ```python
     image_path = f'./rsna_output/cvt_png/{study_id}/{series_description}/{instance_number:03d}.png'
     ```
     where `series_description` is derived from the `series_id` using the `series_description_df`.

3. **DataLoader Responsibilities**:
   - The DataLoader needs to provide the required information (`study_id`, `series_id`, `instance_number`, `x`, `y`) to correctly map images and annotations.
   - For slices without annotations, the model should focus on 'no annotation' data.

### Process Flow:
1. **Fetch Image and Annotations**:
   - For each study (`study_id`), find the `x`, `y` coordinates from `coordinates_df`.
   - Get the corresponding `series_id` and map it to a `series_description` using `series_description_df`.
   - Locate the slice image using `series_description` and `instance_number`.

2. **Plotting**:
   - Display the axial slice image with a bounding box drawn around the `x`, `y` coordinates for the annotation.
   - Display the label for the corresponding `condition` and `level`.



In [None]:
# from torchviz import make_dot
# from PIL import Image
# import matplotlib.pyplot as plt

# # Create dummy data to simulate model input
# batch_size = 2
# dummy_sagittal_t1 = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Sagittal T1
# dummy_sagittal_t2_stir = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Sagittal T2/STIR
# dummy_axial_t2 = torch.randn(batch_size, 10, 512, 512)  # 10 slices for Axial T2

# # Pass through the model to get a forward pass
# condition_pred, coord_pred = model(dummy_sagittal_t1, dummy_sagittal_t2_stir, dummy_axial_t2)

# # Create the computational graph
# dot = make_dot((condition_pred, coord_pred), params=dict(model.named_parameters()))

# # Render to a file and display it
# dot.render("model_diagram", format="png")  # Save as PNG

# # Load and display the image
# img = Image.open("model_diagram.png")
# plt.figure(figsize=(10, 10))  # Increase the figure size for better clarity
# plt.imshow(img)
# plt.axis('off')  # Hide axes for clarity
# plt.show()