> __Purpose:__ This NB tests a CNN in the agglomerative model clustering procedure. With later finetuning. Uses the previously developed PyTorch code

In [1]:
import pandas as pd
import pickle
import numpy as np
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.cross_decomposition import CCA
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

np.random.seed(42) 

from moments_engr import *
from agglo_model_clust import *
from DNN_FT_funcs import *

In [2]:
path1 = 'C:\\Users\\kdmen\\Box\\Meta_Gesture_2024\\saved_datasets\\filtered_datasets\\$BStand_EMG_df.pkl'

with open(path1, 'rb') as file:
    raw_userdef_data_df = pickle.load(file)  # (204800, 19)

print(raw_userdef_data_df.shape)
raw_userdef_data_df.head()

(204800, 19)


Unnamed: 0,Participant,Gesture_ID,Gesture_Num,EMG1,EMG2,EMG3,EMG4,EMG5,EMG6,EMG7,EMG8,EMG9,EMG10,EMG11,EMG12,EMG13,EMG14,EMG15,EMG16
0,P102,pan,1,-0.362743,-0.801651,-0.383077,-0.195299,-0.203047,-0.464472,-0.276292,-0.026736,-0.87387,-1.036152,-0.58093,-0.719494,-0.502255,-1.750091,-0.127847,-0.094192
1,P102,pan,1,-0.351553,-0.775334,-0.382545,-0.154773,-0.131977,-0.295204,-0.125822,0.089679,-0.816215,-2.082635,-0.006283,-0.139439,-0.367764,-0.208084,-0.111811,-0.039009
2,P102,pan,1,-0.380825,-0.762588,-0.398388,-0.085411,0.017528,-0.205675,-0.068451,0.117076,-0.668221,-3.403064,-0.52603,-0.478294,-0.300443,0.203266,0.1133,0.004728
3,P102,pan,1,-0.366795,-0.765464,-0.374423,-0.073225,0.183172,0.009277,-0.058907,0.080977,-0.424416,-3.709413,-0.570894,-0.775155,-0.14471,-0.619539,0.146499,0.199975
4,P102,pan,1,-0.245578,-0.761283,-0.303976,-0.081947,0.224996,0.103319,-0.003929,0.041526,-0.01653,-4.07515,-0.12771,2.682791,-0.14175,-0.208404,-0.035642,0.172662


In [3]:
# STEP 1: Train a classification model on every single individual user

userdef_df = raw_userdef_data_df.groupby(['Participant', 'Gesture_ID', 'Gesture_Num']).apply(create_feature_vectors)
#output is df with particpant, gesture_ID, gesture_num and feature (holds 80 len vector)
userdef_df = userdef_df.reset_index(drop=True)

#convert Gesture_ID to numerical with new Gesture_Encoded column
label_encoder = LabelEncoder()
userdef_df['Gesture_Encoded'] = label_encoder.fit_transform(userdef_df['Gesture_ID'])

label_encoder2 = LabelEncoder()
userdef_df['Cluster_ID'] = label_encoder2.fit_transform(userdef_df['Participant'])

print(userdef_df.shape)
userdef_df.head()

(3200, 6)


Unnamed: 0,Participant,Gesture_ID,Gesture_Num,feature,Gesture_Encoded,Cluster_ID
0,P004,close,1,"[[6.079045311063784], [-7.551458873254243], [-...",0,0
1,P004,close,10,"[[5.994789910363704], [-7.978871468164499], [-...",0,0
2,P004,close,2,"[[6.010193380499154], [-7.7063875553339], [-20...",0,0
3,P004,close,3,"[[5.8212078257286874], [-7.463908156909893], [...",0,0
4,P004,close,4,"[[5.974675085061773], [-7.945111601415482], [-...",0,0


In [4]:
all_participants = userdef_df['Participant'].unique()
# Shuffle the participants
np.random.shuffle(all_participants)
# Split into two groups
#train_participants = all_participants[:24]  # First 24 participants
test_participants = all_participants[24:]  # Remaining 8 participants

In [5]:
# Prepare data
data_splits = prepare_data(
    userdef_df, 'feature', 'Gesture_Encoded', 
    all_participants, test_participants, 
    training_trials_per_gesture=8, finetuning_trials_per_gesture=3,
)

In [6]:
data_splits.keys()

dict_keys(['train', 'intra_subject_test', 'novel_trainFT', 'cross_subject_test'])

