In [None]:
import pandas as pd
import os
os.environ['OMP_NUM_THREADS'] = '4'
#os.environ['OMP_NUM_THREADS'] = '1'
from sklearn.decomposition import PCA, FastICA, NMF
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
from contrastive import CPCA
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from scipy.linalg import eigh
from numpy.linalg import pinv, norm
from scipy.stats import kurtosis
from itertools import combinations
import time
from sklearn.impute import SimpleImputer
import plotly.express as px
import plotly.graph_objects as go
import kaleido
from sklearn.cluster import KMeans
import warnings
from sklearn.exceptions import ConvergenceWarning

In [None]:
#cPCA with Frobenius error

sns.set_theme(style="whitegrid")

# Create output directory for plots if it doesn't exist
output_dir = "output_plots"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Assemble datasets into a dictionary.
# Variables autismcombinedcommon, twitch, accent, dim512, and dim1024 should be defined beforehand.
datasets = {
    "Combined Autism": autismcombinedcommon,  # expected dummy columns: Group_Adult, Group_Child, hidden: Group_Adolescent
    "Twitch": twitch,
    "Accent": accent,                          # expected dummy columns: language_US, language_FR, language_IT, language_GE, language_UK, hidden: language_ES
    "Dim512": dim512,
    "Dim1024": dim1024
}

####################################
# Helper functions for group labeling
####################################
def get_autism_group(row):
    if row.get("Group_Adult", 0) == 1:
        return "Group_Adult"
    elif row.get("Group_Child", 0) == 1:
        return "Group_Child"
    else:
        return "Group_Adolescent"

def get_accent_group(row):
    if row.get("language_US", 0) == 1:
        return "language_US"
    elif row.get("language_FR", 0) == 1:
        return "language_FR"
    elif row.get("language_IT", 0) == 1:
        return "language_IT"
    elif row.get("language_GE", 0) == 1:
        return "language_GE"
    elif row.get("language_UK", 0) == 1:
        return "language_UK"
    else:
        return "language_ES"

####################################
# Custom regularized pseudo-inverse using SVD (with adjustable tolerance)
####################################
def regularized_pinv(A, tol=1e-1):
    U, s, Vh = np.linalg.svd(A, full_matrices=False)
    s_inv = np.array([1/x if x > tol else 0 for x in s])
    return (Vh.T * s_inv) @ U.T

####################################
# Function to tune over candidate n_components with fixed alpha.
####################################
def tune_cpca_n_components(X_target, X_background, fixed_alpha=2.0, max_components=None):
    if max_components is None:
        max_components = np.linalg.matrix_rank(np.cov(X_target, rowvar=False))
    best_n = 1
    best_error = np.inf
    best_eigvecs = None
    best_eigvals = None
    fro_errors = []
    n_components_list = []
    
    # Compute contrastive covariance matrices
    cov_target = np.cov(X_target, rowvar=False)
    cov_background = np.cov(X_background, rowvar=False)
    contrastive_cov = cov_target - fixed_alpha * cov_background
    
    # Eigen-decomposition (using eigh since covariance matrices are symmetric)
    eigvals, eigvecs = eigh(contrastive_cov)
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]
    
    for n in range(1, max_components + 1):
        X_proj = X_target @ eigvecs[:, :n]
        try:
            # Reconstruct target from projection
            X_reconstructed = X_proj @ pinv(eigvecs[:, :n], rcond=1e-5)
        except np.linalg.LinAlgError:
            fro_errors.append(np.inf)
            n_components_list.append(n)
            continue
        error = norm(X_target - X_reconstructed, 'fro')
        fro_errors.append(error)
        n_components_list.append(n)
        if error < best_error:
            best_n = n
            best_error = error
            best_eigvecs = eigvecs[:, :n]
            best_eigvals = eigvals[:n]
    return best_n, best_eigvecs, best_eigvals, best_error, fro_errors, n_components_list

