In [None]:
# Environment Setup
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split,GroupShuffleSplit
from captum.attr import IntegratedGradients, Occlusion
import warnings
import re
from glob import glob
from tensorflow.keras.utils import to_categorical
import random

import numpy as np
import torch

warnings.filterwarnings('ignore')

# Set device
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple GPU
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:

import os
if os.path.isfile("myData.npz"):
    print("✅ Loading preprocessed features...")
    with np.load("myData.npz") as data:
        myData = data['myData']     
        myData2 = data['myData2']  
        myY = data['myY']       
        myActors = data['myActors']   
        myDatasets = data['myDatasets'] 

    print(f"Mel spectrograms shape: {myData.shape}")
    print(f"Topological features shape: {myData2.shape}")
    print(f"Labels shape: {myY.shape}")
    print(f"Number of unique actors: {len(np.unique(myActors))}")
    print(f"Number of unique datasets: {len(np.unique(myDatasets))}")
else:
    folder = './savefiles3'

    def findFilesFromPattern(pattern):
        pattern = re.compile(pattern + r'_(.*?)_(.*?)_(.*?)_(\d+)_(\d+)\.npy')
        heatmaps_dict = {}

        for filename in os.listdir(folder):
            match = pattern.match(filename)
            if match:
                dataset, actor, emotion, i, j = map(str, match.groups())
                i, j = int(i), int(j)
                filepath = os.path.join(folder, filename)
                data = np.load(filepath)

                heatmaps_dict[f'{dataset}_{actor}_{emotion}_{j // 2}_{j%2}'] = {'data': data, 'dataset': dataset, 'actor': actor, 'emotion':emotion, 'type': j}

        return heatmaps_dict

    melwasserstein = findFilesFromPattern('wassersteinHeat')
    meltimeeuclid = findFilesFromPattern('timeMetricHeat')
    meleuclid = findFilesFromPattern('euclideanHeat')

    def load_spectrograms(prefixes, path='./savefiles3'):
        patterns = []
        for prefix in prefixes:
            patterns.append(os.path.join(path, f"{prefix}_*.npy"))
        my_globs = glob(patterns[0])
        for pattern in patterns[1:]:
            my_globs = my_globs + glob(pattern)
        file_list = sorted(my_globs)
        return [np.load(file) for i, file in enumerate(file_list)]

    myRaw = load_spectrograms(["savee", 'tess', 'radvess', 'cremad'])
    print(len(melwasserstein))
    print(len([melwasserstein[key]['data'] for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 0]))
    print(len([melwasserstein[key]['data'] for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 1]))
    print(np.array([[meleuclid[key]['data'] for key in sorted(meleuclid.keys()) if meleuclid[key]['type'] == 0]]).shape)

    print(len(myRaw))

    myData = np.array([myRaw])
    print('finish data')
    myData = myData.astype('float32')
    myData = np.transpose(myData, (1, 2, 3, 0))
    myEmotionMap = {
        'neutral': 1, 'happy':2, 'sad':3, 'angry':4, 'fearful':5, 'disgust':6, 'calm':7, 'surprised':8
    }
    myY = np.array(
        [myEmotionMap[melwasserstein[key]['emotion']] -1 for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 0]
    )
    myActors = np.array(
        [melwasserstein[key]['actor'] + '_' + melwasserstein[key]['dataset']  for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 0]
    )
    myDatasets = np.array(
        [melwasserstein[key]['dataset']  for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 0]
    )
    print(np.unique(myActors))

    print(np.unique(myY))

    myY = to_categorical(myY, num_classes=6)

    myData2 = np.array([
                        [meleuclid[key]['data'] for key in sorted(meleuclid.keys()) if meleuclid[key]['type'] % 2 == 0],
                        [meleuclid[key]['data'] for key in sorted(meleuclid.keys()) if meleuclid[key]['type'] % 2 == 1],
                        [meltimeeuclid[key]['data'] for key in sorted(meltimeeuclid.keys()) if meltimeeuclid[key]['type'] % 2 == 0],
                        [meltimeeuclid[key]['data'] for key in sorted(meltimeeuclid.keys()) if meltimeeuclid[key]['type'] % 2 == 1],
                        [melwasserstein[key]['data'] for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 0],
                        [melwasserstein[key]['data'] for key in sorted(melwasserstein.keys()) if melwasserstein[key]['type'] % 2 == 1]
                        ])
    print('finish data')
    myData2 = myData2.astype('float32')
    print(myData2.shape)
    myData2 = np.transpose(myData2, (1, 2, 3, 0))
    print(myData2.shape)
    np.savez_compressed(
        "myData.npz",
        myData=myData,
        myData2=myData2,
        myY=myY,
        myActors=myActors,
        myDatasets=myDatasets,
    )

In [None]:

def visualize_topological_channels(data, labels, actors, num_samples=3):
    emotions = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgust']
    channel_names = [
        'Euclidean (H0)',
        'Euclidean (H1)',
        'Temporal (H0)',
        'Temporal (H1)',
        'Wasserstein (H0)',
        'Wasserstein (H1)'
    ]

    label_ids = np.argmax(labels, axis=1)

    for emotion_id in range(6):
        emotion_mask = label_ids == emotion_id
        emotion_samples = data[emotion_mask]

        indices = np.random.choice(len(emotion_samples), min(num_samples, len(emotion_samples)), replace=False)

        fig, axes = plt.subplots(num_samples, 6, figsize=(18, 3*num_samples))
        fig.suptitle(f'Topological Features: {emotions[emotion_id].upper()}', fontsize=16, y=1.00)

        for row, idx in enumerate(indices):
            sample = emotion_samples[idx]  
            for col in range(6):
                ax = axes[row, col] if num_samples > 1 else axes[col]
                ax.imshow(sample[:, :, col], cmap='viridis', aspect='auto')
                if row == 0:
                    ax.set_title(channel_names[col], fontsize=10)
                ax.axis('off')

        plt.tight_layout()
        plt.savefig(f'topology_{emotions[emotion_id]}.png', dpi=150, bbox_inches='tight')
        plt.show()
        print(f"Saved: topology_{emotions[emotion_id]}.png\n")

if 'myData2' in locals():
    print("Generating topological feature visualizations...")
    visualize_topological_channels(myData2, myY, myActors, num_samples=2)