In [18]:
features_df = pd.DataFrame(data_splits['train']['features'])
# Create a new column 'features' that contains all 80 columns as lists
features_df['features'] = features_df.apply(lambda row: row.tolist(), axis=1)
# Keep only the new combined column
features_df = features_df[['features']]
# Combine with labels and participant_ids into a single DataFrame
train_df = pd.concat([features_df, pd.Series(data_splits['train']['labels'], name='Gesture_Encoded'), pd.Series(data_splits['train']['participant_ids'], name='participant_ids')], axis=1)
label_encoder = LabelEncoder()
train_df['Cluster_ID'] = label_encoder.fit_transform(train_df['participant_ids'])

features_df = pd.DataFrame(data_splits['intra_subject_test']['features'])
# Create a new column 'features' that contains all 80 columns as lists
features_df['features'] = features_df.apply(lambda row: row.tolist(), axis=1)
# Keep only the new combined column
features_df = features_df[['features']]
# Combine with labels and participant_ids into a single DataFrame
test_df = pd.concat([features_df, pd.Series(data_splits['intra_subject_test']['labels'], name='Gesture_Encoded'), pd.Series(data_splits['intra_subject_test']['participant_ids'], name='participant_ids')], axis=1)
label_encoder = LabelEncoder()
test_df['Cluster_ID'] = label_encoder.fit_transform(test_df['participant_ids'])

# ENTIRELY WITHHOLDING CROSS CLUSTER DATASET (NOVEL TEST SUBJECTS) FOR NOW. 
#test_df
#features_df = pd.DataFrame(data_splits['train']['features'])
## Create a new column 'features' that contains all 80 columns as lists
#features_df['features'] = features_df.apply(lambda row: row.tolist(), axis=1)
## Keep only the new combined column
#features_df = features_df[['features']]
## Combine with labels and participant_ids into a single DataFrame
#train_df = pd.concat([features_df, pd.Series(data_splits['train']['labels'], name='Gesture_Encoded'), pd.Series(data_splits['train']['participant_ids'], name='participant_ids')], axis=1)

In [19]:
data_dfs_dict = {'train':train_df, 'test':test_df}

# Need to update Cluster_ID col at the end of each round, for both dfs

In [20]:
train_df.head()

Unnamed: 0,features,Gesture_Encoded,participant_ids,Cluster_ID
0,"[4.0728441780827485, -6.216633410388578, -15.0...",0,P128,22
1,"[4.842766138855771, -7.933026853320965, -18.38...",0,P128,22
2,"[4.343725560839635, -6.20391996380322, -15.560...",0,P128,22
3,"[4.972292120731378, -6.996600730998998, -17.60...",0,P128,22
4,"[4.942130684716379, -6.513267267205743, -16.81...",0,P128,22


In [29]:
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import StratifiedKFold
import torch
from sklearn.metrics import accuracy_score