####################################
# Function to select target and background splits.
####################################
def select_cpca_split(df, dataset_name):
    if dataset_name == "Combined Autism":
        target = df[df['Group_Adult'] == 1]
        background = df[df['Group_Adult'] != 1]
    elif dataset_name == "Accent":
        target = df[df['language_US'] == 1]
        background = df[df['language_US'] != 1]
    else:
        split_size = 0.3
        target = df.sample(frac=split_size, random_state=42)
        background = df.drop(target.index)
    return target, background

####################################
# Visualization Functions
####################################
def visualize_cpca(dataset_name, cpca_result, cpca_eigvals, error_type='Frobenius', group_labels=None, custom_colors=None):
    n_components = cpca_result.shape[1]
    # 3D Scatter Plot
    if n_components >= 3:
        df_proj = pd.DataFrame(cpca_result[:, :3], columns=[f'cPCA {i+1}' for i in range(3)])
        if group_labels is not None:
            df_proj['group'] = group_labels.to_numpy()
            fig = px.scatter_3d(df_proj, x='cPCA 1', y='cPCA 2', z='cPCA 3',
                                color='group',
                                title=f'3D Scatter Plot for {dataset_name} (cPCA {error_type})',
                                color_discrete_sequence=custom_colors)
        else:
            fig = px.scatter_3d(df_proj, x='cPCA 1', y='cPCA 2', z='cPCA 3',
                                title=f'3D Scatter Plot for {dataset_name} (cPCA {error_type})')
        fig.write_image(os.path.join(output_dir, f"{dataset_name}_cPCA_3D_{error_type}.png"))
        fig.show()
    elif n_components == 2:
        df_proj = pd.DataFrame(cpca_result[:, :2], columns=[f'cPCA {i+1}' for i in range(2)])
        if group_labels is not None:
            df_proj['group'] = group_labels.to_numpy()
            fig = px.scatter(df_proj, x='cPCA 1', y='cPCA 2',
                           title=f'2D Scatter Plot for {dataset_name} (cPCA {error_type})',
                           color='group', color_discrete_sequence=custom_colors)
        else:
            fig = px.scatter(df_proj, x='cPCA 1', y='cPCA 2',
                           title=f'2D Scatter Plot for {dataset_name} (cPCA {error_type})')
        fig.write_image(os.path.join(output_dir, f"{dataset_name}_cPCA_2D_{error_type}.png"))
        fig.show()
    elif n_components == 1:
        df_proj = pd.DataFrame(cpca_result[:, :1], columns=['cPCA 1'])
        fig = px.histogram(df_proj, x='cPCA 1',
                           title=f'Histogram for {dataset_name} (cPCA - 1 Component {error_type})')
        fig.write_image(os.path.join(output_dir, f"{dataset_name}_cPCA_histogram_{error_type}.png"))
        fig.show()
    
    # Scree Plot
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(cpca_eigvals) + 1), cpca_eigvals, marker='o')
    plt.xlabel("Component")
    plt.ylabel("Eigenvalue")
    plt.title(f"Scree Plot for {dataset_name} (cPCA {error_type})")
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_cPCA_scree_{error_type}.png"), bbox_inches='tight')
    plt.show()
    
    # 2D Pairwise Scatter Plots with Coloring
    if n_components >= 2:
        n = min(n_components, 5)
        pair_df = pd.DataFrame(cpca_result[:, :n], columns=[f'cPCA{i+1}' for i in range(n)])
        if group_labels is not None:
            pair_df['group'] = group_labels.to_numpy()
            pairplot = sns.pairplot(pair_df, hue='group', diag_kind="kde", palette=custom_colors)
        else:
            pairplot = sns.pairplot(pair_df, diag_kind="kde")
        plt.suptitle(f"{dataset_name}_cPCA_pairplot_{error_type}", y=1.02)
        pairplot.fig.savefig(os.path.join(output_dir, f"{dataset_name}_cPCA_pairplot_{error_type}.png"), bbox_inches='tight')
        plt.show()

