# Bickley Data Experiment

In [None]:
### Import relevant libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import font_manager
import os
from tqdm import tqdm
import json
import pickle
import seaborn as sns
from helper_functions.bickley_funcs import (align_clusters, dynamic_isoperimetric_score, flatten_timeseries, 
                                            calc_graph_dynamic_isoperim, calc_graph_dynamic_isoperim_4, calc_embed_isoperim)
from helper_functions.bickley_funcs import (create_dataset, create_views, plot_clustering)

sim_params = {
        'delete_kernels': False,
        'generate_data': True,
        'animate' : True,  #  animation generation increases the run time
        'evd_solver': 'arpack',  # 'arpack' / 'randomized' / 'svd'
        'ad_methods': ['ad', 'alternating_roseland', 'ffbb', 'fbfb', 'lead', 'forward_only', 'ncca', 'kcca_impute', 'kcca', 'nystrom', 'adm_plus', 'backward_only'],
        # 'ad_methods': ['adm_plus', 'backward_only'],
        'embed_dim': 20,
        'clusters': 9,
        't': 1,
        'scale': 1,
        'sigma_mode' : 'scale',  # 'median' , 'scale'
        'Nr': 3000,  # number of samples in the reference set,
        'N' : 4000, # number of total samples
        'lag' : 200,  # lag between trajectories for short_traj mode
        'traj_len' : 1,  # for short_traj mode
        'times': [199, 400, 250, 300], # times for view generation (multi_frame) first 2 times are for the embedding views and the last are for evaluation of the dynamic isoperimetry score
        'views': 'multi_frame',  # 'frame' , 'traj' , 'short_traj', 'multi_frame'
        'cmap': 'custom',
        'reps': 10
    }

fig_str = f"figures/bickley"
figures_path = f"{fig_str}_{sim_params['N']}_lag_{sim_params['lag']}_{sim_params['views']}_tlen_{sim_params['traj_len']}_scale_{sim_params['scale']}_final_run".replace('.', 'p')
os.makedirs(figures_path, exist_ok=True)
Nr = sim_params['Nr'] # number of samples in the reference set
N_d = sim_params['N']

with open(f"{figures_path}/sim_params.json", 'w') as fp:
    json.dump(sim_params, fp, indent=4)

In [None]:
font_name = "Times New Roman"  # Change to any other installed serif font if needed

# Set font properties using the font name
font_properties = font_manager.FontProperties(family=font_name, size=18)

