In [None]:
import json
import os
import re
import pandas as pd

# Directory with JSON files
input_dir = "../output"

models = [
    "meta-llama/Llama-2-7b-hf",
    "mistralai/Mistral-7B-v0.3",
    "cygu/llama-2-7b-logit-watermark-distill-kgw-k1-gamma0.25-delta2"
]


output_csv = os.path.join(input_dir, "metrics.csv")

filename_pattern = re.compile(
    r"output_.*?"
    r"(?:align=(\d+))?[^a-zA-Z0-9]*"
    r"(?:dataset=([a-zA-Z0-9_\-]+))?[^a-zA-Z0-9]*"
    r"\.json$"
)
rows = []


def parse_row(json_file):
    with open(json_file, "r") as f:
        data = json.load(f)

        config = data["config"]
        metrics = data["metrics"]
        metrics_60 = data.get("metrics_dipper_text_lex60_order0", {})
        metrics_20 = data.get("metrics_dipper_text_lex20_order0", {})
        ppl = data.get("ppl", {})
        watermark = data["watermark"]

        if watermark == "gaussmark":
            config_id = f"watermark=gaussmark_sigma={config['sigma']}_param={config['target_param_name']}"
        elif watermark == "mb":
            config_id = f"align={align}_gamma={config['gamma']}_delta={config['delta']}_nclusters={config['n_clusters']}"
        elif watermark == "mb2":
            config_id = f"watermark=mb2_delta={config['delta']}"
        elif watermark == "mb3":
            config_id = f"watermark=mb3_delta={config['delta']}"
        elif watermark == "distilled":
            config_id = f"watermark=distilled"

        row = {
            "config_id": config_id,
            "dataset": dataset,
            "n_clusters": config.get("n_clusters"),
            "auroc": metrics.get("auroc"),
            "best_f1_score": metrics.get("best_f1_score"),
            "tpr_1_fpr": metrics.get("tpr_1_fpr"),
            "tpr_0.1_fpr": metrics.get("tpr_0.1_fpr"),
            "auroc_lex60": metrics_60.get("auroc"),
            "f1_lex60": metrics_60.get("best_f1_score"),
            "tpr1_lex60": metrics_60.get("tpr_1_fpr"),
            "auroc_lex20": metrics_20.get("auroc"),
            "f1_lex20": metrics_20.get("best_f1_score"),
            "tpr1_lex20": metrics_20.get("tpr_1_fpr"),
            "ppl_mean": ppl.get("mean"),
            "mean_seq_rep_3": data.get("mean_seq_rep_3"),
            "model": model_suffix,
        }
        return row


for model in models:
    model_suffix = model.split("/")[-1]
    model_dir = os.path.join(input_dir, model_suffix)
    json_files = [f for f in os.listdir(model_dir) if f.endswith(".json")]
    for json_file in json_files:
        match = filename_pattern.match(json_file)
        if not match:
            continue

        align = int(match.group(1)) if match.group(1) is not None else None
        dataset = match.group(2) if match.group(2) is not None else None
        filepath = os.path.join(model_dir, json_file)
        try:
            row = parse_row(filepath)
            rows.append(row)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")

# Convert to DataFrame
df = pd.DataFrame(rows)

# Group by config_id and average all metric columns
grouped_df = df.groupby(["config_id", "dataset", "model"]).mean(
    numeric_only=True).reset_index()

# Write to CSV
grouped_df.to_csv(output_csv, index=False)
print(f"Saved grouped averages to: {output_csv}")