In [1]:
!pip install pydicom




In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

./rsna_output/cvt_png/
    {study_id}/
        {series_id}/
            1.png
            2.png
            ...

In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import KFold
from tqdm import tqdm
import copy
import random
from PIL import Image

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check for available device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [4]:
# Define constants
SERIES_DESCRIPTIONS = ['Sagittal T1', 'Sagittal T2_STIR', 'Axial T2']
CONDITIONS = [
    'spinal_canal_stenosis',
    'left_neural_foraminal_narrowing',
    'right_neural_foraminal_narrowing',
    'left_subarticular_stenosis',
    'right_subarticular_stenosis'
]
LEVELS = ['l1_l2', 'l2_l3', 'l3_l4', 'l4_l5', 'l5_s1']
LABELS = [f'{condition}_{level}' for condition in CONDITIONS for level in LEVELS]

# Set the root directory for your Kaggle files
rd = './kaggle-files'

# Load the main CSV file
df = pd.read_csv(f'{rd}/train.csv')

df = df.fillna(-100)  # Use -100 to indicate missing labels

# Map the labels to integers for multi-class classification
label2id = {'Normal/Mild': 0, 'Moderate': 1, 'Severe': 2}
df.replace(label2id, inplace=True)


# Load the coordinates data
coordinates_df = pd.read_csv(f'{rd}/dfc_updated.csv')
coordinates_df = coordinates_df.dropna(subset=['slice_number'])
coordinates_df['slice_number'] = coordinates_df['slice_number'].astype(int)

# Normalize series descriptions
coordinates_df['series_description'] = coordinates_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Load the series descriptions
series_description_df = pd.read_csv(f'{rd}/train_series_descriptions.csv')
series_description_df['series_description'] = series_description_df['series_description'].str.replace('T2/STIR', 'T2_STIR')

# Verify loaded data
print("Main DataFrame Columns:", df.columns.tolist())
print("Coordinates DataFrame Columns:", coordinates_df.columns.tolist())
print("Series Description DataFrame Columns:", series_description_df.columns.tolist())

Main DataFrame Columns: ['study_id', 'spinal_canal_stenosis_l1_l2', 'spinal_canal_stenosis_l2_l3', 'spinal_canal_stenosis_l3_l4', 'spinal_canal_stenosis_l4_l5', 'spinal_canal_stenosis_l5_s1', 'left_neural_foraminal_narrowing_l1_l2', 'left_neural_foraminal_narrowing_l2_l3', 'left_neural_foraminal_narrowing_l3_l4', 'left_neural_foraminal_narrowing_l4_l5', 'left_neural_foraminal_narrowing_l5_s1', 'right_neural_foraminal_narrowing_l1_l2', 'right_neural_foraminal_narrowing_l2_l3', 'right_neural_foraminal_narrowing_l3_l4', 'right_neural_foraminal_narrowing_l4_l5', 'right_neural_foraminal_narrowing_l5_s1', 'left_subarticular_stenosis_l1_l2', 'left_subarticular_stenosis_l2_l3', 'left_subarticular_stenosis_l3_l4', 'left_subarticular_stenosis_l4_l5', 'left_subarticular_stenosis_l5_s1', 'right_subarticular_stenosis_l1_l2', 'right_subarticular_stenosis_l2_l3', 'right_subarticular_stenosis_l3_l4', 'right_subarticular_stenosis_l4_l5', 'right_subarticular_stenosis_l5_s1']
Coordinates DataFrame Column

  df.replace(label2id, inplace=True)


In [5]:
class Normalize3D(object):
    def __init__(self, mean, std):
        """
        Args:
            mean (list or tuple): Mean values for each channel.
            std (list or tuple): Standard deviation values for each channel.
        """
        self.mean = torch.tensor(mean).view(-1, 1, 1, 1)  # Shape: [C, 1, 1, 1]
        self.std = torch.tensor(std).view(-1, 1, 1, 1)    # Shape: [C, 1, 1, 1]
    
    def __call__(self, tensor):
        """
        Args:
            tensor (torch.Tensor): Tensor image of size [C, D, H, W] to be normalized.
        
        Returns:
            torch.Tensor: Normalized tensor.
        """
        return (tensor - self.mean) / self.std

