In [None]:

import os, pickle
import pandas as pd
import torch
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="whitegrid")


In [None]:

def process_results(results_dir, perturb_type):
    df, flat_pvals = [], []
    rates = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
    run2rate = {i+1: r for i, r in enumerate(rates)}

    for fname in os.listdir(results_dir):
        parts = fname[:-2].split('-')
        run_num = int(parts[2])
        method = parts[3]
        rate = run2rate[run_num]

        results = pickle.load(open(os.path.join(results_dir, fname), 'rb'))
        info = vars(results['args'])
        info['run_num'] = run_num
        info['method'] = method
        info['perturb_type'] = perturb_type
        info['rate'] = rate

        wm_pvals = torch.Tensor([float(p) for p in results['watermark']['pvals']])
        null_pvals = torch.Tensor([float(p) for p in results['null']['pvals']])
        wm_tim = torch.Tensor(results['watermark']['tim'])
        null_tim = torch.Tensor(results['null']['tim'])

        alpha = info['alpha']
        info['power'] = (wm_pvals <= alpha).float().mean().item()
        info['null_rejection_rate'] = (null_pvals <= alpha).float().mean().item()
        info['watermark_median_n_perms'] = wm_tim.median().item()
        info['null_median_n_perms'] = null_tim.median().item()

        df.append(info)

        for p in wm_pvals:
            flat_pvals.append({'rate': rate, 'type': 'watermark', 'pval': float(p), 'perturb': perturb_type, 'method': method})
        for p in null_pvals:
            flat_pvals.append({'rate': rate, 'type': 'null', 'pval': float(p), 'perturb': perturb_type, 'method': method})

    return pd.DataFrame(df), pd.DataFrame(flat_pvals)


In [None]:

df_del, flat_del = process_results("results/exp-del", "deletion")
df_sub, flat_sub = process_results("results/exp-sub", "substitution")
df_ins, flat_ins = process_results("results/exp-ins", "insertion")

df_all = pd.concat([df_del, df_sub, df_ins], ignore_index=True)
flat_all = pd.concat([flat_del, flat_sub, flat_ins], ignore_index=True)


In [None]:

avg_df = df_all.groupby(['method', 'rate']).agg({
    'power': 'mean',
    'null_rejection_rate': 'mean',
    'watermark_median_n_perms': 'mean',
    'null_median_n_perms': 'mean'
}).reset_index()

# Power
plt.figure(figsize=(9, 6))
sns.lineplot(data=avg_df, x='rate', y='power', hue='method', marker='o')
plt.title("Power vs. Perturbation Rate (Averaged Over Attacks)")
plt.xlabel("Perturbation Rate")
plt.ylabel("Power")
plt.grid(True)
plt.tight_layout()
plt.savefig("avg_power_by_method.png", dpi=600)
plt.show()

# Null rejection rate
plt.figure(figsize=(9, 6))
sns.lineplot(data=avg_df, x='rate', y='null_rejection_rate', hue='method', marker='s')
plt.axhline(0.05, linestyle='--', color='black', label='α = 0.05')
plt.title("Null Rejection Rate vs. Perturbation Rate (Averaged Over Attacks)")
plt.xlabel("Perturbation Rate")
plt.ylabel("Null Rejection Rate")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("avg_nullrate_by_method.png", dpi=600)
plt.show()

# Permutations
plt.figure(figsize=(9, 6))
sns.lineplot(data=avg_df, x='rate', y='watermark_median_n_perms', hue='method', marker='o')
sns.lineplot(data=avg_df, x='rate', y='null_median_n_perms', hue='method', marker='x', linestyle='--', legend=False)
plt.title("Median # of Permutations vs. Perturbation Rate (Averaged Over Attacks)")
plt.xlabel("Perturbation Rate")
plt.ylabel("Median Permutations")
plt.grid(True)
plt.tight_layout()
plt.savefig("avg_perms_by_method.png", dpi=600)
plt.show()


In [None]:

avg_flat = flat_all.groupby(['method', 'rate', 'type']).agg({
    'pval': 'median'
}).reset_index()

plt.figure(figsize=(9, 6))
sns.lineplot(data=avg_flat, x='rate', y='pval', hue='method', style='type', marker='o')
plt.title("Median p-value vs. Perturbation Rate (Averaged Over Attacks)")
plt.xlabel("Perturbation Rate")
plt.ylabel("Median p-value")
plt.grid(True)
plt.tight_layout()
plt.savefig("avg_pval_by_method.png", dpi=600)
plt.show()