def train_and_cv_DNN_cluster_model(train_df, model_type, cluster_ids, 
                                   cluster_column='Cluster_ID', feature_column='features', 
                                   target_column='Gesture_Encoded', n_splits=3, bs=32, 
                                   lr=0.001, criterion=nn.CrossEntropyLoss()):
    """
    Perform k-fold cross-validation for models trained on each cluster in the dataset.
    
    Parameters:
    - userdef_df (DataFrame): The input dataframe with cluster, feature, and target data.
    - model (str or sklearn model object): The model to train. If string, it must be one of:
      ['LogisticRegression', 'SVC', 'RF', 'GradientBoosting', 'KNN'].
    - cluster_ids (list): List of cluster IDs to process.
    - cluster_column (str): Column name representing the cluster IDs.
    - feature_column (str): Column name containing feature arrays.
    - target_column (str): Column name for target labels.
    - n_splits (int): Number of cross-validation splits.
    
    Returns:
    - avg_val_accuracy (float): The average validation accuracy across all folds and clusters.
    """
    
    # Get the model object if a string is provided
    if isinstance(model_type, str):
        # Select model
        if model_type == 'CNN':
            model = CNNModel(input_dim, num_classes).to('cpu')
        elif model_type == 'RNN':
            model = RNNModel(input_dim, num_classes).to('cpu')
        else:
            raise ValueError(f"Unsupported model: {model_type}. Only CNNs and RNNs are supported.")
    else:
        # Assuming a model object was passed in
        base_model = model_type

    total_val_accuracy = 0
    num_folds_processed = 0
    clus_model_lst = []
    for cluster in cluster_ids:
        
        #######################################################################
        #######################################################################
        #######################################################################
        
        # Filter data for the current cluster
        cluster_data = train_df[train_df[cluster_column] == cluster]
        #X = np.array([x.flatten() for x in cluster_data[feature_column]])
        X = np.array([x for x in cluster_data[feature_column]])
        y = np.array(cluster_data[target_column])

        # Stratified K-Fold for validation splits
        ## IDK IF THIS WILL WORK WITH PYTORCH FORMATTED DATA...
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        cluster_val_accuracy = 0
        for idx, (train_idx, val_idx) in enumerate(skf.split(X, y)):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            # Convert to PyTorch tensors
            X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
            X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
            y_train_tensor = torch.tensor(y_train, dtype=torch.long)
            y_val_tensor = torch.tensor(y_val, dtype=torch.long)
            # Create TensorDataset
            train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
            val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
            # Create DataLoader
            train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False)
            
            fold_model = base_model.__class__(**base_model.init_params)
            optimizer = torch.optim.Adam(fold_model.parameters(), lr=lr)
            
            # Now train your fold_model
            fold_model.train()  # Ensure the model is in training mode
            for epoch in range(num_epochs):
                # Training loop with your X_train and y_train
                for X_batch, y_batch in train_loader:  # Assuming you have a DataLoader
                    optimizer.zero_grad()
                    outputs = fold_model(X_batch)
                    loss = criterion(outputs, y_batch)
                    loss.backward()
                    optimizer.step()

            # Evaluate the model on the validation set
            fold_model.eval()
            fold_predictions = []
            fold_true_labels = []
            with torch.no_grad():
                # Validation loop
                predictions = []
                for X_val in val_loader:  # Assuming you have a validation DataLoader
                    outputs = fold_model(X_val)
                    _, preds = torch.max(outputs, 1)
                    fold_predictions.extend(preds.cpu().numpy())
                    fold_true_labels.extend(y_val.cpu().numpy())
            # Predict on the validation set and calculate accuracy
            #y_pred = fold_model.predict(X_val)
            #cluster_val_accuracy += accuracy_score(y_val, y_pred)
            # Calculate accuracy for the current fold
            fold_accuracy = accuracy_score(fold_true_labels, fold_predictions)
            cluster_val_accuracy += fold_accuracy

            if idx==0:
                # No great way to save the models... I really only want to use 1...
                ## So for now I'll just save the first kfold's model...
                ## Consistently biased but hopefully the val splits are all roughly equivalent
                clus_model_lst.append(fold_model)

        # REWRITE THIS!!!
        # Average accuracy for this cluster
        cluster_val_accuracy /= n_splits
        # I think this really ought to append not add...
        ## TOTAL maintains the acc of the entire process (across all clusters)
        total_val_accuracy += cluster_val_accuracy
        num_folds_processed += 1
        
        #######################################################################
        #######################################################################
        #######################################################################

    # Overall average accuracy across all clusters
    avg_val_accuracy = total_val_accuracy / num_folds_processed
    #print(f"\nOverall Average Validation Accuracy: {avg_val_accuracy:.4f}")
    
    return clus_model_lst