In [6]:
class LumbarSpine3DDataset(Dataset):
    def __init__(self, df, coordinates_df, series_description_df, root_dir, transform=None, target_slices=10):
        """
        Args:
            df (DataFrame): DataFrame containing labels.
            coordinates_df (DataFrame): DataFrame containing coordinates.
            series_description_df (DataFrame): DataFrame containing series descriptions.
            root_dir (str): Root directory for MRI scans.
            transform (callable, optional): Optional transform to be applied on a sample.
            target_slices (int): Number of slices to resample or pad to.
        """
        self.df = df
        self.coordinates_df = coordinates_df
        self.series_description_df = series_description_df
        self.root_dir = root_dir
        self.transform = transform
        self.target_slices = target_slices
        self.labels = LABELS

        # Generate a list of study IDs
        self.study_ids = self.df['study_id'].unique()
        self.labels = LABELS

    def __len__(self):
        return len(self.study_ids)

    def __getitem__(self, idx):
        study_id = self.study_ids[idx]
        study_data = self.df[self.df['study_id'] == study_id]
        labels = study_data[self.labels].values.flatten()
        
        # Convert labels to torch.long without casting via .astype(int)
        labels = torch.tensor(labels, dtype=torch.long)

        # Load scans for each series description
        scans = []
        for series in SERIES_DESCRIPTIONS:
            # Get the series ID for the current study and series description
            series_id = self.get_series_id(study_id, series)
            if series_id is None:
                # Handle missing series by adding zeroed slices
                scans.append(torch.zeros((1, self.target_slices, 512, 512)))
                continue

            # Load the scan
            scan = self.load_scan(study_id, series_id)
            if scan is None:
                scans.append(torch.zeros((1, self.target_slices, 512, 512)))
                continue

            # Resample or pad to target_slices
            scan = self.resample_slices(scan, target_slices=self.target_slices)

            # Convert to tensor and add channel dimension
            scan = torch.from_numpy(scan).float()  # Shape: [slices, H, W]
            scan = scan.unsqueeze(0)  # Shape: [1, slices, H, W]
            scans.append(scan)

        # Concatenate scans along the channel dimension
        # Resulting shape: [channels, slices, H, W]
        scan = torch.cat(scans, dim=0)

        # Get coordinates
        coords = self.get_coordinates(study_id)
        coords = torch.tensor(coords).float()

        # Apply transforms
        if self.transform:
            scan = self.transform(scan)

        sample = {'scan': scan, 'labels': labels, 'coords': coords}
        return sample

    def load_scan(self, study_id, series_id):
        """
        Load all PNG slices for a given study_id and series_id.
        Args:
            study_id (str/int): The study identifier.
            series_id (str/int): The series identifier.
        Returns:
            np.ndarray: 3D array of shape [slices, H, W] or None if not found.
        """
        series_dir = os.path.join(self.root_dir, str(study_id), str(series_id))
        if not os.path.exists(series_dir):
            return None
        slice_files = sorted(os.listdir(series_dir), key=lambda x: int(os.path.splitext(x)[0]))
        slices = []
        for slice_file in slice_files:
            slice_path = os.path.join(series_dir, slice_file)
            slice_data = self.load_slice(slice_path)
            if slice_data is None:
                continue
            slices.append(slice_data)
        if not slices:
            return None
        volume = np.stack(slices, axis=0)  # Shape: [slices, H, W]
        return volume

    def load_slice(self, slice_path):
        """
        Load a single PNG slice.
        Args:
            slice_path (str): Path to the PNG file.
        Returns:
            np.ndarray: 2D array of the image or None if failed.
        """
        try:
            img = Image.open(slice_path).convert('L')  # Convert to grayscale
            img = img.resize((512, 512))  # Resize to 512x512 if necessary
            img = np.array(img).astype(np.float32)
            if np.isnan(img).any():
                print(f"NaN values found in slice {slice_path}. Replacing with zeros.")
                img = np.nan_to_num(img)
            return img
        except Exception as e:
            print(f"Error loading slice {slice_path}: {e}")
            return None

    def resample_slices(self, volume, target_slices=10):
        """
        Resample or pad the number of slices to target_slices.
        Args:
            volume (np.ndarray): 3D array of shape [slices, H, W].
            target_slices (int): Desired number of slices.
        Returns:
            np.ndarray: Resampled 3D array.
        """
        current_slices = volume.shape[0]
        if current_slices == target_slices:
            return volume
        elif current_slices > target_slices:
            indices = np.linspace(0, current_slices - 1, target_slices).astype(int)
            return volume[indices]
        else:
            # Pad with zeros
            pad_width = target_slices - current_slices
            padding = ((0, pad_width), (0, 0), (0, 0))
            return np.pad(volume, padding, mode='constant', constant_values=0)

    def get_series_id(self, study_id, series_description):
        """
        Get the series_id for a given study_id and series_description.
        Args:
            study_id (str/int): The study identifier.
            series_description (str): The series description.
        Returns:
            str/int or None: The series_id or None if not found.
        """
        series_info = self.series_description_df[
            (self.series_description_df['study_id'] == study_id) &
            (self.series_description_df['series_description'] == series_description)
        ]
        if series_info.empty:
            return None
        return series_info.iloc[0]['series_id']

    def get_coordinates(self, study_id):
        """
        Extract coordinates for all conditions and levels for the study.
        Args:
            study_id (str/int): The study identifier.
        Returns:
            list: List of [x_scaled, y_scaled] for each condition and level.
        """
        study_coords = self.coordinates_df[self.coordinates_df['study_id'] == study_id]
        coords = []
        for condition in CONDITIONS:
            condition_name = condition.replace('_', ' ').title()
            for level in LEVELS:
                level_name = level.upper().replace('_', '/')  # Convert 'l1_l2' to 'L1/L2'
                coord_entry = study_coords[
                    (study_coords['condition'] == condition_name) &
                    (study_coords['level'] == level_name)
                ]
                if not coord_entry.empty:
                    x = coord_entry.iloc[0]['x_scaled']
                    y = coord_entry.iloc[0]['y_scaled']
                    coords.extend([x, y])
                else:
                    # If no coordinate, fill with zeros
                    coords.extend([0.0, 0.0])
        return coords  # List of coordinates for all conditions and levels

