In [None]:
# ======================================
# Part 1: Data Preparation and Dataset Class
# ======================================

# ======================================
# Import Necessary Libraries
# ======================================
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import h5py
import matplotlib.pyplot as plt
import random
import os
from torchvision import models

# Install necessary packages if not already installed
try:
    from fvcore.nn import FlopCountAnalysis, parameter_count
except ImportError:
    !pip install fvcore -q
    from fvcore.nn import FlopCountAnalysis, parameter_count

from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# ======================================
# Reproducibility
# ======================================
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
set_seed(42)

# ======================================
# Device Configuration
# ======================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# ======================================
# Data Preparation
# ======================================
# File path to the dataset
data_path = '/kaggle/input/tcir-cpac-io-sh-h5-file/TCIR-CPAC_IO_SH.h5'

# Load Dataset Information and Filter
try:
    data_info = pd.read_hdf(data_path, key="info", mode='r')
except Exception as e:
    print(f"Error loading HDF5 file: {e}")
    raise

# Filter for the 'SH' dataset
data_info_filtered = data_info[data_info['data_set'].isin(['SH'])]

# Undersampling for Balanced Dataset
low_vmax_threshold = np.percentile(data_info_filtered['Vmax'].values, 35)
low_vmax_indices = data_info_filtered[
    data_info_filtered['Vmax'] <= low_vmax_threshold].index
remaining_indices = data_info_filtered[
    data_info_filtered['Vmax'] > low_vmax_threshold].index
undersample_ratio = 0.3
undersample_size = int(len(low_vmax_indices) * undersample_ratio) \
    if len(low_vmax_indices) > 0 else 0

if undersample_size > 0:
    undersample_indices = np.random.choice(
        low_vmax_indices, undersample_size, replace=False)
    balanced_indices = np.concatenate(
        (undersample_indices, remaining_indices))
else:
    balanced_indices = remaining_indices

data_info_balanced = data_info_filtered.loc[
    balanced_indices].reset_index()

# Define Transformations
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(degrees=(0, 360), fill=0),
    transforms.CenterCrop(size=(152, 152)),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

val_test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.CenterCrop(size=(152, 152)),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Dataset Class
class TCIRLazyDataset(Dataset):
    def __init__(self, hdf5_file, data_info, channels=[0, 1, 3],
                 transform=None):
        self.hdf5_file = hdf5_file
        self.data_info = data_info
        self.channels = channels
        self.transform = transform
        self.channel_norm_values = {0: 350, 1: 275, 3: 4.35}
        try:
            self.hf = h5py.File(self.hdf5_file, 'r')
        except Exception as e:
            print(f"Error opening HDF5 file: {e}")
            raise

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        try:
            hdf5_index = self.data_info.at[idx, 'index']
            data_matrix = self.hf['matrix'][hdf5_index, :, :, self.channels]
        except Exception as e:
            print(f"Error accessing data at index {idx}: {e}")
            raise

        # Normalize image data
        image = np.empty_like(data_matrix, dtype=np.float32)
        for i, ch in enumerate(self.channels):
            norm_value = self.channel_norm_values.get(ch, 1.0)
            channel_data = data_matrix[:, :, i]
            channel_data = np.clip(channel_data, None, norm_value)
            image[:, :, i] = np.nan_to_num(channel_data / norm_value)

        # Convert to tensor and apply transformations
        image = torch.tensor(image).permute(2, 0, 1)
        if self.transform:
            image = self.transform(image)

        # Retrieve the label (Vmax)
        label = torch.tensor(
            self.data_info.at[idx, 'Vmax'], dtype=torch.float32)
        return image, label

    def __del__(self):
        if hasattr(self, 'hf') and self.hf:
            self.hf.close()

# Split Dataset
full_dataset_size = len(data_info_balanced)
indices = list(range(full_dataset_size))
np.random.shuffle(indices)
train_size = int(0.7 * full_dataset_size)
val_size = int(0.15 * full_dataset_size)
test_size = full_dataset_size - train_size - val_size
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

train_data_info = data_info_balanced.iloc[
    train_indices].reset_index(drop=True)
val_data_info = data_info_balanced.iloc[
    val_indices].reset_index(drop=True)
test_data_info = data_info_balanced.iloc[
    test_indices].reset_index(drop=True)

# Create Dataset Instances
train_dataset = TCIRLazyDataset(
    data_path, train_data_info, transform=train_transform)