In [30]:
def DNN_agglo_merge_procedure(data_dfs_dict, model_type, n_splits=2):
    """
    Parameters:
    - model (str or sklearn model object): The model to train. If string, it must be one of:
      ['LogisticRegression', 'SVC', 'RF', 'GradientBoosting', 'KNN', 'XGBoost'].
    """
    
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Unique gestures and number of classes
    unique_gestures = np.unique(data_dfs_dict['train']['Gesture_Encoded'])
    num_classes = len(unique_gestures)
    input_dim = len(data_dfs_dict['train']['features'].iloc[0])
    
    # Select model
    if model_type == 'CNN':
        model = CNNModel(input_dim, num_classes).to('cpu')
    elif model_type == 'RNN':
        model = RNNModel(input_dim, num_classes).to('cpu')
        
    train_df = data_dfs_dict['train']
    test_df = data_dfs_dict['test']
        
    # Data structures for logging cluster merging procedure
    merge_log = []  # List of tuples: [(cluster1, cluster2, distance, new_cluster), ...]
    unique_clusters_log = []  # List of lists: [list of unique clusters at each step]
    # Dictionary to store self-performance over iterations
    intra_cluster_performance = {}
    cross_cluster_performance = {}
    # Simulate cluster merging and model performance tracking
    iterations = 0
    # Main loop for cluster merging
    while len(train_df['Cluster_ID'].unique()) > 1:
        print(f"{len(train_df['Cluster_ID'].unique())} Clusters Remaining")
        # Log the current state of clusters
        unique_clusters_log.append(sorted(train_df['Cluster_ID'].unique()))

        # userdef_df continuously changes wrt cluster ID
        ## So... do train_df and test_df continuously change? Cluster ID changes... 
        ## but that... may or may not affect stratificiation in a meaningful way (I'm not using cluster metadata...)
        # These 2 should be the same... it's stratified...
        current_train_cluster_ids = sorted(train_df['Cluster_ID'].unique())
        current_test_cluster_ids = sorted(test_df['Cluster_ID'].unique())
        if current_train_cluster_ids == current_test_cluster_ids:
            current_cluster_ids = current_train_cluster_ids
        else:
            raise ValueError("Train/test Cluster ID lists not the same length... Stratify failed")

        # Train models with logging for specified clusters
        ## UPDATED TO DNN VERSION HERE!
        clus_model_lst = train_and_cv_DNN_cluster_model(train_df, model, current_cluster_ids, n_splits=n_splits)
        # Pairwise test models with logging for specified clusters
        sym_acc_arr = test_models_on_clusters(test_df, clus_model_lst, current_cluster_ids)

        for idx, cluster_id in enumerate(current_cluster_ids):
            cross_acc_sum = 0
            cross_acc_count = 0

            for idx2, cluster_id2 in enumerate(current_cluster_ids):
                if cluster_id not in intra_cluster_performance:
                    intra_cluster_performance[cluster_id] = []  # Initialize list

                if idx == idx2:  # Diagonal, so intra-cluster
                    # Ensure the logic assumption holds
                    if cluster_id != cluster_id2:
                        raise ValueError("This code isn't working as expected...")
                    intra_cluster_performance[cluster_id].append((iterations, sym_acc_arr[idx, idx2]))
                else:  # Non-diagonal, so cross-cluster
                    cross_acc_sum += sym_acc_arr[idx, idx2]
                    cross_acc_count += 1

            # Calculate average cross-cluster accuracy
            if cross_acc_count > 0:
                avg_cross_acc = cross_acc_sum / cross_acc_count
            else:
                avg_cross_acc = None  # Handle the case where no cross-cluster pairs exist
            # Append the average cross-cluster accuracy to all relevant clusters
            if cluster_id not in cross_cluster_performance:
                cross_cluster_performance[cluster_id] = []  # Initialize list
            cross_cluster_performance[cluster_id].append((iterations, avg_cross_acc))

        masked_diag_array = sym_acc_arr.copy()
        np.fill_diagonal(masked_diag_array, 0.0)
        similarity_score = np.max(masked_diag_array)
        max_index = np.unravel_index(np.argmax(masked_diag_array), masked_diag_array.shape)
        row_idx_to_merge = max_index[0]
        col_idx_to_merge = max_index[1]
        # Get actual cluster IDs to merge
        row_cluster_to_merge = current_cluster_ids[row_idx_to_merge]
        col_cluster_to_merge = current_cluster_ids[col_idx_to_merge]

        # Create a new cluster ID for the merged cluster
        new_cluster_id = max(current_cluster_ids) + 1
        #print(f"MERGE: {row_cluster_to_merge, col_cluster_to_merge} @ {similarity_score*100:.2f}. New cluster: {new_cluster_id}")
        # Log the merge
        merge_log.append((iterations, row_cluster_to_merge, col_cluster_to_merge, similarity_score, new_cluster_id))
        # Update the DataFrame with the new merged cluster
        #userdef_df.loc[userdef_df['Cluster_ID'].isin([row_cluster_to_merge, col_cluster_to_merge]), 'Cluster_ID'] = new_cluster_id
        train_df.loc[userdef_df['Cluster_ID'].isin([row_cluster_to_merge, col_cluster_to_merge]), 'Cluster_ID'] = new_cluster_id
        test_df.loc[userdef_df['Cluster_ID'].isin([row_cluster_to_merge, col_cluster_to_merge]), 'Cluster_ID'] = new_cluster_id
        
        # Remove merged clusters from tracking (mark end with None)
        intra_cluster_performance[row_cluster_to_merge].append((iterations, None))
        intra_cluster_performance[col_cluster_to_merge].append((iterations, None))
        cross_cluster_performance[row_cluster_to_merge].append((iterations, None))
        cross_cluster_performance[col_cluster_to_merge].append((iterations, None))

        iterations += 1
    
    return merge_log, intra_cluster_performance, cross_cluster_performance