In [7]:
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride
        self.se = SEBlock(planes)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.se(out)  # Apply SE block

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=512):
        super(ResNet3D, self).__init__()
        self.in_planes = 64
        # Updated to accept 3 channels instead of 1
        self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2), padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_planes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        # Input x: [batch_size, 3, slices, H, W]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)  # [batch, 64, slices, H/4, W/4]
        x = self.layer2(x)  # [batch, 128, slices/2, H/8, W/8]
        x = self.layer3(x)  # [batch, 256, slices/4, H/16, W/16]
        x = self.layer4(x)  # [batch, 512, slices/8, H/32, W/32]

        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x  # [batch_size, num_classes]

def resnet18_3d(num_classes=512):
    """Constructs a ResNet-18 3D model."""
    model = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=num_classes)
    return model

class CoordAttention3DResNet(nn.Module):
    def __init__(self, num_classes, coord_dim):
        super(CoordAttention3DResNet, self).__init__()
        self.resnet3d = ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=512)
        self.fc = nn.Linear(512, num_classes)
        self.coord_attention = CoordAttentionModule(512, coord_dim)

    def forward(self, x, coords=None):
        x = self.resnet3d(x)  # [batch_size, 512]
        if self.training and coords is not None:
            x = self.coord_attention(x, coords)
        x = self.fc(x)
        return x