val_dataset = TCIRLazyDataset(
    data_path, val_data_info, transform=val_test_transform)
test_dataset = TCIRLazyDataset(
    data_path, test_data_info, transform=val_test_transform)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True,
    num_workers=3, pin_memory=True)
val_loader = DataLoader(
    val_dataset, batch_size=32, shuffle=False,
    num_workers=3, pin_memory=True)
test_loader = DataLoader(
    test_dataset, batch_size=32, shuffle=False,
    num_workers=3, pin_memory=True)

print(f'Total Samples - Train: {len(train_dataset)}, '
      f'Validation: {len(val_dataset)}, Test: {len(test_dataset)}')

# ======================================
# Part 2: Model Definitions (Including CBAM)
# ======================================

# ======================================
# CBAM Attention Module
# ======================================
# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = avg_out + max_out
        return self.sigmoid(out).view(b, c, 1, 1)

# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(
            2, 1, kernel_size, padding=kernel_size // 2,
            bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_conv = self.conv(x_cat)
        return self.sigmoid(x_conv)

# CBAM Module
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction=8, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(
            in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x_out = x * self.channel_attention(x)
        x_out = x_out * self.spatial_attention(x_out)
        return x_out

# ======================================
# Function to Create Models Dynamically
# ======================================
def create_model(architecture, num_channels=3, use_cbam=False):
    """
    Creates a model based on the specified architecture and
    whether to use CBAM.

    Args:
        architecture (str): Name of the architecture ('resnet18', 'resnet34',
                            'resnet50', 'densenet121', 'densenet169', 'densenet201').
        num_channels (int): Number of input channels.
        use_cbam (bool): Whether to include CBAM modules.

    Returns:
        model (nn.Module): The constructed model.
    """
    # Define a mapping from architecture to weights enum
    weights_dict = {
        'resnet18': models.ResNet18_Weights.DEFAULT,
        'resnet34': models.ResNet34_Weights.DEFAULT,
        'resnet50': models.ResNet50_Weights.DEFAULT,
        'densenet121': models.DenseNet121_Weights.DEFAULT,
        'densenet169': models.DenseNet169_Weights.DEFAULT,
        'densenet201': models.DenseNet201_Weights.DEFAULT
    }
    
    if architecture not in weights_dict:
        raise ValueError(f"Architecture {architecture} not supported.")
    
    weights = weights_dict[architecture]
    
    if 'resnet' in architecture:
        # Load the base model with pretrained weights
        base_model = getattr(models, architecture)(weights=weights)
        # Modify the first conv layer
        base_model.conv1 = nn.Conv2d(
            num_channels, 64, kernel_size=7, stride=2,
            padding=3, bias=False)
        # Modify the fully connected layer
        base_model.fc = nn.Sequential(
            nn.Linear(base_model.fc.in_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 1)
        )

        if use_cbam:
            # Insert CBAM modules after each residual layer
            model = ResNet_CBAM_Modified(base_model)
        else:
            model = base_model

    elif 'densenet' in architecture:
        # Load the base model with pretrained weights
        base_model = getattr(models, architecture)(weights=weights)
        # Modify the first conv layer
        base_model.features.conv0 = nn.Conv2d(
            num_channels, base_model.features.conv0.out_channels,
            kernel_size=7, stride=2, padding=3, bias=False)
        # Modify the classifier
        base_model.classifier = nn.Sequential(
            nn.Linear(base_model.classifier.in_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, 1)
        )

        if use_cbam:
            # Manually set the parameters based on the architecture
            if architecture == 'densenet121':
                growth_rate = 32
                block_config = (6, 12, 24, 16)
                compression = 0.5
            elif architecture == 'densenet169':
                growth_rate = 32
                block_config = (6, 12, 32, 32)
                compression = 0.5
            elif architecture == 'densenet201':
                growth_rate = 32
                block_config = (6, 12, 48, 32)
                compression = 0.5
            else:
                raise ValueError(f"Unknown architecture {architecture}")

            # Insert CBAM modules after each dense block
            model = DenseNet_CBAM_Modified(
                base_model, growth_rate, block_config, compression)
        else:
            model = base_model
    else:
        raise ValueError(f"Architecture {architecture} not supported.")
    return model

# ======================================
# Modified ResNet and DenseNet Classes with CBAM
# ======================================
class ResNet_CBAM_Modified(nn.Module):
    def __init__(self, base_model):
        super(ResNet_CBAM_Modified, self).__init__()
        self.base_model = base_model

        # Determine the expansion factor
        if isinstance(self.base_model.layer1[0], models.resnet.BasicBlock):
            expansion = 1
        elif isinstance(self.base_model.layer1[0], models.resnet.Bottleneck):
            expansion = self.base_model.layer1[0].expansion  # This will be 4
        else:
            raise NotImplementedError('Unknown block type')

        # Define CBAM modules with correct in_channels
        self.cbam1 = CBAM(64 * expansion)
        self.cbam2 = CBAM(128 * expansion)
        self.cbam3 = CBAM(256 * expansion)
        self.cbam4 = CBAM(512 * expansion)

    def forward(self, x):
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        x = self.base_model.maxpool(x)

        x = self.base_model.layer1(x)
        x = self.cbam1(x)

        x = self.base_model.layer2(x)
        x = self.cbam2(x)

        x = self.base_model.layer3(x)
        x = self.cbam3(x)

        x = self.base_model.layer4(x)
        x = self.cbam4(x)

        x = self.base_model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.base_model.fc(x)
        return x

class DenseNet_CBAM_Modified(nn.Module):
    def __init__(self, base_model, growth_rate, block_config, compression):
        super(DenseNet_CBAM_Modified, self).__init__()
        self.base_model = base_model

        # Retrieve initial parameters
        num_init_features = base_model.features.conv0.out_channels  # Typically 64

        # Compute number of features after each dense block
        num_features = num_init_features  # Start with num_init_features (64)

        # After DenseBlock1
        num_features1 = num_features + block_config[0] * growth_rate
        num_features = int(num_features1 * compression)

        # After DenseBlock2
        num_features2 = num_features + block_config[1] * growth_rate
        num_features = int(num_features2 * compression)

        # After DenseBlock3
        num_features3 = num_features + block_config[2] * growth_rate
        num_features = int(num_features3 * compression)

        # After DenseBlock4
        num_features4 = num_features + block_config[3] * growth_rate
        # No compression after the last block

        # Initialize CBAM modules with correct in_channels
        self.cbam1 = CBAM(num_features1)
        self.cbam2 = CBAM(num_features2)
        self.cbam3 = CBAM(num_features3)
        self.cbam4 = CBAM(num_features4)

    def forward(self, x):
        features = self.base_model.features.conv0(x)
        features = self.base_model.features.norm0(features)
        features = self.base_model.features.relu0(features)
        features = self.base_model.features.pool0(features)

        # DenseBlock1 + CBAM1
        features = self.base_model.features.denseblock1(features)
        features = self.cbam1(features)
        features = self.base_model.features.transition1(features)

        # DenseBlock2 + CBAM2
        features = self.base_model.features.denseblock2(features)
        features = self.cbam2(features)
        features = self.base_model.features.transition2(features)

        # DenseBlock3 + CBAM3
        features = self.base_model.features.denseblock3(features)
        features = self.cbam3(features)
        features = self.base_model.features.transition3(features)

        # DenseBlock4 + CBAM4
        features = self.base_model.features.denseblock4(features)
        features = self.cbam4(features)

        features = self.base_model.features.norm5(features)
        features = nn.functional.relu(features, inplace=True)
        out = nn.functional.adaptive_avg_pool2d(features, (1, 1))
        out = torch.flatten(out, 1)
        out = self.base_model.classifier(out)
        return out

# ======================================
# Part 3: Training, Evaluation, and Results Compilation
# ======================================

# ======================================
# Training and Evaluation Functions
# ======================================
def train_and_evaluate(model, train_loader, val_loader,
                       num_epochs=100, patience=15,
                       checkpoint_path='best_model.pth'):
    """
    Trains the model and evaluates it on the validation set.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        num_epochs (int): Maximum number of epochs.
        patience (int): Early stopping patience.
        checkpoint_path (str): Path to save the best model.

    Returns:
        train_losses (list): List of training losses per epoch.
        val_losses (list): List of validation losses per epoch.
        model (nn.Module): The best trained model.
    """
    model.to(device)
    criterion = nn.SmoothL1Loss()
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=5)

    best_val_loss = float('inf')
    early_stop_counter = 0
    train_losses, val_losses = [], []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)

        train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)

        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)

        scheduler.step(val_loss)
        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Train Loss: {train_loss:.3f}, "
              f"Val Loss: {val_loss:.3f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), checkpoint_path)
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered")
                break

    # Load the best model
    model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
    return train_losses, val_losses, model