def visualize_cpca_errors(dataset_name, fro_errors_dict, n_components_list):
    plt.figure(figsize=(12, 5))
    for alpha, error in fro_errors_dict.items():
        plt.plot(n_components_list, [error] * len(n_components_list), label=f'Frobenius α={alpha}', linestyle='--')
    plt.xlabel("Number of Components")
    plt.ylabel("Frobenius Reconstruction Error")
    plt.title(f"cPCA Errors vs Components for {dataset_name} (Frobenius)")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_cPCA_errors_Frobenius.png"), bbox_inches='tight')
    plt.show()

####################################
# New: Visualization Function for cPCA Heatmap of Feature Loadings
####################################
def visualize_cpca_heatmap(dataset_name, cpca_loadings, feature_names, error_type='Frobenius'):
    """
    Visualizes a heatmap of cPCA loadings.
    
    Parameters:
        dataset_name (str): Name of the dataset.
        cpca_loadings (np.ndarray): The eigenvectors (loadings) computed by cPCA.
                                    Expected shape is (n_features, n_components).
        feature_names (list or pd.Index): Names of the features (after dropping dummy columns).
        error_type (str): A label to indicate which error type was used (e.g., 'Frobenius').
    """
    # Transpose so that rows are components and columns are features.
    loadings_matrix = cpca_loadings.T  # now shape is (n_components, n_features)
    n_components = loadings_matrix.shape[0]
    n_to_plot = min(n_components, 5)  # Plot only the first 5 components for clarity

    plt.figure(figsize=(max(10, len(feature_names) * 0.4), 6))
    sns.heatmap(
        loadings_matrix[:n_to_plot, :],
        cmap='coolwarm',
        xticklabels=feature_names,
        yticklabels=[f'cPCA {i+1}' for i in range(n_to_plot)]
    )
    plt.title(f'{dataset_name} (cPCA {error_type}) Feature Loadings')
    plt.xlabel("Features")
    plt.ylabel("Component")
    
    filename = os.path.join(output_dir, f"{dataset_name}_cPCA_heatmap_{error_type}.png")
    plt.savefig(filename, bbox_inches='tight')
    plt.show()

