In [None]:
import os
import json
import matplotlib.pyplot as plt
import pandas as pd

# Enable LaTeX font
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'

def load_metrics(run_paths):
    all_data = {}
    for run_name, run_path in run_paths.items():
        metrics_file = os.path.join(run_path, 'metrics.json')
        if os.path.isfile(metrics_file):
            with open(metrics_file, 'r') as f:
                metrics = json.load(f)
            all_data[run_name] = metrics
    return all_data

# Example usage:
# Suppose you have { 'config_1': './experiments/run_20231001_123455/config_1', ... }
run_paths = {
    'config_1': './experiments/run_20250107_140754/config_1',
    'config_2': './experiments/run_20250107_140754/config_2',
}

metrics_dict = load_metrics(run_paths)

# Convert each run's metrics into a DataFrame and plot them
for run_name, m in metrics_dict.items():
    df = pd.DataFrame(m)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(df['nmi_true'], label='NMI vs GT')
    axes[0].plot(df['ari_true'], label='ARI vs GT')
    axes[0].plot(df['nmi_prev'], label='NMI vs Prev')
    axes[0].legend()
    axes[0].set_title(f'{run_name} Clustering Metrics')

    axes[1].plot(df['silhouette'], label='Silhouette')
    axes[1].plot(df['dbi'], label='DBI')
    axes[1].legend()
    axes[1].set_title(f'{run_name} Cluster Quality')

    plt.suptitle(f'Metrics for {run_name}')
    plt.show()


In [None]:
best_metrics = {}

for run_name, m in metrics_dict.items():
    df = pd.DataFrame(m)
    best_nmi_true = df['nmi_true'].max()
    best_ari_true = df['ari_true'].max()
    best_silhouette = df['silhouette'].max()
    best_dbi = df['dbi'].min()  # Assuming lower DBI is better

    best_metrics[run_name] = {
        'best_nmi_true': best_nmi_true,
        'best_ari_true': best_ari_true,
        'best_silhouette': best_silhouette,
        'best_dbi': best_dbi
    }

best_metrics_df = pd.DataFrame(best_metrics).T
print(best_metrics_df)