In [None]:
class MelSpectrogramCNN(nn.Module):
    def __init__(self, num_classes=6):
        super(MelSpectrogramCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class TopologicalCNN(nn.Module):
    def __init__(self, num_classes=6):
        super(TopologicalCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(6, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x, return_embedding=False):
        f = self.features(x)
        z = self.classifier[:-1](f)  
        logits = self.classifier[-1](z)
        return z if return_embedding else logits


class CombinedFusionModel(nn.Module):

    def __init__(self, num_classes=6):
        super(CombinedFusionModel, self).__init__()
        self.mel_branch = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.2)
        )


        self.topo_branch = nn.Sequential(
            nn.Conv2d(6, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, mel_input, topo_input):
        mel_emb = self.mel_branch(mel_input)
        topo_emb = self.topo_branch(topo_input)
        combined = torch.cat([mel_emb, topo_emb], dim=1)
        return self.classifier(combined)

print("✅ Model architectures defined")

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit

def stratified_group_split_3way(
    y,
    groups,
    myDatasets,
    val_size=0.2,
    test_size=0.2,
    random_state=42,
):
    y = np.asarray(y)
    groups = np.asarray(groups)
    myDatasets = np.asarray(myDatasets)
    rng = np.random.default_rng(random_state)
    df = pd.DataFrame({"y": y, "group": groups, "dataset": myDatasets})

    group_df = (
        df.groupby("group")
          .agg(
              y_group=("y", lambda s: s.value_counts().index[0]),
              dataset_group=("dataset", lambda s: s.value_counts().index[0]),
          )
          .reset_index()
    )

    datasets = group_df["dataset_group"].unique()
    n_groups = len(group_df)

    ds_counts = group_df["dataset_group"].value_counts()
    bad = ds_counts[ds_counts < 2].index.to_list()
    if bad:
        raise ValueError(f"Cannot guarantee dataset coverage in both train and test; datasets with <2 groups: {bad}")

    # Target group counts
    n_test = max(int(np.round(test_size * n_groups)), len(datasets))
    n_val = int(np.round(val_size * n_groups))

    if n_test + n_val >= n_groups:
        raise ValueError("Not enough groups to allocate train/val/test with these sizes.")

    # After taking test+val, train must still have at least one per dataset
    if (n_groups - n_test - n_val) < len(datasets):
        raise ValueError(
            "Not enough remaining groups to guarantee at least one group per dataset in TRAIN. "
            "Reduce val_size and/or test_size."
        )

    forced_test = []
    forced_train = []

    for ds in datasets:
        ds_groups = group_df[group_df["dataset_group"] == ds].copy()

        y_counts = ds_groups["y_group"].value_counts()
        labels = y_counts.index.to_numpy()
        probs = (y_counts / y_counts.sum()).to_numpy()
        chosen_label = rng.choice(labels, p=probs)

        candidates = ds_groups.loc[ds_groups["y_group"] == chosen_label, "group"].to_numpy()
        g_test = rng.choice(candidates)
        forced_test.append(g_test)

        # choose a different group for train
        remaining = ds_groups.loc[ds_groups["group"] != g_test, "group"].to_numpy()
        g_train = rng.choice(remaining)
        forced_train.append(g_train)

    forced_test = np.array(pd.unique(forced_test))
    forced_train = np.array(pd.unique(forced_train))

    overlap = set(forced_train).intersection(set(forced_test))
    if overlap:
        for g in list(overlap):
            ds = group_df.loc[group_df["group"] == g, "dataset_group"].iloc[0]
            ds_groups = group_df[group_df["dataset_group"] == ds]["group"].to_numpy()
            used = set(forced_test).union(set(forced_train))
            repl = next((x for x in ds_groups if x not in used), None)
            if repl is None:
                repl = next(x for x in ds_groups if x not in set(forced_test))
            forced_train[forced_train == g] = repl

    forced_test_set = set(forced_test.tolist())
    forced_train_set = set(forced_train.tolist())

    # -------------------------
    # Fill remaining TEST slots (IMPORTANT: never sample forced_train into test)
    # -------------------------
    remaining_test_slots = n_test - len(forced_test_set)
    eligible_for_test = group_df[
        (~group_df["group"].isin(forced_test_set)) &
        (~group_df["group"].isin(forced_train_set))
    ].reset_index(drop=True)

    extra_test = np.array([], dtype=group_df["group"].dtype)
    if remaining_test_slots > 0:
        rem_n = len(eligible_for_test)
        if remaining_test_slots > rem_n:
            # If this happens, expand eligibility by allowing forced_train (but then train coverage may break).
            # Better to fail loudly.
            raise ValueError(
                "Not enough eligible groups to fill test without stealing reserved train-coverage groups. "
                "Reduce test_size."
            )

        if remaining_test_slots == rem_n:
            extra_test = eligible_for_test["group"].to_numpy()
        else:
            sss = StratifiedShuffleSplit(
                n_splits=1,
                test_size=remaining_test_slots / rem_n,
                random_state=random_state,
            )
            idx = np.arange(rem_n)
            _, extra_idx = next(sss.split(idx, eligible_for_test["y_group"]))
            extra_test = eligible_for_test.loc[extra_idx, "group"].to_numpy()

    test_groups = np.array(list(forced_test_set.union(set(extra_test.tolist()))))
    test_set = set(test_groups.tolist())

    # -------------------------
    # Build remaining pool for TRAIN/VAL
    #   - Keep forced_train always in TRAIN (never in VAL)
    # -------------------------
    remaining_after_test = group_df[~group_df["group"].isin(test_set)].reset_index(drop=True)

    split_pool = remaining_after_test[~remaining_after_test["group"].isin(forced_train_set)].reset_index(drop=True)

    n_val = min(n_val, len(split_pool))
    if n_val <= 0:
        val_set = set()
        train_set = set(remaining_after_test["group"].to_list())  # includes forced_train
    else:
        sss_inner = StratifiedShuffleSplit(
            n_splits=1,
            test_size=n_val / len(split_pool),
            random_state=random_state + 1,
        )
        idx = np.arange(len(split_pool))
        train_idx_g, val_idx_g = next(sss_inner.split(idx, split_pool["y_group"]))

        val_set = set(split_pool.loc[val_idx_g, "group"].to_list())
        train_set = set(remaining_after_test["group"].to_list()) - val_set  # ensures disjoint, includes forced_train

    # -------------------------
    # Sanity checks
    # -------------------------
    if train_set & test_set:
        raise RuntimeError("Train/Test overlap detected (groups).")
    if val_set & test_set:
        raise RuntimeError("Val/Test overlap detected (groups).")
    if train_set & val_set:
        raise RuntimeError("Train/Val overlap detected (groups).")

    expected_ds = set(group_df["dataset_group"].unique())
    got_test = set(group_df.loc[group_df["group"].isin(list(test_set)), "dataset_group"].unique())
    got_train = set(group_df.loc[group_df["group"].isin(list(train_set)), "dataset_group"].unique())

    if expected_ds - got_test:
        raise RuntimeError(f"Dataset coverage constraint failed in TEST: missing {expected_ds - got_test}")
    if expected_ds - got_train:
        raise RuntimeError(f"Dataset coverage constraint failed in TRAIN: missing {expected_ds - got_train}")

    # -------------------------
    # Map groups -> sample indices
    # -------------------------
    train_idx = np.where(df["group"].isin(list(train_set)))[0]
    val_idx = np.where(df["group"].isin(list(val_set)))[0]
    test_idx = np.where(df["group"].isin(list(test_set)))[0]

    return train_idx, val_idx, test_idx




# Perform split
if 'myY' in locals():
    groups = myActors
    train_idx, val_idx, test_idx = stratified_group_split_3way(
        y=np.argmax(myY, axis=1),
        myDatasets=myDatasets,
        groups=groups,
        val_size=0.2,
        test_size=0.2
    )

    print(f"Train speakers: {len(np.unique(myActors[train_idx]))}")
    print(f"Train datasets: {len(np.unique(myDatasets[train_idx]))}")
    print(f"Val speakers: {len(np.unique(myActors[val_idx]))}")
    print(f"Test speakers: {len(np.unique(myActors[test_idx]))}")
    print(f"Test datasets: {len(np.unique(myDatasets[test_idx]))}")

    print(f"\nTrain samples: {len(train_idx)}")
    print(f"Val samples: {len(val_idx)}")
    print(f"Test samples: {len(test_idx)}")

    # Global split
    X_train, X_val, X_test = myData[train_idx], myData[val_idx], myData[test_idx]
    X_train2, X_val2, X_test2 = myData2[train_idx], myData2[val_idx], myData2[test_idx]
    y_train, y_val, y_test = myY[train_idx], myY[val_idx], myY[test_idx]

    # myRavdessMask = [x == 'radvess' for x in myDatasets]
    # myNoRavdessMask = [x != 'radvess' for x in myDatasets]
    # sgss = GroupShuffleSplit(
    #     n_splits=1,
    #     test_size=0.2,
    #     random_state=42
    # )
    # test_idx = myRavdessMask
    # train_idx, val_idx = next(sgss.split(myData[myNoRavdessMask], groups=groups[myNoRavdessMask]))
    # X_train, X_val, X_test = myData[myNoRavdessMask][train_idx], myData[myNoRavdessMask][val_idx], myData[myRavdessMask]
    # X_train2, X_val2, X_test2 = myData2[myNoRavdessMask][train_idx], myData2[myNoRavdessMask][val_idx], myData2[myRavdessMask]
    # y_train, y_val, y_test = myY[myNoRavdessMask][train_idx], myY[myNoRavdessMask][val_idx], myY[myRavdessMask]

    X_train_tensor = torch.tensor(X_train.transpose(0, 3, 1, 2), dtype=torch.float32)
    X_train2_tensor = torch.tensor(X_train2.transpose(0, 3, 1, 2), dtype=torch.float32)
    y_train_tensor = torch.tensor(np.argmax(y_train, axis=1), dtype=torch.long)

    X_val_tensor = torch.tensor(X_val.transpose(0, 3, 1, 2), dtype=torch.float32)
    X_val2_tensor = torch.tensor(X_val2.transpose(0, 3, 1, 2), dtype=torch.float32)
    y_val_tensor = torch.tensor(np.argmax(y_val, axis=1), dtype=torch.long)

    X_test_tensor = torch.tensor(X_test.transpose(0, 3, 1, 2), dtype=torch.float32)
    X_test2_tensor = torch.tensor(X_test2.transpose(0, 3, 1, 2), dtype=torch.float32)
    y_test_tensor = torch.tensor(np.argmax(y_test, axis=1), dtype=torch.long)

    print(f"\nTensor shapes:")
    print(f"Train mel: {X_train_tensor.shape}")
    print(f"Train topo: {X_train2_tensor.shape}")
    print(f"Test mel: {X_test_tensor.shape}")
    print(f"Test topo: {X_test2_tensor.shape}")

In [None]:
myData[myNoRavdessMask].shape

In [None]:
np.unique(groups[myNoRavdessMask])

In [None]:
# myTessMask = [dataset != 'r' for dataset in myDatasets[train_idx]]
for i in (np.unique(myActors[myNoRavdessMask][train_idx])):
    print(i)

In [None]:
for i in (np.unique(myActors[test_idx])):
    print(i)

In [None]:
for i in (np.unique(myActors[val_idx])):
    print(i)

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=30, lr=1e-3,
                model_name="model", device=device):
    """
    Generic training function with validation.
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    auroc = MulticlassAUROC(num_classes=6, average='macro').to(device)
    top3acc = MulticlassAccuracy(num_classes=6, top_k=3).to(device)

    best_acc = 0.0
    history = {'train_loss': [], 'val_loss': [], 'val_auc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            if len(batch) == 3:  # Combined model
                X_batch, X2_batch, y_batch = batch
                X_batch, X2_batch, y_batch = X_batch.to(device), X2_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                outputs = model(X_batch, X2_batch)
            else: 
                X_batch, y_batch = batch
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                optimizer.zero_grad()
                outputs = model(X_batch)

            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_preds, val_labels = [], []
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                if len(batch) == 3:
                    X_val, X2_val, y_val = batch
                    X_val, X2_val, y_val = X_val.to(device), X2_val.to(device), y_val.to(device)
                    outputs = model(X_val, X2_val)
                else:
                    X_val, y_val = batch
                    X_val, y_val = X_val.to(device), y_val.to(device)
                    outputs = model(X_val)

                loss = criterion(outputs, y_val)
                val_loss += loss.item()
                preds = torch.softmax(outputs, dim=1)
                val_preds.append(preds)
                val_labels.append(y_val)

        val_preds = torch.cat(val_preds)
        val_labels = torch.cat(val_labels)

        val_auc = auroc(val_preds, val_labels).item()
        val_top3 = top3acc(val_preds, val_labels).item()
        y_pred = torch.argmax(val_preds, dim=1)
        accuracy = (y_pred == val_labels).float().mean().item()

        history['train_loss'].append(train_loss / len(train_loader))
        history['val_loss'].append(val_loss / len(val_loader))
        history['val_auc'].append(val_auc)
        history['val_acc'].append(accuracy)

        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"train_loss: {train_loss/len(train_loader):.4f} - "
              f"val_loss: {val_loss/len(val_loader):.4f} - "
              f"val_auc: {val_auc:.4f} - "
              f"val_top3: {val_top3:.4f} - "
              f"val_acc: {accuracy:.4f}")

        # Save best model
        if accuracy > best_acc:
            best_acc = accuracy
            torch.save(model.state_dict(), f"best_{model_name}.pth")
            print(f"✅ Saved best model: {accuracy:.4f}")

    return model, history


from collections import defaultdict

def evaluate_model(model, test_loader, model_path, device=device):
    """
    Evaluate model on test set and compute accuracy per dataset.
    """
    model.load_state_dict(torch.load(model_path))
    model.eval()

    auroc = MulticlassAUROC(num_classes=6, average='macro').to(device)
    top3acc = MulticlassAccuracy(num_classes=6, top_k=3).to(device)

    all_preds, all_labels, all_probs = [], [], []

    # Per-dataset tracking
    dataset_correct = defaultdict(int)
    dataset_total = defaultdict(int)

    with torch.no_grad():
        for batch in test_loader:

            if len(batch) == 4:
                # Two-input model
                X_batch, X2_batch, y_batch, dataset_batch = batch
                X_batch = X_batch.to(device)
                X2_batch = X2_batch.to(device)
                outputs = model(X_batch, X2_batch)

            elif len(batch) == 3:
                # Single-input model
                X_batch, y_batch, dataset_batch = batch
                X_batch = X_batch.to(device)
                outputs = model(X_batch)

            else:
                raise ValueError("Unexpected batch format")

            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            # Store global results
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y_batch.numpy())
            all_probs.append(probs)

            # Per-dataset accuracy
            for pred, label, ds in zip(preds.cpu(), y_batch, dataset_batch):
                dataset_total[ds] += 1
                if pred == label:
                    dataset_correct[ds] += 1

    all_probs = torch.cat(all_probs)
    all_labels_tensor = torch.tensor(all_labels).to(device)

    test_auc = auroc(all_probs, all_labels_tensor).item()
    test_top3 = top3acc(all_probs, all_labels_tensor).item()
    test_acc = (torch.argmax(all_probs, dim=1) == all_labels_tensor).float().mean().item()

    print(f"\nTest Results:")
    print(f"AUC: {test_auc:.4f}")
    print(f"Top-3 Accuracy: {test_top3:.4f}")
    print(f"Accuracy: {test_acc:.4f}")

    # Per-dataset accuracy
    print("\nAccuracy per dataset:")
    dataset_acc = {}
    for ds in dataset_total:
        acc = dataset_correct[ds] / dataset_total[ds]
        dataset_acc[ds] = acc
        print(f"{ds}: {acc:.4f}")

    # Classification report
    class_labels = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgust']
    report = classification_report(all_labels, all_preds, target_names=class_labels, output_dict=True)
    print("\nClassification Report:")
    print(report)

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    cm_df = pd.DataFrame(cm, index=class_labels, columns=class_labels)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues', cbar=True)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

    return {
        'auc': test_auc,
        'top3_acc': test_top3,
        'acc': test_acc,
        'acc_per_dataset': dataset_acc,
        'predictions': all_preds,
        'labels': all_labels,
        'report': report
    }

def get_macro_f1(res):
    # classification_report output dict typically has:
    # res["report"]["macro avg"]["f1-score"]
    return float(res["report"]["macro avg"]["f1-score"])

def summarize_with_f1(all_results):
    aucs  = np.array([r["auc"] for r in all_results], dtype=float)
    accs  = np.array([r["acc"] for r in all_results], dtype=float)
    top3s = np.array([r["top3_acc"] for r in all_results], dtype=float)
    f1s   = np.array([get_macro_f1(r) for r in all_results], dtype=float)

    return {
        "auc": (aucs.mean(),  aucs.std(ddof=1)),
        "acc": (accs.mean(),  accs.std(ddof=1)),
        "top3_acc": (top3s.mean(), top3s.std(ddof=1)),
        "macro_f1": (f1s.mean(),  f1s.std(ddof=1)),
    }

def summarize_acc_per_dataset(all_results):
    # assume each result has same dataset keys
    ds_keys = list(all_results[0]["acc_per_dataset"].keys())
    out = {}
    for ds in ds_keys:
        vals = np.array([r["acc_per_dataset"][ds] for r in all_results], dtype=float)
        out[ds] = (vals.mean(), vals.std(ddof=1))
    return out

from torch.utils.data import Dataset

class TensorDatasetWithNames(Dataset):
    def __init__(self, X, y, names):
        self.X = X
        self.y = y
        self.names = names

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.names[idx]


class TensorDatasetWithNamesComb(Dataset):
    def __init__(self, X, X2, y, names):
        self.X = X
        self.X2 = X2
        self.y = y
        self.names = names

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.X2[idx], self.y[idx], self.names[idx]

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


print("✅ Training functions defined")

In [None]:
myDatasets[test_idx]

In [None]:
y_test_tensor

In [None]:


def run_mel_experiment(seed: int):
    set_seed(seed)

    # Datasets
    train_dataset_mel = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset_mel   = TensorDataset(X_val_tensor, y_val_tensor)
    test_dataset_mel  = TensorDatasetWithNames(
        X_test_tensor,
        y_test_tensor,
        myDatasets[test_idx]
    )

    # IMPORTANT: seed-controlled shuffling
    g = torch.Generator()
    g.manual_seed(seed)

    train_loader_mel = DataLoader(
        train_dataset_mel,
        batch_size=256,
        shuffle=True,
        generator=g,
        num_workers=4,
        pin_memory=True
    )
    val_loader_mel = DataLoader(
        val_dataset_mel,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    test_loader_mel = DataLoader(
        test_dataset_mel,
        batch_size=256,
        shuffle=False
    )

    # Train (re-init model every seed)
    model_mel = MelSpectrogramCNN().to(device)
    model_mel, history_mel = train_model(
        model_mel,
        train_loader_mel,
        val_loader_mel,
        num_epochs=60,
        model_name=f"model_mel_seed{seed}",
        device=device
    )

    # Evaluate
    results_mel = evaluate_model(
        model_mel,
        test_loader_mel,
        f"best_model_mel_seed{seed}.pth",
        device
    )

    print(results_mel)

    return results_mel



seeds = [0, 1, 2, 3, 4]
all_results = []

for s in seeds:
    print(f"\n{'='*20} Seed {s} {'='*20}")
    all_results.append(run_mel_experiment(s))

In [None]:
summary = summarize_with_f1(all_results)
ds_summary = summarize_acc_per_dataset(all_results)
print(f"AUC: {summary['auc'][0]:.4f} ± {summary['auc'][1]:.4f}")
print(f"Macro-F1: {summary['macro_f1'][0]:.4f} ± {summary['macro_f1'][1]:.4f}")
print(f"Acc: {summary['acc'][0]:.4f} ± {summary['acc'][1]:.4f}")
print(f"Top-3: {summary['top3_acc'][0]:.4f} ± {summary['top3_acc'][1]:.4f}")
for ds, (mu, sd) in ds_summary.items():
    print(f"{ds}: {mu:.4f} ± {sd:.4f}")

In [None]:
def run_topo_experiment(seed: int):
    set_seed(seed)

    train_dataset_topo = TensorDataset(X_train2_tensor, y_train_tensor)
    val_dataset_topo   = TensorDataset(X_val2_tensor, y_val_tensor)
    test_dataset_topo  = TensorDatasetWithNames(
        X_test2_tensor,
        y_test_tensor,
        myDatasets[test_idx]
    )

    g = torch.Generator()
    g.manual_seed(seed)

    train_loader_topo = DataLoader(
        train_dataset_topo,
        batch_size=256,
        shuffle=True,
        generator=g,
        num_workers=4,
        pin_memory=True
    )
    val_loader_topo = DataLoader(
        val_dataset_topo,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    test_loader_topo = DataLoader(
        test_dataset_topo,
        batch_size=256,
        shuffle=False
    )

    # Train (re-init every seed)
    model_topo = TopologicalCNN().to(device)
    model_topo, history_topo = train_model(
        model_topo,
        train_loader_topo,
        val_loader_topo,
        num_epochs=60,
        lr=1e-4,
        model_name=f"model_topo_seed{seed}",
        device=device
    )

    results_topo = evaluate_model(
        model_topo,
        test_loader_topo,
        f"best_model_topo_seed{seed}.pth",
        device
    )

    return results_topo



seeds = [0, 1, 2, 3, 4]
all_results_topo = []

for s in seeds:
    print(f"\n{'='*20} TOPO Seed {s} {'='*20}")
    all_results_topo.append(run_topo_experiment(s))

In [None]:
summary = summarize_with_f1(all_results_topo)
ds_summary = summarize_acc_per_dataset(all_results_topo)
print(f"AUC: {summary['auc'][0]:.4f} ± {summary['auc'][1]:.4f}")
print(f"Macro-F1: {summary['macro_f1'][0]:.4f} ± {summary['macro_f1'][1]:.4f}")
print(f"Acc: {summary['acc'][0]:.4f} ± {summary['acc'][1]:.4f}")
print(f"Top-3: {summary['top3_acc'][0]:.4f} ± {summary['top3_acc'][1]:.4f}")
for ds, (mu, sd) in ds_summary.items():
    print(f"{ds}: {mu:.4f} ± {sd:.4f}")

In [None]:
def run_fusion_experiment(seed: int):
    set_seed(seed)

    # Datasets
    train_dataset_comb = TensorDataset(X_train_tensor, X_train2_tensor, y_train_tensor)
    val_dataset_comb   = TensorDataset(X_val_tensor, X_val2_tensor, y_val_tensor)
    test_dataset_comb  = TensorDatasetWithNamesComb(
        X_test_tensor,
        X_test2_tensor,
        y_test_tensor,
        myDatasets[test_idx]
    )

    # Seed-controlled shuffling
    g = torch.Generator()
    g.manual_seed(seed)

    train_loader_comb = DataLoader(
        train_dataset_comb,
        batch_size=256,
        shuffle=True,
        generator=g,
        num_workers=4,
        pin_memory=True
    )
    val_loader_comb = DataLoader(
        val_dataset_comb,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    test_loader_comb = DataLoader(
        test_dataset_comb,
        batch_size=256,
        shuffle=False
    )

    # Train (re-init every seed)
    model_comb = CombinedFusionModel().to(device)
    model_comb, history_comb = train_model(
        model_comb,
        train_loader_comb,
        val_loader_comb,
        num_epochs=60,
        model_name=f"model_combined_seed{seed}",   # avoid overwrite
        device=device
    )

    # Evaluate (seed-specific checkpoint)
    results_comb = evaluate_model(
        model_comb,
        test_loader_comb,
        f"best_model_combined_seed{seed}.pth",     # ensure train_model saves this
        device
    )

    return results_comb


# ---- Run multiple seeds and summarize ----
seeds = [0, 1, 2, 3, 4]   # N=5
all_results_comb = []

for s in seeds:
    print(f"\n{'='*20} FUSION Seed {s} {'='*20}")
    all_results_comb.append(run_fusion_experiment(s))

In [None]:
summary = summarize_with_f1(all_results_comb)
ds_summary = summarize_acc_per_dataset(all_results_comb)
print(f"AUC: {summary['auc'][0]:.4f} ± {summary['auc'][1]:.4f}")
print(f"Macro-F1: {summary['macro_f1'][0]:.4f} ± {summary['macro_f1'][1]:.4f}")
print(f"Acc: {summary['acc'][0]:.4f} ± {summary['acc'][1]:.4f}")
print(f"Top-3: {summary['top3_acc'][0]:.4f} ± {summary['top3_acc'][1]:.4f}")
for ds, (mu, sd) in ds_summary.items():
    print(f"{ds}: {mu:.4f} ± {sd:.4f}")

In [None]:
# Interpretability: Integrated Gradients per Channel
from pathlib import Path

def compute_channel_importance_ig(model, test_loader, device, save_dir="attr_summary"):
    """
    Compute per-channel attribution using Integrated Gradients.
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    ig = IntegratedGradients(model)
    model.eval()

    acc_global = None
    count_global = 0

    for X_batch, y_batch in test_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        baseline = torch.zeros_like(X_batch)

        # Compute IG
        attr = ig.attribute(X_batch, baselines=baseline, target=y_batch)
        batch_attr = attr.abs().sum(dim=0).detach().cpu()  # (C, H, W)

        acc_global = batch_attr if acc_global is None else acc_global + batch_attr
        count_global += X_batch.size(0)

    mean_attr_global = acc_global / max(count_global, 1)
    C, H, W = mean_attr_global.shape

    # Per-channel importance scores
    per_channel_scores = mean_attr_global.view(C, -1).mean(dim=1).numpy()

    print("\nPer-Channel Importance (Integrated Gradients):")
    channel_names = [
        'Euclidean (H0)',
        'Euclidean (H1)',
        'Temporal (H0)',
        'Temporal (H1)',
        'Wasserstein (H0)',
        'Wasserstein (H1)'
    ]
    for c, (name, score) in enumerate(zip(channel_names, per_channel_scores)):
        print(f"Channel {c} ({name}): {score:.6f}")

    # Save heatmaps
    def save_heat(heat2d, title, path):
        h = (heat2d - heat2d.min()) / (heat2d.max() - heat2d.min() + 1e-8)
        plt.figure(figsize=(4.5, 4))
        plt.imshow(h, cmap="jet")
        plt.title(title)
        plt.axis("off")
        plt.colorbar()
        plt.tight_layout()
        plt.savefig(path, dpi=300, bbox_inches='tight')
        plt.close()

    for c in range(C):
        heat = mean_attr_global[c].numpy()
        save_heat(heat, f"IG Attribution — {channel_names[c]}",
                  save_dir / f"ig_channel_{c}.png")

    # Bar chart
    plt.figure(figsize=(10, 5))
    bars = plt.bar(range(C), per_channel_scores)
    plt.xlabel("Channel", fontsize=12)
    plt.ylabel("Mean |IG Attribution|", fontsize=12)
    plt.title("Per-Channel Importance (Integrated Gradients)", fontsize=14)
    plt.xticks(range(C), [f"Ch{c}\n{name}" for c, name in enumerate(channel_names)],
               rotation=45, ha='right', fontsize=9)

    # Color bars by metric type
    colors = ['#1f77b4', '#1f77b4', '#ff7f0e', '#ff7f0e', '#2ca02c', '#2ca02c']
    for bar, color in zip(bars, colors):
        bar.set_color(color)

    plt.tight_layout()
    plt.savefig(save_dir / "channel_importance_ig.png", dpi=300, bbox_inches='tight')
    plt.show()

    np.save(save_dir / "per_channel_scores_ig.npy", per_channel_scores)
    return per_channel_scores