####################################
# MAIN cPCA REDUCTION FUNCTION
####################################
def apply_cpca_reduction(dataset_name, df):
    print(f"\n--- Processing {dataset_name} ---")
    # Exclude dummy columns from cPCA computation.
    if dataset_name == "Accent":
        dummy_cols = [col for col in df.columns if col.startswith("language_")]
    else:
        dummy_cols = [col for col in df.columns if col.startswith("Group_")]
    X_full = df.drop(columns=dummy_cols).values

    # Branch based on dataset for target and background splitting.
    if dataset_name == "Combined Autism":
        target_df = df[df['Group_Adult'] == 1].copy()
        X_target = target_df.drop(columns=dummy_cols).values
        X_background = df[df['Group_Adult'] != 1].drop(columns=dummy_cols).values

        # Tune cPCA normally first:
        fixed_alpha = 2.6
        max_components = np.linalg.matrix_rank(np.cov(X_target, rowvar=False))
        best_n, best_eigvecs, best_eigvals, _, _, _ = tune_cpca_n_components(
            X_target, X_background, fixed_alpha=fixed_alpha, max_components=max_components)
    
        # Project target onto cPCA space to identify outliers.
        X_target_proj = X_target @ best_eigvecs
    
        # For example, identify outliers based on the 4th cPCA component (adjust threshold visually):
        non_outlier_indices = (X_target_proj[:, 3] > -0.5)
    
        print(f"Original adult count: {X_target.shape[0]}")
        X_target_filtered = X_target[non_outlier_indices]
        print(f"Filtered adult count: {X_target_filtered.shape[0]}")
    
        # Re-run cPCA tuning with filtered target.
        best_n, best_eigvecs, best_eigvals, cpca_fro_error, fro_errors_list, n_components_list = tune_cpca_n_components(
            X_target_filtered, X_background, fixed_alpha=fixed_alpha, max_components=max_components)
    
        print(f"  ✅ Filtered cPCA (target) selected {best_n} components with alpha = {fixed_alpha} and target error = {cpca_fro_error:.4f}")
        
        # Project the full dataset using the learned eigenvectors.
        cpca_result_full = X_full @ best_eigvecs
        X_reconstructed_full = cpca_result_full @ pinv(best_eigvecs, rcond=1e-5)
        full_error = norm(X_full - X_reconstructed_full, 'fro')
        # Compute full relative error based on X_full.
        full_relative_error = full_error / norm(X_full, 'fro')
        print(f"  Full dataset reconstruction error: {full_error:.4f}")
        
    elif dataset_name == "Accent":
        target_df = df[df['language_US'] == 1].copy()
        X_target = target_df.drop(columns=dummy_cols).values
        X_background = df[df['language_US'] != 1].drop(columns=dummy_cols).values
    
        # Tune cPCA normally first:
        fixed_alpha = 2.6
        max_components = np.linalg.matrix_rank(np.cov(X_target, rowvar=False))
        best_n, best_eigvecs, best_eigvals, _, _, _ = tune_cpca_n_components(
            X_target, X_background, fixed_alpha=fixed_alpha, max_components=max_components)
    
        # Project target onto cPCA space to identify outliers.
        X_target_proj = X_target @ best_eigvecs
    
        # For example, identify outliers based on the first and third cPCA components:
        non_outlier_indices = (X_target_proj[:, 0] > -0.1) & (X_target_proj[:, 2] > 0.3)
    
        print(f"Original US speaker count: {X_target.shape[0]}")
        X_target_filtered = X_target[non_outlier_indices]
        print(f"Filtered US speaker count: {X_target_filtered.shape[0]}")
    
        # Re-run cPCA tuning with filtered target.
        best_n, best_eigvecs, best_eigvals, cpca_fro_error, fro_errors_list, n_components_list = tune_cpca_n_components(
            X_target_filtered, X_background, fixed_alpha=fixed_alpha, max_components=max_components)
    
        print(f"  ✅ Filtered cPCA (target) selected {best_n} components with alpha = {fixed_alpha} and target error = {cpca_fro_error:.4f}")
        
        # Project the full dataset using the learned eigenvectors.
        cpca_result_full = X_full @ best_eigvecs
        X_reconstructed_full = cpca_result_full @ pinv(best_eigvecs, rcond=1e-5)
        full_error = norm(X_full - X_reconstructed_full, 'fro')
        full_relative_error = full_error / norm(X_full, 'fro')
        print(f"  Full dataset reconstruction error: {full_error:.4f}")
        
    else:
        target_df, background_df = select_cpca_split(df, dataset_name)
        X_target = target_df.drop(columns=dummy_cols).values
        X_background = background_df.drop(columns=dummy_cols).values
        
        fixed_alpha = 2.6
        max_components = np.linalg.matrix_rank(np.cov(X_target, rowvar=False))
        best_n, best_eigvecs, best_eigvals, cpca_fro_error, fro_errors_list, n_components_list = tune_cpca_n_components(
            X_target, X_background, fixed_alpha=fixed_alpha, max_components=max_components)
        print(f"  ✅ cPCA (target) selected {best_n} components with alpha = {fixed_alpha} and target error = {cpca_fro_error:.4f}")
    
        cpca_result_full = X_full @ best_eigvecs
        X_reconstructed_full = cpca_result_full @ pinv(best_eigvecs, rcond=1e-5)
        full_error = norm(X_full - X_reconstructed_full, 'fro')
        full_relative_error = full_error / norm(X_full, 'fro')
        print(f"  Full dataset reconstruction error: {full_error:.4f}")
    
    # Derive group labels for visualization.
    if dataset_name == "Accent":
        expected = ["language_US", "language_FR", "language_ES", "language_IT", "language_GE", "language_UK"]
        group_labels = df.apply(get_accent_group, axis=1)
        group_labels = pd.Categorical(group_labels, categories=expected)
        custom_colors = px.colors.qualitative.Plotly[:len(expected)]
    elif dataset_name == "Combined Autism":
        group_labels = df.apply(get_autism_group, axis=1)
        group_labels = pd.Categorical(group_labels, categories=["Group_Adult", "Group_Child", "Group_Adolescent"])
        custom_colors = px.colors.qualitative.Plotly[:3]
    else:
        group_labels, custom_colors = None, None

    # Visualize the full dataset projection.
    visualize_cpca(dataset_name, cpca_result_full, best_eigvals, error_type='Frobenius', group_labels=group_labels, custom_colors=custom_colors)
    
    # ---- New: Visualize the cPCA loadings as a heatmap ----
    feature_names = df.drop(columns=dummy_cols).columns
    visualize_cpca_heatmap(dataset_name, best_eigvecs, feature_names, error_type='Frobenius')
    
    # Append results for later summary.
    comparison_results.append({
        "Dataset": dataset_name,
        "cPCA Best Alpha (Frobenius)": round(fixed_alpha, 4),
        "cPCA Components (Frobenius)": best_n,
        "cPCA Frobenius Error (target)": round(cpca_fro_error, 4),
        "cPCA Full Error": round(full_error, 4),
        "cPCA Full Relative Error": round(full_relative_error, 4)
    })
    print(f"✅ Completed {dataset_name}")