class CoordAttentionModule(nn.Module):
    def __init__(self, feature_dim, coord_dim):
        super(CoordAttentionModule, self).__init__()
        self.attention_fc = nn.Sequential(
            nn.Linear(coord_dim, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim),
            nn.Sigmoid()
        )

    def forward(self, x, coords):
        attention_weights = self.attention_fc(coords)  # [batch_size, feature_dim]
        x = x * attention_weights  # Element-wise multiplication
        return x


In [8]:
def train_model(dataset, num_classes, num_epochs=25, k_folds=5, batch_size=2):
    """
    Train the model using K-Fold cross-validation.
    Args:
        dataset (Dataset): The dataset to train on.
        num_classes (int): Number of classes per label.
        num_epochs (int): Number of training epochs.
        k_folds (int): Number of K-Folds.
        batch_size (int): Batch size for training.
    Returns:
        dict: Validation loss for each fold.
    """
    kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    fold_performance = {}

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f'\nFold {fold + 1}/{k_folds}')
        print('--------------------------------')

        # Sample elements randomly from a given list of indices
        train_subsampler = Subset(dataset, train_idx)
        val_subsampler = Subset(dataset, val_idx)

        # Define data loaders
        train_loader = DataLoader(train_subsampler, batch_size=batch_size, shuffle=True,
                                  num_workers=4, pin_memory=True)
        val_loader = DataLoader(val_subsampler, batch_size=batch_size, shuffle=False,
                                num_workers=4, pin_memory=True)

        # Initialize the model
        coord_dim = len(CONDITIONS) * len(LEVELS) * 2  # 2 coordinates per condition per level
        num_labels = len(LABELS)
        model = CoordAttention3DResNet(num_classes=num_labels * num_classes, coord_dim=coord_dim)
        model.to(device)

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss(ignore_index=-100)
        optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
                                                         patience=2, factor=0.5, verbose=True)

        # Early stopping parameters
        early_stopping_patience = 5
        best_val_loss = np.inf
        epochs_no_improve = 0
        best_model_wts = copy.deepcopy(model.state_dict())

        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch + 1}/{num_epochs}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                    dataloader = train_loader
                else:
                    model.eval()   # Set model to evaluate mode
                    dataloader = val_loader

                running_loss = 0.0

                # Iterate over data
                for batch in tqdm(dataloader, desc=f'{phase.capitalize()} Progress'):
                    scans = batch['scan'].to(device)      # [batch_size, 3, 10, 512, 512]
                    labels = batch['labels'].to(device)    # [batch_size, num_labels]
                    coords = batch['coords'].to(device)    # [batch_size, coord_dim]

                    optimizer.zero_grad()

                    # Forward
                    with torch.set_grad_enabled(phase == 'train'):
                        if phase == 'train':
                            outputs = model(scans, coords)  # [batch_size, num_labels * num_classes]
                        else:
                            outputs = model(scans)          # [batch_size, num_labels * num_classes]

                        # Reshape outputs and labels for loss computation
                        outputs = outputs.view(-1, num_classes)  # [batch_size * num_labels, num_classes]
                        labels = labels.view(-1)                  # [batch_size * num_labels]

                        loss = criterion(outputs, labels)

                        # Backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # Statistics
                    running_loss += loss.item() * scans.size(0)

                epoch_loss = running_loss / len(dataloader.dataset)

                print(f'{phase.capitalize()} Loss: {epoch_loss:.4f}')

                # Deep copy the model
                if phase == 'val':
                    scheduler.step(epoch_loss)
                    if epoch_loss < best_val_loss:
                        best_val_loss = epoch_loss
                        best_model_wts = copy.deepcopy(model.state_dict())
                        epochs_no_improve = 0
                        print(f'Validation loss decreased to {best_val_loss:.4f}. Saving model...')
                    else:
                        epochs_no_improve += 1
                        print(f'No improvement in validation loss for {epochs_no_improve} epochs.')

            # Check early stopping condition
            if epochs_no_improve >= early_stopping_patience:
                print(f'\nEarly stopping triggered after {early_stopping_patience} epochs without improvement.')
                break

        print(f'\nBest Validation Loss for Fold {fold + 1}: {best_val_loss:.4f}')

        # Load best model weights
        model.load_state_dict(best_model_wts)

        # Save the best model for this fold
        torch.save(model.state_dict(), f'model_fold_{fold + 1}.pth')
        print(f'Saved best model for Fold {fold + 1}.')

        # Record fold performance
        fold_performance[fold + 1] = best_val_loss

    return fold_performance


