In [1]:
import os
import re
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set base_dir to the increase_d folder for one graph type.
# For example, choose one of:
#base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_d/ER'
#base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_d/SF'
base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/increase_d/NWS'
graph_type = base_dir.split('/')[-1]

# Regex to extract d (nodes), type (ER, SF, or NWS), and e (edges) from folder names.
pattern = re.compile(r'(\d+)(ER|SF|NWS)(\d+)_.*')

# Choose the metric.
metric = "SHD_cyclic"
#metric = "CSS"
#metric = "Cycle_KLD"
#metric = "Cycle_F1"

# Define aggregation functions for different metrics.
aggregation_funcs = {
    "CSS": np.median,
    "SHD_cyclic": np.median,
    "Cycle_KLD": np.median,
    "Cycle_F1": np.median
}
aggregation_func = aggregation_funcs.get(metric, np.median)

records = []

# Each subfolder corresponds to one seed for a given d value.
for folder in os.listdir(base_dir):
    folder_path = os.path.join(base_dir, folder)
    if not os.path.isdir(folder_path):
        continue
    m = pattern.match(folder)
    if not m:
        continue
    # Extract d value from the folder name.
    d_val = int(m.group(1))
    # (folder_type and e_val are available if needed)
    folder_type = m.group(2)  # "ER", "SF", or "NWS"
    e_val = int(m.group(3))
    
    # Locate JSON file.
    json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
    if not json_files:
        continue
    json_path = os.path.join(folder_path, json_files[0])
    try:
        with open(json_path, 'r') as f:
            results = json.load(f)
    except Exception as ex:
        print(f"Error reading {json_path}: {ex}")
        continue

    # Group results by method.
    method_dict = {}
    for res in results:
        mth = res.get("method")
        val = res.get(metric)
        if val is None:
            continue
        try:
            val = float(val)
        except Exception:
            continue
        if np.isinf(val):
            continue
        # For CSS, use the logarithm.
        if metric == "CSS":
            if val <= 0:
                continue
            val = np.log(val)
        method_dict.setdefault(mth, []).append(val)

    # Compute aggregated value for each method using the chosen aggregation function.
    for mth, vals in method_dict.items():
        if not vals:
            continue
        agg_val = aggregation_func(vals)
        records.append({
            "d": d_val,
            "method": mth,
            "agg_value": agg_val
        })

# Create a DataFrame.
df = pd.DataFrame(records)
if df.empty:
    raise ValueError("No valid data found.")

# Group across seed folders: for each d and method, aggregate the values.
df_grouped = df.groupby(['d', 'method'])['agg_value'].agg(aggregation_func).reset_index()

# Determine unique d values.
unique_ds = sorted(df_grouped['d'].unique())

# Create plot: x-axis is d and y-axis is the aggregated metric.
fig, ax = plt.subplots(figsize=(8, 6))
methods = df_grouped['method'].unique()
colors = plt.cm.tab10.colors

for i, mth in enumerate(methods):
    df_mth = df_grouped[df_grouped['method'] == mth].sort_values(by='d')
    ax.plot(df_mth['d'], df_mth['agg_value'], marker='o', linestyle='-', 
            color=colors[i % len(colors)], label=mth)

# Set axis labels.
if metric == "CSS":
    y_label = f"log({metric})"
else:
    y_label = metric
ax.set_xlabel("Number of Nodes (d)")
ax.set_ylabel(y_label)
ax.set_title(f"{metric} vs Number of Nodes for {graph_type}")
ax.grid(True)
ax.legend()

plt.tight_layout()
output_dir = os.path.expanduser('~/pcax/examples/compare_cyclic_plots')
os.makedirs(output_dir, exist_ok=True)
plt.savefig(os.path.join(output_dir, f'compare_cyclic_increase_d_{graph_type}_{metric}.pdf'))
plt.close()
