In [6]:
import os
import re
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Base directory for the experiment results (choose one for ER, SF, NWS)
base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/compare_ER/ER'
#base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/compare_SF/SF'
#base_dir = '/share/amine.mcharrak/cyclic_data_final/experiment_results/compare_NWS/NWS'

# Use the last part of the path as the graph type.
graph_type = base_dir.split('/')[-1]

# Regex to extract d (nodes), type (ER, SF, or NWS), and e (edges) from folder names.
pattern = re.compile(r'(\d+)(ER|SF|NWS)(\d+)_.*')

# Choose the metric.
# metric = "SHD_cyclic"
metric = "CSS"
# metric = "Cycle_KLD"
#metric = "Cycle_F1"

# Define aggregation functions for different metrics.
aggregation_funcs = {
    "CSS": np.min,
    "SHD_cyclic": np.median,
    "Cycle_KLD": np.median,
    "Cycle_F1": np.median
}
aggregation_func = aggregation_funcs.get(metric, np.median)

# Mapping for SF and NWS folder names: edge value -> ratio.
e_to_ratio_mapping = {50: 1, 65: 2, 79: 3, 92: 4, 97: 4}

records = []

# Traverse folders in the base directory.
for folder in os.listdir(base_dir):
    folder_path = os.path.join(base_dir, folder)
    if not os.path.isdir(folder_path):
        continue
    m = pattern.match(folder)
    if not m:
        continue
    d_val = int(m.group(1))
    folder_type = m.group(2)  # "ER", "SF", or "NWS"
    e_val = int(m.group(3))
    
    # Determine ratio.
    if folder_type in ("SF", "NWS"):
        # Use mapping if available; otherwise, fallback to computed ratio.
        ratio = e_to_ratio_mapping.get(e_val, round(e_val / d_val))
    else:
        ratio = round(e_val / d_val)
    
    # Locate JSON file.
    json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
    if not json_files:
        continue
    json_path = os.path.join(folder_path, json_files[0])
    try:
        with open(json_path, 'r') as f:
            results = json.load(f)
    except Exception as ex:
        print(f"Error reading {json_path}: {ex}")
        continue

    # Group results by method.
    method_dict = {}
    for res in results:
        mth = res.get("method")
        val = res.get(metric)
        if val is None:
            continue
        try:
            val = float(val)
        except Exception:
            continue
        if np.isinf(val):
            continue
        # For CSS, transform the value to its logarithm.
        if metric == "CSS":
            if val <= 0:
                continue
            val = np.log(val)
        method_dict.setdefault(mth, []).append(val)

    # Compute aggregated value for each method using the appropriate function.
    for mth, vals in method_dict.items():
        if len(vals) == 0:
            continue
        agg_val = aggregation_func(vals)
        records.append({
            "d": d_val,
            "e": e_val,
            "method": mth,
            "agg_value": agg_val,
            "ratio": ratio
        })

# Create a DataFrame.
df = pd.DataFrame(records)
if df.empty:
    raise ValueError("No valid data found.")

# Determine unique ratio values.
unique_ratios = sorted(df['ratio'].unique())
k = len(unique_ratios)

# Create subplots: one row per unique ratio.
fig, axes = plt.subplots(nrows=k, ncols=1, figsize=(8, 4 * k), sharex=True)
if k == 1:
    axes = [axes]

methods = df['method'].unique()
colors = plt.cm.tab10.colors

# Set y-axis label text without mentioning the aggregation function.
if metric == "CSS":
    y_label = f"log({metric})"
else:
    y_label = metric

# For each ratio, plot d (nodes) on the x-axis and the aggregated value on the y-axis.
for ax, r in zip(axes, unique_ratios):
    df_r = df[df['ratio'] == r]
    for i, mth in enumerate(methods):
        df_mth = df_r[df_r['method'] == mth].sort_values(by='d')
        if df_mth.empty:
            continue
        ax.plot(df_mth['d'], df_mth['agg_value'], marker='o', linestyle='-', 
                color=colors[i % len(colors)], label=mth)
    ax.set_title(f'Ratio (mapped for SF/NWS or computed as e/d) ≈ {r}')
    ax.set_ylabel(y_label)
    ax.grid(True)

axes[-1].set_xlabel('Number of Nodes (d)')
axes[0].legend()
plt.tight_layout()
output_dir = os.path.expanduser('~/pcax/examples/compare_cyclic_plots')
os.makedirs(output_dir, exist_ok=True)
plt.savefig(os.path.join(output_dir, f'compare_cyclic_{graph_type}_{metric}.pdf'))
plt.close()