####################################
# Process all datasets.
####################################
comparison_results = []
for dataset_name, df in datasets.items():
    apply_cpca_reduction(dataset_name, df)

# Final summary table.
results_df = pd.DataFrame(comparison_results)
print("\n--- Final cPCA Summary ---")
print(results_df)

# Save summary results to CSV.
results_df.to_csv("cpca_results_summary.csv", index=False)
print("\nAll datasets processed and results saved.")

In [None]:
#cPCA with L1 error

sns.set_theme(style="whitegrid")

# Create output directory for plots if it doesn't exist
output_dir = "output_plots"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Assemble datasets into a dictionary.
# Variables autismcombinedcommon, twitch, accent, dim512, and dim1024 should be defined beforehand.
datasets = {
    "Combined Autism": autismcombinedcommon,  # expected dummy columns: Group_Adult, Group_Child, hidden: Group_Adolescent
    "Twitch": twitch,
    "Accent": accent,                          # expected dummy columns: language_US, language_FR, language_IT, language_GE, language_UK, hidden: language_ES
    "Dim512": dim512,
    "Dim1024": dim1024
}

####################################
# Helper functions for group labeling for visualization
####################################
def get_autism_group(row):
    if row.get("Group_Adult", 0) == 1:
        return "Group_Adult"
    elif row.get("Group_Child", 0) == 1:
        return "Group_Child"
    else:
        return "Group_Adolescent"

def get_accent_group(row):
    if row.get("language_US", 0) == 1:
        return "language_US"
    elif row.get("language_FR", 0) == 1:
        return "language_FR"
    elif row.get("language_IT", 0) == 1:
        return "language_IT"
    elif row.get("language_GE", 0) == 1:
        return "language_GE"
    elif row.get("language_UK", 0) == 1:
        return "language_UK"
    else:
        return "language_ES"

####################################
# Custom regularized pseudo-inverse using SVD (with adjustable tolerance)
####################################
def regularized_pinv(A, tol=1e-1):
    U, s, Vh = np.linalg.svd(A, full_matrices=False)
    s_inv = np.array([1/x if x > tol else 0 for x in s])
    return (Vh.T * s_inv) @ U.T