test_dataset_topo  = TensorDataset(
    X_test2_tensor,
    y_test_tensor
)

test_loader_topo = DataLoader(
    test_dataset_topo,
    batch_size=256,
    shuffle=False
)

model_topo = TopologicalCNN().to(device)
# Run IG analysis on topological
print("\n" + "="*60)
print("Interpretability Analysis: Integrated Gradients")
print("="*60)
model_topo.load_state_dict(torch.load("best_model_topo_seed0.pth"))
channel_scores_ig = compute_channel_importance_ig(model_topo, test_loader_topo, device)

In [None]:
# Occlusion analysis per channel
def compute_channel_importance_occlusion(model, test_loader, device,
                                         patch_size=8, stride=4,
                                         save_dir="attr_summary"):
    """
    Compute per-channel importance using Occlusion method.
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    occ = Occlusion(model)
    model.eval()

    # Sample a few examples for visualization
    X_sample, y_sample = next(iter(test_loader))
    X_sample = X_sample[:5].to(device)  # First 5 samples
    y_sample = y_sample[:5].to(device)

    channel_names = [
        'Euclidean (clean)',
        'Euclidean (noise)',
        'Temporal (clean)',
        'Temporal (noise)',
        'Wasserstein (clean)',
        'Wasserstein (noise)'
    ]

    C = X_sample.shape[1]
    attr_per_channel = []

    for sample_idx in range(len(X_sample)):
        x = X_sample[sample_idx:sample_idx+1]
        target = int(y_sample[sample_idx])

        for c in range(C):
            attr = occ.attribute(
                x,
                target=target,
                sliding_window_shapes=(1, patch_size, patch_size),
                strides=(1, stride, stride),
                baselines=0
            )
            # Extract channel c attribution
            heat = attr[:, c:c+1, :, :].abs().mean(dim=1)[0].detach().cpu().numpy()
            attr_per_channel.append(heat)

            # Save individual sample heatmap
            plt.figure(figsize=(5, 4))
            plt.imshow(heat, cmap='jet')
            plt.title(f"Sample {sample_idx} - {channel_names[c]}")
            plt.colorbar()
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(save_dir / f"occ_sample{sample_idx}_ch{c}.png", dpi=150)
            plt.close()

    print("\n✅ Occlusion analysis complete. Heatmaps saved to", save_dir)
    return attr_per_channel


# Run Occlusion analysis
if 'model_topo' in locals() and 'test_loader_topo' in locals():
    print("\n" + "="*60)
    print("Interpretability Analysis: Occlusion")
    print("="*60)
    model_topo.load_state_dict(torch.load("best_model_topo_seed0.pth"))
    occlusion_attrs = compute_channel_importance_occlusion(
        model_topo, test_loader_topo, device, patch_size=8, stride=4
    )

In [None]:
# Extract and visualize embeddings
def visualize_embeddings(model, test_loader, device):
    """
    Extract 64-dim embeddings and project to 2D using UMAP.
    """
    import umap.umap_ as umap

    model.eval()
    embs, labs = [], []

    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            z = model(X_batch, return_embedding=True)  # (B, 64)
            embs.append(z.cpu().numpy())
            labs.append(y_batch.numpy())

    X_emb = np.concatenate(embs, axis=0)
    y_emb = np.concatenate(labs, axis=0)

    # UMAP projection
    print("Computing UMAP projection...")
    X_2d = umap.UMAP(random_state=42, n_neighbors=15, min_dist=0.1).fit_transform(X_emb)

    # Plot
    emotions = ['neutral', 'happy', 'sad', 'angry', 'fearful', 'disgust']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']

    plt.figure(figsize=(10, 8))
    for c, (name, color) in enumerate(zip(emotions, colors)):
        mask = (y_emb == c)
        plt.scatter(X_2d[mask, 0], X_2d[mask, 1], s=20, alpha=0.6,
                    label=name, color=color)

    plt.title("UMAP Projection of Topological Embeddings", fontsize=14)
    plt.xlabel("UMAP 1", fontsize=12)
    plt.ylabel("UMAP 2", fontsize=12)
    plt.legend(loc='best', fontsize=10)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('embedding_umap.png', dpi=300, bbox_inches='tight')
    plt.show()


    print("\n✅ Embedding visualization saved: embedding_umap.png")
    return X_2d, y_emb


# Visualize topological embeddings
if 'model_topo' in locals() and 'test_loader_topo' in locals():
    print("\n" + "="*60)
    print("Embedding Visualization")
    print("="*60)
    model_topo.load_state_dict(torch.load("best_model_topo_seed0.pth"))
    X_2d, y_2d = visualize_embeddings(model_topo, test_loader_topo, device)

In [None]:
class FlexibleTopologicalCNN(nn.Module):
    def __init__(self, num_channels=6, num_classes=6):
        super(FlexibleTopologicalCNN, self).__init__()
        self.num_channels = num_channels
        self.features = nn.Sequential(
            nn.Conv2d(num_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.AvgPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(0.2),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


def train_ablation_model(X_train, X_val, X_test, y_train, y_val, y_test,
                         channel_indices, model_name, device,
                         num_epochs=30, lr=1e-3, verbose=False):

    X_train_sub = X_train[:, channel_indices, :, :]
    X_val_sub = X_val[:, channel_indices, :, :]
    X_test_sub = X_test[:, channel_indices, :, :]


    train_dataset = TensorDataset(X_train_sub, y_train)
    val_dataset = TensorDataset(X_val_sub, y_val)
    test_dataset = TensorDataset(X_test_sub, y_test)

    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=256)
    test_loader = DataLoader(test_dataset, batch_size=256)

    model = FlexibleTopologicalCNN(num_channels=len(channel_indices)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    auroc = MulticlassAUROC(num_classes=6, average='macro').to(device)

    best_val_acc = 0.0
    best_state = None


    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                preds = torch.softmax(outputs, dim=1)
                val_preds.append(preds)
                val_labels.append(y_batch)

        val_preds = torch.cat(val_preds)
        val_labels = torch.cat(val_labels)

        val_auc = auroc(val_preds, val_labels).item()
        y_pred = torch.argmax(val_preds, dim=1)
        accuracy = (y_pred == val_labels).float().mean().item()


        if accuracy > best_val_acc:
            best_val_acc = accuracy
            best_state = model.state_dict().copy()

        if verbose and (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1}/{num_epochs} - val_auc: {val_auc:.4f}")


    model.load_state_dict(best_state)
    model.eval()

    test_preds, test_labels = [], []
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            preds = torch.softmax(outputs, dim=1)
            test_preds.append(preds)
            test_labels.append(y_batch)

    test_preds = torch.cat(test_preds)
    test_labels = torch.cat(test_labels)

    test_auc = auroc(test_preds, test_labels).item()
    test_acc = (torch.argmax(test_preds, dim=1) == test_labels).float().mean().item()

    return {
        'auc': test_auc,
        'acc': test_acc,
        'val_acc': best_val_acc
    }


print("✅ Ablation training functions defined")

In [None]:
from tqdm import tqdm
# Run ablation experiments
if 'X_train2_tensor' in locals():
    print("="*60)
    print("Running Ablation Study")
    print("="*60)
    SEEDS = [0, 1, 2, 3, 4]

    # Define ablation configurations
    ablation_configs = {
        # Baseline
        # 'All 6 channels': [0, 1, 2, 3, 4, 5],

        # Individual channels
        'Euclidean H0': [0],
        'Euclidean H1': [1],
        'Temporal H0': [2],
        'Temporal H1': [3],
        'Wasserstein H0': [4],
        'Wasserstein H1': [5],

        # Metric groups
        'Euclidean only': [0, 1],
        'Temporal only': [2, 3],
        'Wasserstein only': [4, 5],

        # Leave-one-out (remove metric groups)
        'Without Euclidean': [2, 3, 4, 5],
        'Without Temporal': [0, 1, 4, 5],
        'Without Wasserstein': [0, 1, 2, 3],
    }

    ablation_results = {}

    ablation_results = {}

for config_name, channels in ablation_configs.items():
    print(f"\n{config_name} (channels {channels})...")

    seed_results = []

    for seed in tqdm(SEEDS):
            # Set seeds
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)

            results = train_ablation_model(
                X_train2_tensor, X_val2_tensor, X_test2_tensor,
                y_train_tensor, y_val_tensor, y_test_tensor,
                channel_indices=channels,
                model_name=f"ablation_{config_name.replace(' ', '_')}_seed{seed}",
                device=device,
                num_epochs=60,
                lr=1e-4,
                verbose=False
            )

            seed_results.append(results)

    # Aggregate metrics
    aucs = np.array([r['auc'] for r in seed_results], dtype=float)
    accs = np.array([r['acc'] for r in seed_results], dtype=float)

    ablation_results[config_name] = {
        'auc_mean': aucs.mean(),
        'auc_std': aucs.std(ddof=1),
        'acc_mean': accs.mean(),
        'acc_std': accs.std(ddof=1),
        'num_seeds': len(SEEDS)
    }

    print(
        f"  ✅ Test AUC: {aucs.mean():.4f} ± {aucs.std(ddof=1):.4f} | "
        f"Test Acc: {accs.mean():.4f} ± {accs.std(ddof=1):.4f}"
    )

    print("\n" + "="*60)
    print("Ablation Study Complete (all configs, multi-seed)!")
    print("="*60)

else:
    print("⚠️ Training data not available. Please run previous cells first.")

In [None]:
# interpret_persistence_images.py
# Requirements:
#   pip install captum matplotlib numpy torch
#
# Assumptions:
#   - Your model: model(pi) -> logits [B, num_classes]
#   - Your dataloader yields: (pi, label, *optional_metadata)
#       pi: float tensor [B, 6, 32, 32]
#       label: long tensor [B]
#
# If you have a fusion model, wrap it so that forward(pi) returns logits.

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from captum.attr import IntegratedGradients

import matplotlib.pyplot as plt


CHANNEL_NAMES = [
    "euclid_H0", "euclid_H1",
    "temporal_H0", "temporal_H1",
    "wasserstein_H0", "wasserstein_H1",
]


@dataclass
class IGResults:
    # Means are normalized by sample count
    mean_abs_attr: torch.Tensor        # [C,H,W]
    mean_abs_attr_by_class: torch.Tensor  # [K,C,H,W]
    mean_pi_by_class: torch.Tensor        # [K,C,H,W]
    mean_abs_attr_channel: torch.Tensor   # [C]
    mean_abs_attr_channel_by_class: torch.Tensor  # [K,C]
    counts_by_class: torch.Tensor         # [K]
    num_samples: int
    num_classes: int


@torch.no_grad()
def _infer_num_classes(model: nn.Module, device: torch.device, c=6, h=32, w=32) -> int:
    x = torch.zeros(1, c, h, w, device=device)
    y = model(x)
    return int(y.shape[-1])


def compute_integrated_gradients(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    num_classes: Optional[int] = None,
    n_steps: int = 64,
    target_mode: str = "true",   # "true" or "pred"
    max_batches: Optional[int] = None,
) -> IGResults:
    model = model.to(device).eval()

    if num_classes is None:
        num_classes = _infer_num_classes(model, device)

    ig = IntegratedGradients(model)

    # Running sums
    sum_abs_attr = torch.zeros(6, 32, 32, device=device)
    sum_abs_attr_by_class = torch.zeros(num_classes, 6, 32, 32, device=device)
    sum_pi_by_class = torch.zeros(num_classes, 6, 32, 32, device=device)

    sum_abs_attr_channel = torch.zeros(6, device=device)
    sum_abs_attr_channel_by_class = torch.zeros(num_classes, 6, device=device)

    counts_by_class = torch.zeros(num_classes, device=device)
    total_samples = 0

    for b_idx, batch in enumerate(dataloader):
        if max_batches is not None and b_idx >= max_batches:
            break

        # Accept (pi, label) or (pi, label, meta...)
        pi = batch[0].to(device).float()
        y = batch[1].to(device).long()
        bsz = pi.shape[0]

        # Choose attribution target
        with torch.no_grad():
            logits = model(pi)
            pred = torch.argmax(logits, dim=1)

        if target_mode == "true":
            target = y
        elif target_mode == "pred":
            target = pred

        pi_req = pi.detach().clone().requires_grad_(True)
        baseline = torch.zeros_like(pi_req)

        attr = ig.attribute(
            inputs=pi_req,
            baselines=baseline,
            target=target,
            n_steps=n_steps,
            internal_batch_size=min(32, bsz),
        )

        abs_attr = attr.abs()

        sum_abs_attr += abs_attr.sum(dim=0) 
        sum_abs_attr_channel += abs_attr.sum(dim=(0, 2, 3)) 

        for k in range(num_classes):
            mask = (y == k)
            if mask.any():
                m = mask.nonzero(as_tuple=False).squeeze(1)
                counts_by_class[k] += float(m.numel())
                sum_abs_attr_by_class[k] += abs_attr[m].sum(dim=0)
                sum_pi_by_class[k] += pi[m].sum(dim=0)
                sum_abs_attr_channel_by_class[k] += abs_attr[m].sum(dim=(0, 2, 3))

        total_samples += bsz


    mean_abs_attr = sum_abs_attr / max(total_samples, 1)
    mean_abs_attr_channel = sum_abs_attr_channel / max(total_samples * 32 * 32, 1)

    mean_abs_attr_by_class = torch.zeros_like(sum_abs_attr_by_class)
    mean_pi_by_class = torch.zeros_like(sum_pi_by_class)
    mean_abs_attr_channel_by_class = torch.zeros_like(sum_abs_attr_channel_by_class)

    for k in range(num_classes):
        ck = int(counts_by_class[k].item())
        if ck > 0:
            mean_abs_attr_by_class[k] = sum_abs_attr_by_class[k] / ck
            mean_pi_by_class[k] = sum_pi_by_class[k] / ck
            mean_abs_attr_channel_by_class[k] = sum_abs_attr_channel_by_class[k] / (ck * 32 * 32)

    return IGResults(
        mean_abs_attr=mean_abs_attr.detach().cpu(),
        mean_abs_attr_by_class=mean_abs_attr_by_class.detach().cpu(),
        mean_pi_by_class=mean_pi_by_class.detach().cpu(),
        mean_abs_attr_channel=mean_abs_attr_channel.detach().cpu(),
        mean_abs_attr_channel_by_class=mean_abs_attr_channel_by_class.detach().cpu(),
        counts_by_class=counts_by_class.detach().cpu(),
        num_samples=total_samples,
        num_classes=num_classes,
    )


def plot_channel_importance(mean_abs_attr_channel: torch.Tensor, title: str = "Per-channel IG importance"):
    vals = mean_abs_attr_channel.numpy()
    x = np.arange(len(vals))
    plt.figure()
    plt.bar(x, vals)
    plt.xticks(x, CHANNEL_NAMES, rotation=35, ha="right")
    plt.ylabel("Mean |IG attribution| per pixel")
    plt.title(title)
    plt.tight_layout()
    plt.show()


def plot_attr_heatmaps(mean_abs_attr: torch.Tensor, max_channels: int = 6, title_prefix: str = "Mean |IG|"):
    C = min(mean_abs_attr.shape[0], max_channels)
    for c in range(C):
        plt.figure()
        plt.imshow(mean_abs_attr[c].numpy(), origin="lower", aspect="auto")
        plt.colorbar()
        plt.title(f"{title_prefix}: {CHANNEL_NAMES[c]}")
        plt.xlabel("PI x-bin")
        plt.ylabel("PI y-bin")
        plt.tight_layout()
        plt.show()


def plot_class_prototypes(
    mean_pi_by_class: torch.Tensor,
    mean_abs_attr_by_class: torch.Tensor,
    class_names: List[str],
    class_idx: int,
    channels: Optional[List[int]] = None,
):
    if channels is None:
        channels = list(range(mean_pi_by_class.shape[1]))

    for c in channels:
        plt.figure()
        plt.imshow(mean_pi_by_class[class_idx, c].numpy(), origin="lower", aspect="auto")
        plt.colorbar()
        plt.title(f"Mean PI prototype — class={class_names[class_idx]} — {CHANNEL_NAMES[c]}")
        plt.tight_layout()
        plt.show()

        plt.figure()
        plt.imshow(mean_abs_attr_by_class[class_idx, c].numpy(), origin="lower", aspect="auto")
        plt.colorbar()
        plt.title(f"Mean |IG| — class={class_names[class_idx]} — {CHANNEL_NAMES[c]}")
        plt.tight_layout()
        plt.show()


def plot_difference_map(
    tensor_by_class: torch.Tensor,
    class_names: List[str],
    a: int,
    b: int,
    channel: int,
    title: str = "Difference map",
):
    diff = tensor_by_class[a, channel] - tensor_by_class[b, channel]
    plt.figure()
    plt.imshow(diff.numpy(), origin="lower", aspect="auto",
           vmin=200, vmax=-200)
    plt.colorbar()
    plt.title(f"{title}: {class_names[a]} - {class_names[b]} — {CHANNEL_NAMES[channel]}")
    plt.tight_layout()
    plt.show()


In [None]:

model_topo = TopologicalCNN().to(device)
model_topo.load_state_dict(torch.load("best_model_topo_seed4.pth"))

test_dataset_topo  = TensorDataset(
    X_test2_tensor,
    y_test_tensor
)

test_loader_topo = DataLoader(
    test_dataset_topo,
    batch_size=256,
    shuffle=False
)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

results = compute_integrated_gradients(
    model=model_topo,
    dataloader=test_loader_topo,
    device=device,
    num_classes=6,
    n_steps=64,
    target_mode="true",
)

In [None]:
class_names = ["neutral", "happy", "sad", "angry", "fearful", "disgust"]  # adjust if needed

# plot_difference_map(results.mean_pi_by_class, class_names, a=4, b=0, channel=2, title="PI prototype diff")
# plot_difference_map(results.mean_pi_by_class, class_names, a=2, b=0, channel=2, title="PI prototype diff")
# plot_difference_map(results.mean_pi_by_class, class_names, a=5, b=0, channel=2, title="PI prototype diff")

plot_difference_map(results.mean_abs_attr_by_class, class_names, a=4, b=0, channel=2, title="PI prototype diff")
plot_difference_map(results.mean_abs_attr_by_class, class_names, a=2, b=0, channel=2, title="PI prototype diff")
plot_difference_map(results.mean_abs_attr_by_class, class_names, a=5, b=0, channel=2, title="PI prototype diff")