def evaluate_model(model, test_loader):
    """
    Evaluates the model on the test set.

    Args:
        model (nn.Module): The trained model.
        test_loader (DataLoader): DataLoader for test data.

    Returns:
        mae (float): Mean Absolute Error.
        rmse (float): Root Mean Squared Error.
        r2 (float): RÂ² Score.
        predictions (np.array): Model predictions.
        actuals (np.array): Actual labels.
    """
    model.eval()
    model.to(device)
    predictions = []
    actuals = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images).squeeze(1).cpu().numpy()
            labels = labels.numpy()
            predictions.extend(outputs)
            actuals.extend(labels)
    predictions = np.array(predictions)
    actuals = np.array(actuals)
    mae = mean_absolute_error(actuals, predictions)
    rmse = np.sqrt(mean_squared_error(actuals, predictions))
    r2 = r2_score(actuals, predictions)
    return mae, rmse, r2, predictions, actuals

def get_model_complexity(model, input_size=(3, 224, 224)):
    """
    Computes the total number of parameters and FLOPs of the model.

    Args:
        model (nn.Module): The model to analyze.
        input_size (tuple): Size of the input tensor.

    Returns:
        total_params (int): Total number of parameters.
        total_flops (int): Total number of floating-point operations.
    """
    model.to('cpu')
    model.eval()
    dummy_input = torch.randn(1, *input_size)
    try:
        flop_count = FlopCountAnalysis(model, dummy_input)
        total_flops = flop_count.total()
    except Exception as e:
        print(f"Error calculating FLOPs: {e}")
        total_flops = None

    # Calculate total parameters
    total_params = parameter_count(model)['']
    return total_params, total_flops

