In [None]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
from gtda.homology import VietorisRipsPersistence
from persim import wasserstein

DATA_DIR = "./abide"

tda = VietorisRipsPersistence(metric="precomputed", homology_dimensions=[0, 1, 2])

reference_diagram = None
wasserstein_features = []

subject_dirs = [os.path.join(DATA_DIR, d) for d in os.listdir(DATA_DIR)
                if os.path.isdir(os.path.join(DATA_DIR, d))]

total_subjects = len(subject_dirs)
print(f"Processing {total_subjects} subjects...")

for idx, subject_dir in enumerate(subject_dirs, start=1):
    subject_id = os.path.basename(subject_dir)
    print(f"\n[{idx}/{total_subjects}] Processing subject: {subject_id}")
    
    file_list = [f for f in os.listdir(subject_dir) if "AAL116_correlation_matrix.mat" in f]
    if not file_list:
        print(f"  No correlation matrix found for {subject_id}. Skipping.")
        continue
    
    file_path = os.path.join(subject_dir, file_list[0])
    try:
        mat_data = sio.loadmat(file_path)
        corr_matrix = mat_data['data']
    except Exception as e:
        print(f"  Error processing {file_path} for subject {subject_id}: {e}")
        continue

    distance_matrix = 1 - np.abs(corr_matrix)
    
    diagrams = tda.fit_transform([distance_matrix])
    current_diagram = diagrams[0]
    current_diag_dim1 = current_diagram[current_diagram[:, 2] == 1][:, :2]
    
    if reference_diagram is None:
        reference_diagram = current_diag_dim1
        wasserstein_distance = 0.0
    else:
        wasserstein_distance = wasserstein(reference_diagram, current_diag_dim1, matching=False)
    
    wasserstein_features.append([subject_id, wasserstein_distance])

wasserstein_df = pd.DataFrame(wasserstein_features, columns=["Subject_ID", "Wasserstein_Distance"])
wasserstein_df.to_csv("wasserstein_features.csv", index=False)
print("\nWasserstein features saved to 'wasserstein_features.csv'")

# This throws a whole ton of deprecation warnings for ensure_all_finite... from sklearn apparently??
# Not quite sure where it's coming from