In [31]:
merge_log, intra_cluster_performance, cross_cluster_performance = DNN_agglo_merge_procedure(data_dfs_dict, "CNN", n_splits=2)

24 Clusters Remaining


AttributeError: 'CNNModel' object has no attribute 'init_params'

In [None]:
merge_log

## INTRA CLUSTER RESULTS

In [None]:
# Visualization
plt.figure(figsize=(12, 6))

for cluster_id in intra_cluster_performance:
    # Extract valid iterations and performance
    data = intra_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    if valid_iterations[0]==0:
        continue
    #print(valid_iterations)
    #print(valid_performance)
    #print()
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")

plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Intra-Cluster Test Accuracy", fontsize=18)
plt.title(f"{model_str} Intra-Cluster Acc: Merged Clusters Only", fontsize=18)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Visualization
plt.figure(figsize=(12, 6))

for cluster_id in intra_cluster_performance:
    # Extract valid iterations and performance
    data = intra_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    if valid_iterations[0]!=0:
        continue
    #print(valid_iterations)
    #print(valid_performance)
    #print()
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")

plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Intra-Cluster Test Accuracy", fontsize=18)
plt.title(f"{model_str} Intra-Cluster Acc: Original Clusters Only", fontsize=18)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

## CROSS CLUSTER RESULTS

In [None]:
# Visualization
plt.figure(figsize=(12, 6))

for cluster_id in cross_cluster_performance:
    # Extract valid iterations and performance
    data = cross_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    if valid_iterations[0]==0:
        continue
    #print(valid_iterations)
    #print(valid_performance)
    #print()
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")

plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Cross-Cluster Test Accuracy", fontsize=18)
plt.title(f"{model_str} Cross-Cluster Acc: Merged Clusters Only", fontsize=18)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Visualization
plt.figure(figsize=(12, 6))

for cluster_id in cross_cluster_performance:
    # Extract valid iterations and performance
    data = cross_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    if valid_iterations[0]!=0:
        continue
    #print(valid_iterations)
    #print(valid_performance)
    #print()
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")

plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Cross-Cluster Test Accuracy", fontsize=18)
plt.title(f"{model_str} Cross-Cluster Acc: Original Clusters Only", fontsize=18)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Visualization
plt.figure(figsize=(12, 6))

n = 5

# Reduce number of lines (e.g., top n longest-lived clusters)
longest_clusters = sorted(cross_cluster_performance.keys(), key=lambda k: len(cross_cluster_performance[k]), reverse=True)[:n]

for cluster_id in longest_clusters:
    # Extract valid iterations and performance
    data = cross_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")

plt.xlabel("Iteration")
plt.ylabel("Cross-Cluster Test Accuracy")
plt.title(f"Cross-Cluster Test Accuracy: {n} Longest-Lasting Clusters")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Visualization
plt.figure(figsize=(12, 6))

n = 5  # Number of clusters to plot

# Find clusters with the highest final accuracies
highest_final_accuracy_clusters = sorted(
    cross_cluster_performance.keys(), 
    key=lambda k: max([perf for it, perf in cross_cluster_performance[k] if perf is not None], default=0), 
    reverse=True
)[:n]

# Plot the performance curves for these clusters
for cluster_id in highest_final_accuracy_clusters:
    # Extract valid iterations and performance
    data = cross_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")

plt.xlabel("Iteration")
plt.ylabel("Cross-Cluster Test Accuracy")
plt.title(f"Cross-Cluster Test Accuracy: {n} Clusters with Highest Final Accuracy")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


## Intra-Cluster Test Accuracy Merge Tracking

In [None]:
# Visualization with Merge Log and Connections
plt.figure(figsize=(14, 8))

# Dictionary to track the last valid point for each cluster
last_points = {}

# Plot original clusters
for cluster_id in cross_cluster_performance:
    data = cross_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    if valid_iterations[0] != 0:
        continue

    # Plot original cluster performance
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")
    last_points[cluster_id] = (valid_iterations[-1], valid_performance[-1])  # Store the last point

