In [None]:
import os
import re
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set base_dir to the increase_n folder for one graph type.
# Example for ER:
#base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_n/ER'
#base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_n/SF'
base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_n/NWS'

# Alternatively, set base_dir to:
#   '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_n/SF'
#   '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_n/NWS'
graph_type = base_dir.split('/')[-1]

# Regex to extract sample size from JSON file names.
n_pattern = re.compile(r'experiment_results_n(\d+)\.json')

# Choose the metric.
#metric = "SHD_cyclic"
#metric = "CSS"
#metric = "Cycle_KLD"
metric = "Cycle_F1"

# Define aggregation functions for different metrics.
# For CSS, SHD_cyclic, Cycle_KLD: use np.median; for Cycle_F1: use np.max.
aggregation_funcs = {
    "CSS": np.median,
    "SHD_cyclic": np.median,
    "Cycle_KLD": np.median,
    "Cycle_F1": np.median
}
aggregation_func = aggregation_funcs.get(metric, np.median)

records = []

# Each folder in base_dir corresponds to one seed.
for folder in os.listdir(base_dir):
    folder_path = os.path.join(base_dir, folder)
    if not os.path.isdir(folder_path):
        continue
    # For each JSON file (one per sample size) in this seed folder:
    for fname in os.listdir(folder_path):
        if not fname.endswith('.json'):
            continue
        m_n = n_pattern.match(fname)
        if not m_n:
            continue
        sample_size = int(m_n.group(1))
        json_path = os.path.join(folder_path, fname)
        try:
            with open(json_path, 'r') as f:
                results = json.load(f)
        except Exception as ex:
            print(f"Error reading {json_path}: {ex}")
            continue

        # Group results by method.
        method_dict = {}
        for res in results:
            mth = res.get("method")
            val = res.get(metric)
            if val is None:
                continue
            try:
                val = float(val)
            except Exception:
                continue
            if np.isinf(val):
                continue
            # For CSS, use the logarithm.
            if metric == "CSS":
                if val <= 0:
                    continue
                val = np.log(val)
            method_dict.setdefault(mth, []).append(val)
        
        # Compute aggregated value for each method using the chosen aggregation function.
        for mth, vals in method_dict.items():
            if len(vals) == 0:
                continue
            agg_val = aggregation_func(vals)
            records.append({
                "n": sample_size,
                "method": mth,
                "agg_value": agg_val
            })

# Create a DataFrame.
df = pd.DataFrame(records)
if df.empty:
    raise ValueError("No valid data found.")

# Group across seed folders: For each sample size and method, aggregate the agg_values.
df_grouped = df.groupby(['n', 'method'])['agg_value'].agg(aggregation_func).reset_index()

# Determine unique sample sizes.
unique_ns = sorted(df_grouped['n'].unique())

# Create the plot: x-axis is sample size and y-axis is the aggregated metric.
fig, ax = plt.subplots(figsize=(8, 6))
methods = df_grouped['method'].unique()
colors = plt.cm.tab10.colors

for i, mth in enumerate(methods):
    df_mth = df_grouped[df_grouped['method'] == mth].sort_values(by='n')
    ax.plot(df_mth['n'], df_mth['agg_value'], marker='o', linestyle='-', 
            color=colors[i % len(colors)], label=mth)

# Set axis labels.
if metric == "CSS":
    y_label = f"log({metric})"
else:
    y_label = metric

ax.set_xlabel("Sample Size (n)")
ax.set_ylabel(y_label)
ax.set_title(f"{metric} vs Sample Size for {graph_type}")
ax.grid(True)
ax.legend()

plt.tight_layout()
output_dir = os.path.expanduser('~/pcax/examples/compare_cyclic_plots')
os.makedirs(output_dir, exist_ok=True)
plt.savefig(os.path.join(output_dir, f'compare_cyclic_increase_n_{graph_type}_{metric}.pdf'))
plt.close()