In [None]:
from helper_functions.AD_funcs import Create_Transition_Mat, Create_Asym_Tran_Kernel
def compute_kernels(s1_full, s1_ref, s2_full, s2_ref, s3_full, s4_full):
    A1, _, _ = Create_Asym_Tran_Kernel(s1_full, s1_ref, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    A2, _, _ = Create_Asym_Tran_Kernel(s2_full, s2_ref, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    K1, _ = Create_Transition_Mat(s1_full, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    K2, _ = Create_Transition_Mat(s2_full, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    K1_ref, _ = Create_Transition_Mat(s1_ref, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    K2_ref, _ = Create_Transition_Mat(s2_ref, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    if sim_params['views'] == 'multi_frame':
        K3, _ = Create_Transition_Mat(s3_full, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
        K4, _ = Create_Transition_Mat(s4_full, scale=sim_params['scale'], mode=sim_params['sigma_mode'])
    else:
        K3 = None
        K4 = None
    
    return A1, A2, K1, K2, K1_ref, K2_ref, K3, K4

# kernels_dict = dict()
# kernels_dict['K1'] = K1
# kernels_dict['K2'] = K2
# kernels_dict['K1_ref'] = K1_ref
# kernels_dict['K2_ref'] = K2_ref
# kernels_dict['A1'] = A1
# kernels_dict['A2'] = A2


# with open(f"{figures_path}/kernels_dict.pkl", 'wb') as fp:
#     pickle.dump(kernels_dict, fp)
#     print('dictionary saved successfully to file')

In [None]:
from helper_functions.bickley_funcs import method_analysis, evaluate_metrics
results = []
for rep in tqdm(range(sim_params['reps'])):
    embed_dict = dict()
    # create data for rep
    dataset, c = create_dataset(sim_params, figures_path)
    # create views
    s1_full, s2_full, s3_full, s4_full = create_views(dataset, sim_params)
    s1_ref = s1_full[:Nr, :]
    s2_ref = s2_full[:Nr, :]
    # rep figures path 
    rep_path = f'{figures_path}/rep_{rep}'
    os.makedirs(rep_path, exist_ok=True)
    c_align = None
    A1, A2, K1, K2, K1_ref, K2_ref, K3, K4 = compute_kernels(s1_full, s1_ref, s2_full, s2_ref, s3_full, s4_full)
    for i, method in enumerate(sim_params['ad_methods']):
        print(f'now processing method {method}, number {i+1} out of {len(sim_params['ad_methods'])}')
        if method in {"forward_only", "forward_only_slow", "alternating_roseland", "ffbb", "fbfb", "ncca", "kcca", "nystrom", "adm_plus", 'backward_only'}:
            embed, c = method_analysis(s1_ref, s1_full, s2_ref, s2_full, method=method,
                                       sim_params=sim_params, figures_path=rep_path, K1=A1, K2=K2_ref, dataset=dataset, 
                                       font_properties=font_properties, c_align=c_align)
        elif method in {"kcca_impute"}:
            embed, c = method_analysis(s1_ref, s1_full, s2_ref, s2_full, method=method,
                                       sim_params=sim_params, figures_path=rep_path, K1=K1, K2=K2_ref, dataset=dataset, 
                                       font_properties=font_properties, c_align=c_align)
        elif method in {"ad", 'dm', 'ad_svd'}:
            embed, c = method_analysis(s1_ref, s1_full, s2_ref, s2_full, method=method,
                                       sim_params=sim_params, figures_path=rep_path, K1=K1, K2=K2, dataset=dataset, 
                                       font_properties=font_properties, c_align=c_align)
        elif method == "lead":
            embed, c = method_analysis(s1_ref, s1_full, s2_ref, s2_full, method=method,
                                       sim_params=sim_params, figures_path=rep_path, K1=A1, K2=A2, dataset=dataset, 
                                       font_properties=font_properties, c_align=c_align)
        else:
            raise ValueError(f"invalid method: {method}")
        if i == 0:
            c_align = c
        # save clustering and embedding to dictionary
        embed_dict[method] = {'embed': embed, 'clustering': c, 's1': s1_full, 's2': s2_full}
        # evaluate metrics
        new_line = evaluate_metrics(s1_full, s2_full, c, embed, sim_params, K1, K2, K3, K4, method, rep)
        results.append(new_line)
    # save rep dictionary
    with open(f"{rep_path}/embedding_dictionary.pkl", 'wb') as fp:
        pickle.dump(embed_dict, fp)
        print('dictionary saved successfully to file')
results_df = pd.DataFrame(results)
results_df['generalization_gap'] = results_df['dynamic_graph_isoperim_score_4'] - results_df['dynamic_graph_isoperim_score']
results_df.to_csv(f'{figures_path}/results.csv')

In [None]:
avg_results_df = results_df.groupby('method').mean()
avg_results_df.to_csv(f'{figures_path}/average_results.csv')
stats_columns = ['silhouette_score_embed', 'dynamic_graph_isoperim_score', 'dynamic_graph_isoperim_score_4', 'generalization_gap']
stats_df = results_df.groupby('method')[stats_columns].agg(
        ['mean', 'std']
    ).reset_index()
stats_df.to_csv(f'{figures_path}/stats_results.csv')

## Post Analysis

In [None]:
# compute confusion matrices for each method's clustering compared to clustering achieved by ADM
from sklearn.metrics import confusion_matrix
from helper_functions.clustering_funcs import align_clusters_different_rep
from helper_functions.bickley_funcs import plot_hit_or_miss, plot_wrapper
methods_to_plot = ['ad', 'ncca', 'kcca_impute', 'adm_plus', 'nystrom', 'forward_only']
reference_method = 'ad'
annotate = False

cm_path = f'{figures_path}/confusion_matrices'
os.makedirs(cm_path, exist_ok=True)
cm_dict = dict()
for rep in tqdm(range(sim_params['reps'])):
    rep_path = f'{figures_path}/rep_{rep}'
    summary_path = f'{rep_path}/summary'
    os.makedirs(summary_path, exist_ok=True)
    with open(f"{rep_path}/embedding_dictionary.pkl", 'rb') as fp:
        embed_dict = pickle.load(fp)
    print('Dictionary Loaded Successfully')
    c_ad = embed_dict[reference_method]['clustering']
    # save first rep clustering to align the rest to it
    if rep == 0:
        c_align = c_ad
        s1_align = embed_dict[reference_method]['s1']
        s2_align = embed_dict[reference_method]['s2']
    else:
        c_ad = align_clusters_different_rep(s1=s1_align, label1=c_align, s2=embed_dict[reference_method]['s1'], label2=c_ad, metric='kde_correlation')
    for i, method in enumerate(sim_params['ad_methods']):
        # load clustering
        c_curr = embed_dict[method]['clustering']
        c_curr = align_clusters(c_ad, c_curr)
        # initialize list of confusion matrices for each method
        if method not in cm_dict.keys():
            cm_dict[method] = []
        # compute confusion matrix
        cm = confusion_matrix(y_true=c_ad, y_pred=c_curr, normalize='true')
        cm_dict[method].append(cm)

        # plot hit or miss plot
        if method in methods_to_plot:
            plot_wrapper(data=embed_dict[method]['s2'], c=c_curr, c_ref=c_ad, figures_path=summary_path, method=method,
                             font_properties=font_properties, sim_params=sim_params, cmap='custom', plot_type='hit-or-miss')
            plot_wrapper(data=embed_dict[method]['s2'], c=c_curr, c_ref=c_ad, figures_path=summary_path, method=method,
                             font_properties=font_properties, sim_params=sim_params, cmap='custom', plot_type='scatter')

# plot mean and std confusion matrix
for method in sim_params['ad_methods']:
    # calcualte mean and std confusion matrix
    cm_array = np.array(cm_dict[method])
    mean_cm = np.mean(cm_array, axis=0)
    std_cm = np.std(cm_array, axis=0)

    # plot
    class_names = np.arange(sim_params['clusters'])
    fig, ax = plt.subplots(figsize=(8, 8))
    fmt = ".2f"
    # Create annotation text with mean ± std
    if annotate:
        annotations = np.array([[f"{mean:.2f}\n±{std:.2f}" for mean, std in zip(row_mean, row_std)]
                                for row_mean, row_std in zip(mean_cm, std_cm)])
    
        sns.heatmap(mean_cm, annot=annotations, fmt="", cmap="Blues", xticklabels=class_names, yticklabels=class_names, vmin=0, vmax=1)
    else:
        sns.heatmap(mean_cm, fmt="", cmap="Blues", xticklabels=class_names, yticklabels=class_names, vmin=0, vmax=1)
    
    plt.xlabel(f'Predicted Sets', fontsize=26)
    plt.ylabel("ADM Sets", fontsize=26)
    ax.tick_params(labelsize=20)
    cax = ax.figure.axes[-1]
    cax.tick_params(labelsize=20)

    plt.savefig(f'{cm_path}/confusion_matrix_{method}.pdf', dpi=500, format='pdf', bbox_inches='tight')

In [None]:
embed_dict.keys()

## Deeptime Experiment code

In [None]:
# from deeptime.kernels import GaussianKernel
# from deeptime.decomposition import KernelCCA
# from deeptime.clustering import KMeans
#
# method = 'deeptime_kcca'
#
# # create kernel
# sigma = sim_params['scale']
# kernel = GaussianKernel(sigma)
#
# kcca_estimator = KernelCCA(kernel, n_eigs=9, epsilon=1e-3)
# kcca_model = kcca_estimator.fit((s1_full, s2_full)).fetch_model()
# embed = np.real(kcca_model.eigenvectors)
# # perform kmeans
# kmeans = KMeans(n_clusters=sim_params['clusters'], n_jobs=8).fit(embed).fetch_model()
# c = kmeans.transform(embed)
# c = align_clusters(c_align, c)
# if sim_params['animate']:
#     fig, ax = plt.subplots(1, 1, figsize=(6, 4))
#     ani = dataset.make_animation(c=c/sim_params['clusters'], agg_backend=False, interval=75, fig=fig, ax=ax, s=50)
#     ani.save(f'{figures_path}/{method}_clustering_animation.mp4', writer='ffmpeg', fps=30, dpi=300)
# plot_clustering(dataset, c, figures_path, method, cmap=sim_params['cmap'])

In [None]:
# grid1, grid2, score, _, _ = dynamic_isoperimetric_score(s1_full, s2_full, c, grid_size_x=50, grid_size_y=15, bounds=(0, -3, 20, 3))
# plt.figure(figsize=(8, 4))
# # grid1 = create_grid(s2_full, c, grid_size_x=50, grid_size_y=15, bounds=(0, -3, 20, 3))
# plt.imshow(grid2.T, origin='lower', cmap='Dark2')
# plt.colorbar(label='Cluster Label')
# plt.xlabel('X-axis')
# plt.ylabel('Y-axis')
# plt.show()
# print(f'isoperimetric score {score}')

In [None]:
# grid = np.meshgrid(np.linspace(0, 20, 150), np.linspace(-3, 3, 50))
# xy = np.dstack(grid).reshape(-1, 2)
# z = kcca_model.transform(xy).real
# 
# fig = plt.figure(figsize=(12, 10))
# gs = fig.add_gridspec(ncols=2, nrows=3)
# 
# for row in range(3):
#     for col in range(2):
#         ix = col + 2*row
#         ax = fig.add_subplot(gs[row, col])
#         ax.contourf(grid[0], grid[1], z[:, ix].reshape(grid[0].shape), levels=15)
#         ax.set_title(f"Eigenfunction {ix+1}")