####################################
# Function: tune over candidate n_components with fixed alpha using L1 error.
####################################
def tune_cpca_n_components(X_target, X_background, fixed_alpha=2.0, max_components=None):
    if max_components is None:
        max_components = np.linalg.matrix_rank(np.cov(X_target, rowvar=False))
    best_n = 1
    best_error = np.inf
    best_eigvecs = None
    best_eigvals = None
    L1_errors = []
    n_components_list = []
    
    # Compute contrastive covariance matrices
    cov_target = np.cov(X_target, rowvar=False)
    cov_background = np.cov(X_background, rowvar=False)
    contrastive_cov = cov_target - fixed_alpha * cov_background
    
    # Eigen-decomposition (using eigh since covariance is symmetric)
    eigvals, eigvecs = eigh(contrastive_cov)
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]
    
    for n in range(1, max_components + 1):
        X_proj = X_target @ eigvecs[:, :n]
        try:
            # Reconstruct target from projection
            X_reconstructed = X_proj @ pinv(eigvecs[:, :n], rcond=1e-5)
        except np.linalg.LinAlgError:
            L1_errors.append(np.inf)
            n_components_list.append(n)
            continue
        error = np.sum(np.abs(X_target - X_reconstructed))
        L1_errors.append(error)
        n_components_list.append(n)
        if error < best_error:
            best_n = n
            best_error = error
            best_eigvecs = eigvecs[:, :n]
            best_eigvals = eigvals[:n]
    return best_n, best_eigvecs, best_eigvals, best_error, L1_errors, n_components_list

####################################
# Function to select target and background splits.
####################################
def select_cpca_split(df, dataset_name):
    if dataset_name == "Combined Autism":
        target = df[df['Group_Adult'] == 1]
        background = df[df['Group_Adult'] != 1]
    elif dataset_name == "Accent":
        target = df[df['language_US'] == 1]
        background = df[df['language_US'] != 1]
    else:
        split_size = 0.3
        target = df.sample(frac=split_size, random_state=42)
        background = df.drop(target.index)
    return target, background

####################################
# Visualization Functions
####################################
def visualize_cpca(dataset_name, cpca_result, cpca_eigvals, error_type='L1', group_labels=None, custom_colors=None):
    n_components = cpca_result.shape[1]
    # 3D Scatter Plot
    if n_components >= 3:
        df_proj = pd.DataFrame(cpca_result[:, :3], columns=[f'cPCA {i+1}' for i in range(3)])
        if group_labels is not None:
            df_proj['group'] = group_labels.to_numpy()
            fig = px.scatter_3d(df_proj, x='cPCA 1', y='cPCA 2', z='cPCA 3',
                                color='group',
                                title=f'3D Scatter Plot for {dataset_name} (cPCA {error_type})',
                                color_discrete_sequence=custom_colors)
        else:
            fig = px.scatter_3d(df_proj, x='cPCA 1', y='cPCA 2', z='cPCA 3',
                                title=f'3D Scatter Plot for {dataset_name} (cPCA {error_type})')
        fig.write_image(os.path.join(output_dir, f"{dataset_name}_cPCA_3D_{error_type}.png"))
        fig.show()
    elif n_components == 2:
        df_proj = pd.DataFrame(cpca_result[:, :2], columns=[f'cPCA {i+1}' for i in range(2)])
        if group_labels is not None:
            df_proj['group'] = group_labels.to_numpy()
            fig = px.scatter(df_proj, x='cPCA 1', y='cPCA 2',
                           title=f'2D Scatter Plot for {dataset_name} (cPCA {error_type})',
                           color='group', color_discrete_sequence=custom_colors)
        else:
            fig = px.scatter(df_proj, x='cPCA 1', y='cPCA 2',
                           title=f'2D Scatter Plot for {dataset_name} (cPCA {error_type})')
        fig.write_image(os.path.join(output_dir, f"{dataset_name}_cPCA_2D_{error_type}.png"))
        fig.show()
    elif n_components == 1:
        df_proj = pd.DataFrame(cpca_result[:, :1], columns=['cPCA 1'])
        fig = px.histogram(df_proj, x='cPCA 1',
                           title=f'Histogram for {dataset_name} (cPCA - 1 Component {error_type})')
        fig.write_image(os.path.join(output_dir, f"{dataset_name}_cPCA_histogram_{error_type}.png"))
        fig.show()
    
    # Scree Plot
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(cpca_eigvals) + 1), cpca_eigvals, marker='o')
    plt.xlabel("Component")
    plt.ylabel("Eigenvalue")
    plt.title(f"Scree Plot for {dataset_name} (cPCA {error_type})")
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_cPCA_scree_{error_type}.png"), bbox_inches='tight')
    plt.show()
    
    # 2D Pairwise Scatter Plots with Coloring
    if n_components >= 2:
        n = min(n_components, 5)
        pair_df = pd.DataFrame(cpca_result[:, :n], columns=[f'cPCA{i+1}' for i in range(n)])
        if group_labels is not None:
            pair_df['group'] = group_labels.to_numpy()
            pairplot = sns.pairplot(pair_df, hue='group', diag_kind="kde", palette=custom_colors)
        else:
            pairplot = sns.pairplot(pair_df, diag_kind="kde")
        plt.suptitle(f"{dataset_name}_cPCA_pairplot_{error_type}", y=1.02)
        pairplot.fig.savefig(os.path.join(output_dir, f"{dataset_name}_cPCA_pairplot_{error_type}.png"), bbox_inches='tight')
        plt.show()