In [9]:
def weighted_log_loss(outputs, targets, severity_weights, ignore_index=-100):
    """
    Calculate the weighted log loss.
    Args:
        outputs (torch.Tensor): Logits from the model, shape [N, C].
        targets (torch.Tensor): Ground truth labels, shape [N].
        severity_weights (torch.Tensor): Weights for each class, shape [C].
        ignore_index (int): Label to ignore.
    Returns:
        torch.Tensor: Weighted log loss.
    """
    # Apply log_softmax to get log probabilities
    log_probs = F.log_softmax(outputs, dim=1)  # [N, C]
    
    # Gather the log probabilities corresponding to the targets
    targets = targets.view(-1, 1)
    log_probs = log_probs.gather(1, targets).squeeze(1)  # [N]
    
    # Get the weights for each target
    weights = severity_weights[targets.squeeze(1)]  # [N]
    
    # Compute loss, ignoring the ignore_index
    loss = -log_probs * weights
    loss = loss[targets.squeeze(1) != ignore_index]
    return loss.mean()


In [10]:
# Define transformations using the custom Normalize3D
transform = transforms.Compose([
    Normalize3D(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Adjust mean and std as needed
])

# Create the dataset
root_dir = './rsna_output/cvt_png'  # Adjust this path to where your PNG images are stored
dataset = LumbarSpine3DDataset(df, coordinates_df, series_description_df, root_dir,
                               transform=transform, target_slices=10)

# Verify a few samples to ensure data is loaded correctly
for i in range(3):
    sample = dataset[i]
    print(f"Sample {i}:")
    print(f"Scan shape: {sample['scan'].shape}")  # Expected: [3, 10, 512, 512]
    print(f"Labels: {sample['labels']}")
    print(f"Coordinates: {sample['coords']}")
    print("\n")


Sample 0:
Scan shape: torch.Size([3, 10, 512, 512])
Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
        0])
Coordinates: tensor([258.2655, 182.3717, 256.4572, 236.5714, 258.4243, 297.4546, 268.2336,
        341.8619, 282.7328, 387.1717, 261.4276, 168.0283, 255.0954, 226.8269,
        250.5045, 289.6601, 248.6726, 335.4562, 262.8008, 385.9431, 259.4264,
        170.3403, 255.5105, 221.2467, 249.6367, 280.9637,   0.0000,   0.0000,
          0.0000,   0.0000, 286.6023, 257.9768, 289.5676, 254.0232, 281.6602,
        252.0463, 276.7181, 251.0579, 287.5907, 258.9652, 232.4620, 253.7994,
        233.4401, 251.3543, 228.5499, 249.8873, 235.8851, 252.8214, 233.4401,
        258.2006])


Sample 1:
Scan shape: torch.Size([3, 10, 512, 512])
Labels: tensor([0, 0, 1, 2, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 2, 0, 0, 1, 1, 1,
        0])
