In [None]:
import pandas as pd
import numpy as np
from preprocess import load_volcorr, preprocess_volcorr, load_total_df, preprocess_total_df, load_globalminds, preprocess_globalminds
from cluster_eval import evaluate_all
from visualize import visualize_by_config, cluster_heatmap

In [None]:
DATASETS = {
    'volcorr': (load_volcorr, preprocess_volcorr, 'volcorr', 'exchange'),
    'total_df': (load_total_df, preprocess_total_df, 'total_df', 'exchange'),
    'globalminds': (load_globalminds, preprocess_globalminds, 'globalminds', 'country')
}

for key in DATASETS.keys():
    load_fn, pre_fn, data_name, id_col = DATASETS[key]
    df = load_fn()
    X_df = pre_fn(df)
    ids = df[id_col] if id_col in df.columns else X_df.index
    X = X_df.to_numpy()

    summary, stats_df, all_members = evaluate_all(X, ids.to_numpy())

    output_path = f"/mnt/nas/project/crypto/data/sandbox/stats_df_{key}.csv"
    stats_df.to_csv(output_path, index=False)

    screened_stats_df = stats_df.copy()

    if key == 'total_df':
        screened_stats_df = screened_stats_df[screened_stats_df['method'].isin(['pca_kmeans', 'tsne_dbscan', 'umap_hdbscan'])]
        screened_stats_df = screened_stats_df[screened_stats_df['k'] > 5]
        screened_stats_df = screened_stats_df[screened_stats_df['sil'] > 0.3]
    elif key == 'globalminds':
        screened_stats_df = screened_stats_df[screened_stats_df['method'].isin(['pca_kmeans', 'tsne_dbscan', 'umap_hdbscan'])]
        screened_stats_df = screened_stats_df[screened_stats_df['k'] > 5]
        screened_stats_df = screened_stats_df[screened_stats_df['sil'] > 0.3]
    elif key == 'volcorr': 
        screened_stats_df = screened_stats_df[screened_stats_df['method'].isin(['pca_kmeans',' tsne_dbscan', 'umap_hdbscan'])]
        screened_stats_df = screened_stats_df[screened_stats_df['k'] > 5]
        screened_stats_df = screened_stats_df[screened_stats_df['sil'] > 0.3]

    visualize_by_config(df, screened_stats_df, pre_fn, data_name, id_col=id_col)
    cluster_heatmap(df, screened_stats_df, pre_fn, data_name, id_col = id_col)