def visualize_cpca_errors(dataset_name, L1_errors_dict, n_components_list):
    plt.figure(figsize=(12, 5))
    for alpha, error in L1_errors_dict.items():
        plt.plot(n_components_list, [error] * len(n_components_list), label=f'L1 α={alpha}', linestyle='--')
    plt.xlabel("Number of Components")
    plt.ylabel("L1 Reconstruction Error")
    plt.title(f"cPCA Errors vs Components for {dataset_name} (L1)")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_cPCA_errors_L1.png"), bbox_inches='tight')
    plt.show()

####################################
# New: Visualization Function for cPCA Heatmap of Feature Loadings
####################################
def visualize_cpca_heatmap(dataset_name, cpca_loadings, feature_names, error_type='L1'):
    """
    Visualizes a heatmap of cPCA loadings.
    
    Parameters:
        dataset_name (str): Name of the dataset.
        cpca_loadings (np.ndarray): The eigenvectors (loadings) computed by cPCA.
                                    Expected shape is (n_features, n_components).
        feature_names (list or pd.Index): Names of the features (after dropping dummy columns).
        error_type (str): A label to indicate which error type was used (e.g., 'L1').
    """
    # Transpose so that rows are components and columns are features.
    loadings_matrix = cpca_loadings.T  # now shape is (n_components, n_features)
    n_components = loadings_matrix.shape[0]
    n_to_plot = min(n_components, 5)  # Plot only the first 5 components for clarity

    plt.figure(figsize=(max(10, len(feature_names) * 0.4), 6))
    sns.heatmap(
        loadings_matrix[:n_to_plot, :],
        cmap='coolwarm',
        xticklabels=feature_names,
        yticklabels=[f'cPCA {i+1}' for i in range(n_to_plot)]
    )
    plt.title(f'{dataset_name} (cPCA {error_type}) Feature Loadings')
    plt.xlabel("Features")
    plt.ylabel("Component")
    
    filename = os.path.join(output_dir, f"{dataset_name}_cPCA_heatmap_{error_type}.png")
    plt.savefig(filename, bbox_inches='tight')
    plt.show()