Coordinates: tensor([  0.0000,   0.0000,   0.0000,   0.0000, 310.7159, 257.5971,   0.0000,
          0.0000, 345.368

In [None]:
# Number of classes per label (e.g., 3 for Normal/Mild, Moderate, Severe)
k_folds = 2  # Number of folds you used during training
num_classes = 3  # Number of classes per label
num_epochs = 5  # Number of epochs you trained for


# Train the model
fold_performance = train_model(dataset, num_classes=num_classes, num_epochs=num_epochs, k_folds=k_folds, batch_size=2)

# Print fold performance
for fold, loss in fold_performance.items():
    print(f'Fold {fold}, Best Validation Loss: {loss:.4f}')


Fold 1/2
--------------------------------





Epoch 1/5
----------


Train Progress: 100%|██████████| 494/494 [05:58<00:00,  1.38it/s]


Train Loss: 0.6086


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.36it/s]


Val Loss: 0.6379
Validation loss decreased to 0.6379. Saving model...

Epoch 2/5
----------


Train Progress: 100%|██████████| 494/494 [05:58<00:00,  1.38it/s]


Train Loss: 0.5939


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.35it/s]


Val Loss: 0.6363
Validation loss decreased to 0.6363. Saving model...

Epoch 3/5
----------


Train Progress: 100%|██████████| 494/494 [05:59<00:00,  1.37it/s]


Train Loss: 0.5918


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.34it/s]


Val Loss: 0.6634
No improvement in validation loss for 1 epochs.

Epoch 4/5
----------


Train Progress: 100%|██████████| 494/494 [05:59<00:00,  1.37it/s]


Train Loss: 0.5900


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.34it/s]


Val Loss: 0.6705
No improvement in validation loss for 2 epochs.

Epoch 5/5
----------


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.35it/s]s]


Val Loss: 0.6136
Validation loss decreased to 0.6136. Saving model...

Epoch 2/5
----------


Train Progress: 100%|██████████| 494/494 [05:59<00:00,  1.37it/s]


Train Loss: 0.6042


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.36it/s]


Val Loss: 0.6217
No improvement in validation loss for 1 epochs.

Epoch 3/5
----------


Train Progress: 100%|██████████| 494/494 [05:59<00:00,  1.38it/s]


Train Loss: 0.6031


Val Progress: 100%|██████████| 494/494 [01:53<00:00,  4.35it/s]


Val Loss: 0.6488
No improvement in validation loss for 2 epochs.

Epoch 4/5
----------


Train Progress:  40%|████      | 199/494 [02:24<03:33,  1.38it/s]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class EnsembleModel(nn.Module):
    def __init__(self, model_class, model_paths, device):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList()
        for path in model_paths:
            # Initialize the model with necessary parameters
            model = model_class()
            # Load the saved state dict
            model.load_state_dict(torch.load(path, map_location=device))
            model.to(device)
            model.eval()  # Set model to evaluation mode
            self.models.append(model)
        self.device = device

    def forward(self, x):
        outputs_list = []
        for model in self.models:
            outputs = model(x)
            outputs_list.append(outputs)
        # Stack outputs and take mean over the ensemble dimension
        outputs = torch.stack(outputs_list, dim=0)
        avg_outputs = torch.mean(outputs, dim=0)
        return avg_outputs


In [None]:
# Assuming you have defined the necessary imports and variables earlier
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Paths to your saved models from each fold
model_paths = [f'model_fold_{i+1}.pth' for i in range(k_folds)]

# Define the model class with necessary parameters
coord_dim = len(CONDITIONS) * len(LEVELS) * 2  # As defined in your dataset
num_labels = len(LABELS)

# Since your model requires parameters, use a lambda function to pass them
ensemble_model = EnsembleModel(
    model_class=lambda: CoordAttention3DResNet(num_classes=num_labels * num_classes, coord_dim=coord_dim),
    model_paths=model_paths,
    device=device
)

ensemble_model.to(device)
ensemble_model.eval()

# Optionally save the ensemble model's state_dict
torch.save(ensemble_model.state_dict(), f'ensemble_3d_resnet_model_F{k_folds}_E{num_epochs}.pth')
