In [None]:
import json
import os
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

# Base path to the directories for each seed
base_folder_path = '~/maplecg_nfs_public/watermark_arxiv/main_results/'
# base_folder_path = '../../main_results/'

# Seed numbers you wish to use
seeds = [41, 42, 43]

# Initialize dictionaries to store the 'KnowMem Forget' and 'mia' metrics for each seed
wtm_results = {algo: [] for algo in ["original", "retraining", "finetune", "ga", "gdiff", "KL", "tv", "scrub"]}
knowmem_results = {algo: [] for algo in ["original", "retraining", "finetune", "ga", "gdiff", "KL", "tv", "scrub"]}
mia_results = {algo: [] for algo in ["original", "retraining", "finetune", "ga", "gdiff", "KL", "tv", "scrub"]}
rouge_results = {algo: [] for algo in ["original", "retraining", "finetune", "ga", "gdiff", "KL", "tv", "scrub"]}

matplotlib.rcParams.update({'font.size': 20})


for seed in seeds:
    folder_path = os.path.join(base_folder_path, f'seed_{seed}/results_remove-1class')
    watermarked_folder_path = os.path.join(base_folder_path, f'seed_{seed}/watermarked_results_remove-1class')

    for algo in knowmem_results.keys():
        filepath_knowmem = os.path.join(folder_path, f'eval/knowmem/10/{algo}/aggregated.json')
        filepath_mia = os.path.join(folder_path, f'eval/mia_{algo}.json')
        filepath_rouge = os.path.join(folder_path, f'eval/rouge_{algo}.csv')
        filepath_wtm = os.path.join(watermarked_folder_path, f'watermark_verify/{algo}_q.npy')
        
        # WaterDrum metrics
        try:
            num_last_elements = 1
            data = np.load(filepath_wtm)
            diagonal = np.diagonal(data).tolist()
            wtm_results[algo].append(np.mean(diagonal[-num_last_elements:])) 
            # wtm_results[algo].append(np.mean(diagonal[:-num_last_elements])) 
        except FileNotFoundError:
            print(f"File not found: {filepath_wtm}")

        # KnowMem metrics
        try:
            with open(filepath_knowmem, 'r') as file:
                data = json.load(file)
                knowmem_results[algo].append(data['KnowMem Forget']['mean_rougeL_recall'])
        except FileNotFoundError:
            print(f"File not found: {filepath_knowmem}")

        # MIA metrics
        try:
            with open(filepath_mia, 'r') as file:
                data = json.load(file)
                mia_results[algo].append(data['forget_holdout_Min-40%'])
        except FileNotFoundError:
            print(f"File not found: {filepath_mia}")

        # ROUGE metrics
        try:
            data = pd.read_csv(filepath_rouge)
            rouge_results[algo].append(data['ROUGE Forget'])
        except FileNotFoundError:
            print(f"File not found: {filepath_rouge}")

# Calculate mean and std for each algo
knowmem_stats = {algo: (np.mean(values), np.std(values)) for algo, values in knowmem_results.items()}
mia_stats = {algo: (np.mean(values), np.std(values)) for algo, values in mia_results.items()}
rouge_stats = {algo: (np.mean(values), np.std(values)) for algo, values in rouge_results.items()}
wtm_stats = {algo: (np.mean(values), np.std(values)) for algo, values in wtm_results.items()}

print("Algorithm | ROUGE Mean | ROUGE Std | KnowMem Mean | KnowMem Std | MIA Mean | MIA Std | WaterDrum Mean | WaterDrum Std")
print("-----------------------------------------------------------")
for algo in knowmem_results.keys():
    rouge_mean, rouge_std = rouge_stats[algo]
    knowmem_mean, knowmem_std = knowmem_stats[algo]
    mia_mean, mia_std = mia_stats[algo]
    wtm_mean, wtm_std = wtm_stats[algo]
    print(f"{algo:<10} | {rouge_mean:<12.4f} | {rouge_std:<9.4f} | {knowmem_mean:<12.4f} | {knowmem_std:<9.4f} | {mia_mean:<8.4f} | {mia_std:<6.4f} | {wtm_mean:<8.4f} | {wtm_std:<6.4f}")

In [None]:
def format():
    plt.rcParams.update({
    'font.family': 'serif',
    'axes.labelsize': 24,
    'axes.titlesize': 26,
    'xtick.labelsize': 24,
    'ytick.labelsize': 24,
    'legend.fontsize': 27,
    'grid.linestyle': '--',
    'grid.alpha': 0.5,
    'figure.dpi': 300
})
format()
def plot(mean_data, std_data, save_path=None):
    fig, ax = plt.subplots(figsize=(18, 8))
    methods = mean_data.index
    metrics = mean_data.columns
    legends = [legend_dict[x] for x in methods]

    bar_width = 0.2
    
    x = np.arange(len(methods))
    for i, metric in enumerate(metrics):
        offset = (i - len(metrics) / 2) * bar_width + bar_width / 2
        scores = mean_data[metric].values
        errors = std_data[metric].values
        ax.bar(x + offset, scores, bar_width, label=metric, yerr=errors, 
               edgecolor='black', linewidth=0.6, alpha=0.9, capsize=3)
    
    # ax.set_xlabel('Methods', fontsize=14, labelpad=10)
    ax.set_ylabel('Scores', labelpad=10)
    ax.set_xticks(x)
    ax.set_xticklabels(legends, rotation=0)
    ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.0), ncol=4)
    
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()
    
legend_dict = {
    "original": "Original",
    "retraining": "Retraining",
    "finetune": "Finetune",
    "ga": "GA",
    "gdiff": "GDiff",
    "KL": "KL",
    # "dpo": "dpo",
    "tv": "TV",
    "scrub": "Scrub"
}


# wtm_stats["gdiff"]= (0,0)

# Convert data to DataFrames
mean_data = pd.DataFrame({
    'ROUGE': [rouge_stats[algo][0] for algo in rouge_stats.keys()],
    'KnowMem': [knowmem_stats[algo][0] for algo in knowmem_stats.keys()],
    'MIA': [mia_stats[algo][0] for algo in mia_stats.keys()],
    'WaterDrum': [wtm_stats[algo][0] for algo in wtm_stats.keys()],
}, index=legend_dict.keys())

std_data = pd.DataFrame({
    'ROUGE': [rouge_stats[algo][1] for algo in rouge_stats.keys()],
    'KnowMem': [knowmem_stats[algo][1] for algo in knowmem_stats.keys()],
    'MIA': [mia_stats[algo][1] for algo in mia_stats.keys()],
    'WaterDrum': [wtm_stats[algo][1] for algo in wtm_stats.keys()],
}, index=legend_dict.keys())

# Plot the data
plot(mean_data, std_data)