# Handle merged clusters and connect original clusters
for iteration, cluster1, cluster2, _, new_cluster in merge_log:
    # Plot and connect the merged clusters
    for cluster in [cluster1, cluster2]:
        if cluster in cross_cluster_performance:
            data = cross_cluster_performance[cluster]
            merge_perf = next((perf for it, perf in data if it == iteration), None)
            if merge_perf is not None:
                plt.scatter(iteration, merge_perf, color='red', marker='^')#, label=f"Merge {cluster} → {new_cluster}")
                
        if cluster in last_points:  # If it's an original cluster
            last_iteration, last_perf = last_points[cluster]

            # Connect to the newly merged cluster
            if new_cluster in cross_cluster_performance:
                new_data = cross_cluster_performance[new_cluster]
                valid_iterations = [it for it, perf in new_data if perf is not None and it >= iteration]
                valid_performance = [perf for it, perf in new_data if perf is not None and it >= iteration]

                if valid_iterations:
                    # Draw a line connecting the original cluster to the new merged cluster
                    plt.plot(
                        [last_iteration, valid_iterations[0]],
                        [last_perf, valid_performance[0]],
                        linestyle='--', color='gray'
                    )

                    # Continue plotting the merged cluster's performance
                    plt.plot(valid_iterations, valid_performance, linestyle='--')

                # Update the last points for the newly merged cluster
                if valid_iterations:
                    last_points[new_cluster] = (valid_iterations[-1], valid_performance[-1])

# Add labels, legend, and formatting
plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Cross-Cluster Test Accuracy", fontsize=18)
plt.title(f"{model_str} Cross-Cluster Acc with Merge Connections", fontsize=18)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
# Visualization with Merge Log and Connections
plt.figure(figsize=(14, 8))

# Dictionary to track the last valid point for each cluster
last_points = {}

# Plot original clusters
for cluster_id in intra_cluster_performance:
    data = intra_cluster_performance[cluster_id]
    valid_iterations = [it for it, perf in data if perf is not None]
    valid_performance = [perf for it, perf in data if perf is not None]
    if valid_iterations[0] != 0:
        continue

    # Plot original cluster performance
    plt.plot(valid_iterations, valid_performance, label=f"Cluster {cluster_id}")
    last_points[cluster_id] = (valid_iterations[-1], valid_performance[-1])  # Store the last point

# Handle merged clusters and connect original clusters
for iteration, cluster1, cluster2, _, new_cluster in merge_log:
    # Plot and connect the merged clusters
    for cluster in [cluster1, cluster2]:
        if cluster in intra_cluster_performance:
            data = intra_cluster_performance[cluster]
            merge_perf = next((perf for it, perf in data if it == iteration), None)
            if merge_perf is not None:
                plt.scatter(iteration, merge_perf, color='red', marker='^')#, label=f"Merge {cluster} → {new_cluster}")
                
        if cluster in last_points:  # If it's an original cluster
            last_iteration, last_perf = last_points[cluster]

            # Connect to the newly merged cluster
            if new_cluster in intra_cluster_performance:
                new_data = intra_cluster_performance[new_cluster]
                valid_iterations = [it for it, perf in new_data if perf is not None and it >= iteration]
                valid_performance = [perf for it, perf in new_data if perf is not None and it >= iteration]

                if valid_iterations:
                    # Draw a line connecting the original cluster to the new merged cluster
                    plt.plot(
                        [last_iteration, valid_iterations[0]],
                        [last_perf, valid_performance[0]],
                        linestyle='--', color='gray'
                    )

                    # Continue plotting the merged cluster's performance
                    plt.plot(valid_iterations, valid_performance, linestyle='--')

                # Update the last points for the newly merged cluster
                if valid_iterations:
                    last_points[new_cluster] = (valid_iterations[-1], valid_performance[-1])

# Add labels, legend, and formatting
plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Intra-Cluster Test Accuracy", fontsize=18)
plt.title(f"{model_str} Intra-Cluster Acc with Merge Connections", fontsize=18)
#plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
intra_mean_lst, cross_mean_lst, ratio_lst = compute_performance_ratios(intra_cluster_performance, cross_cluster_performance)


In [None]:
plt.plot(intra_mean_lst[31:], label="Intra Mean")
plt.plot(cross_mean_lst[31:], label="Cross Mean")
plt.plot((np.array(ratio_lst)/10)[31:], label="Ratio/10")
plt.xlabel("Iteration", fontsize=18)
plt.ylabel("Mean Accuracy | Ratio/10", fontsize=18)
plt.title(f"{model_str} Summary Statistic Trends", fontsize=18)
plt.legend(loc='upper left')