####################################
# MAIN cPCA REDUCTION FUNCTION (using L1 error)
####################################
def apply_cpca_reduction(dataset_name, df):
    print(f"\n--- Processing {dataset_name} ---")
    # Exclude dummy columns from cPCA computation.
    if dataset_name == "Accent":
        dummy_cols = [col for col in df.columns if col.startswith("language_")]
    else:
        dummy_cols = [col for col in df.columns if col.startswith("Group_")]
    X_full = df.drop(columns=dummy_cols).values

    # Define target and background splits.
    if dataset_name == "Combined Autism":
        target_df = df[df['Group_Adult'] == 1].copy()
        X_target = target_df.drop(columns=dummy_cols).values
        X_background = df[df['Group_Adult'] != 1].drop(columns=dummy_cols).values
    elif dataset_name == "Accent":
        target_df = df[df['language_US'] == 1].copy()
        X_target = target_df.drop(columns=dummy_cols).values
        X_background = df[df['language_US'] != 1].drop(columns=dummy_cols).values
    else:
        target_df, background_df = select_cpca_split(df, dataset_name)
        X_target = target_df.drop(columns=dummy_cols).values
        X_background = background_df.drop(columns=dummy_cols).values

    # Tune cPCA over candidate n_components only, fixing alpha.
    fixed_alpha = 2.6
    max_components = np.linalg.matrix_rank(np.cov(X_target, rowvar=False))
    best_n, best_eigvecs, best_eigvals, cpca_L1_error, L1_errors_list, n_components_list = tune_cpca_n_components(
        X_target, X_background, fixed_alpha=fixed_alpha, max_components=max_components)
    print(f"  ✅ cPCA (target) selected {best_n} components with alpha = {fixed_alpha} and target error = {cpca_L1_error:.4f}")
    
    # Project the full dataset using the learned eigenvectors.
    cpca_result_full = X_full @ best_eigvecs
    X_reconstructed_full = cpca_result_full @ pinv(best_eigvecs, rcond=1e-5)
    full_error = mean_absolute_error(X_full, X_reconstructed_full)
    # Compute the full relative error for L1 as:
    full_relative_error = full_error / np.sum(np.abs(X_full))
    print(f"  Full dataset reconstruction error: {full_error:.4f}")
    print(f"  Full dataset relative reconstruction error: {full_relative_error:.4f}")
    
    # For visualization: derive group labels.
    if dataset_name == "Accent":
        expected = ["language_US", "language_FR", "language_ES", "language_IT", "language_GE", "language_UK"]
        group_labels = df.apply(get_accent_group, axis=1)
        group_labels = pd.Categorical(group_labels, categories=expected)
        custom_colors = px.colors.qualitative.Plotly[:len(expected)]
    elif dataset_name == "Combined Autism":
        group_labels = df.apply(get_autism_group, axis=1)
        group_labels = pd.Categorical(group_labels, categories=["Group_Adult", "Group_Child", "Group_Adolescent"])
        custom_colors = px.colors.qualitative.Plotly[:3]
    else:
        group_labels, custom_colors = None, None

    # Visualize the full dataset projection.
    visualize_cpca(dataset_name, cpca_result_full, best_eigvals, error_type='L1', group_labels=group_labels, custom_colors=custom_colors)
    
    # ---- New: Visualize the cPCA loadings as a heatmap ----
    feature_names = df.drop(columns=dummy_cols).columns
    visualize_cpca_heatmap(dataset_name, best_eigvecs, feature_names, error_type='L1')
    
    # Append results for later summary.
    comparison_results.append({
        "Dataset": dataset_name,
        "cPCA Best Alpha (L1)": round(fixed_alpha, 4),
        "cPCA Components (L1)": best_n,
        "cPCA L1 Error (target)": round(cpca_L1_error, 4),
        "cPCA Full Error": round(full_error, 4),
        "cPCA Full Relative Error": round(full_relative_error, 4)
    })
    print(f"✅ Completed {dataset_name}")

####################################
# Process all datasets.
####################################
comparison_results = []
for dataset_name, df in datasets.items():
    apply_cpca_reduction(dataset_name, df)

# Final summary table.
results_df = pd.DataFrame(comparison_results)
print("\n--- Final cPCA Summary ---")
print(results_df)

# Save summary results to CSV.
results_df.to_csv("cpca_results_summary_L1.csv", index=False)
print("\nAll datasets processed and results saved.")