# ======================================
# Model Architectures to Train
# ======================================
architectures = [
    'resnet18',
    'resnet34',
    'resnet50',
    'densenet121',
    'densenet169',
    'densenet201'
]

# List to store results
results = []

# Iterate over architectures and CBAM option
for architecture in architectures:
    for use_cbam in [True, False]:
        model_name = f"{architecture.upper()} " \
                     f"{'with CBAM' if use_cbam else 'without CBAM'}"
        print(f"\nTraining {model_name}...\n")

        # Create the model
        try:
            model = create_model(
                architecture, num_channels=3, use_cbam=use_cbam)
        except ValueError as ve:
            print(f"Error creating model {model_name}: {ve}")
            continue

        # Define checkpoint path
        checkpoint_path = f"best_model_{architecture}" \
                          f"{'_cbam' if use_cbam else ''}.pth"

        # Train the model
        train_losses, val_losses, best_model = train_and_evaluate(
            model, train_loader, val_loader,
            num_epochs=100, patience=15,
            checkpoint_path=checkpoint_path)

        # Evaluate on test set
        print(f"\nEvaluating {model_name} on test set...\n")
        mae, rmse, r2, predictions, actuals = evaluate_model(
            best_model, test_loader)
        print(f"{model_name} - Test MAE: {mae:.3f}, "
              f"RMSE: {rmse:.3f}, R2 Score: {r2:.3f}")

        # Compute FLOPs and Params
        print(f"\nCalculating FLOPs and Parameters "
              f"for {model_name}...\n")
        total_params, total_flops = get_model_complexity(best_model)
        if total_flops is not None:
            flops_display = f"{total_flops:,}"
        else:
            flops_display = "N/A"
        print(f"{model_name} - Total Params: "
              f"{total_params:,}, Total FLOPs: {flops_display}")

        # Save training history
        history_df = pd.DataFrame({
            'Epoch': range(1, len(train_losses) + 1),
            'Train Loss': train_losses,
            'Val Loss': val_losses
        })
        history_csv_path = f"training_history_{architecture}" \
                           f"{'_cbam' if use_cbam else ''}.csv"
        history_df.to_csv(history_csv_path, index=False)
        print(f"\nTraining history saved to '{history_csv_path}'")

        # Store results
        results.append({
            'Model': model_name,
            'MAE': mae,
            'RMSE': rmse,
            'R2 Score': r2,
            'Total Params': total_params,
            'Total FLOPs': total_flops
        })

# ======================================
# Compile Results into a Table
# ======================================
results_df = pd.DataFrame(results)
print("\nFinal Results:\n")
print(results_df)

# Save results to a CSV file
results_df.to_csv('model_evaluation_results.csv', index=False)
print("\nModel evaluation results saved to "
      "'model_evaluation_results.csv'")