In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

# ==============================================================================
# 1. CONFIGURATION
# Please update the paths below with your actual file locations
# ==============================================================================

# The results file generated by the training run
RESULTS_PATH = "wikipedia/experiments/run_YOUR_RUN_ID/results/canary_details_full.csv"

# The original configuration file containing repetition counts
CANARIES_PATH = "memorization/canaries.csv"

# ==============================================================================
# 2. DATA LOADING AND MERGING
# ==============================================================================

# Load Results
if not os.path.exists(RESULTS_PATH):
    print(f"[ERROR] Results file not found: {RESULTS_PATH}")
else:
    df_results = pd.read_csv(RESULTS_PATH)
    print("[INFO] Results loaded successfully.")

# Load Canary Configuration (to get repetition counts)
if not os.path.exists(CANARIES_PATH):
    print(f"[ERROR] Canaries file not found: {CANARIES_PATH}")
else:
    df_config = pd.read_csv(CANARIES_PATH)
    print("[INFO] Canary configuration loaded.")

# MERGE: Join results with repetition info based on 'canary_id'
# We only keep relevant columns from the config
df_config_clean = df_config[['canary_id', 'repetitions', 'type']]
merged_df = pd.merge(df_results, df_config_clean, on='canary_id', how='left')

# Data Cleaning and Formatting for Plotting
# Rename entropy types for better readability
merged_df['Entropy Type'] = merged_df['type'].apply(
    lambda x: 'High Entropy (Random)' if 'high' in x else 'Low Entropy (Natural Lang.)'
)

# Convert repetitions to string format for categorical plotting (e.g., "1x", "5x")
merged_df['Repetitions'] = merged_df['repetitions'].astype(str) + "x"

# Define sort order for the legend
sort_order = ["1x", "5x", "20x"]
merged_df['Repetitions'] = pd.Categorical(merged_df['Repetitions'], categories=sort_order, ordered=True)

print(f"[INFO] Data ready for plotting. Total rows processed: {len(merged_df)}")

# ==============================================================================
# 3. PLOTTING
# ==============================================================================
sns.set_theme(style="whitegrid", context="paper", font_scale=1.2)

# Define a sequential color palette (Light -> Dark) to represent frequency intensity
# 1x = Light, 20x = Dark
freq_palette = sns.color_palette("viridis_r", 3) 

# Create a 3x2 Grid
fig, axes = plt.subplots(3, 2, figsize=(16, 16), sharex=True)

# Map internal metric names to readable titles
metrics_map = {
    'mia_score': 'MIA Recall (Recollection)',
    'counterfactual_score': 'Counterfactual Memorization',
    'contextual_score': 'Contextual Memorization'
}

metrics_list = list(metrics_map.keys())

for i, metric in enumerate(metrics_list):
    # --- Left Column: High Entropy ---
    sns.lineplot(
        data=merged_df[merged_df['Entropy Type'] == 'High Entropy (Random)'],
        x='epoch', y=metric, hue='Repetitions', style='Repetitions',
        markers=True, dashes=False, palette=freq_palette, ax=axes[i, 0],
        linewidth=2.5, markersize=8, err_style="band"
    )
    axes[i, 0].set_title(f"{metrics_map[metric]}\nHigh Entropy", fontsize=14, fontweight='bold')
    axes[i, 0].set_ylabel("Score", fontsize=12)
    
    # --- Right Column: Low Entropy ---
    sns.lineplot(
        data=merged_df[merged_df['Entropy Type'] == 'Low Entropy (Natural Lang.)'],
        x='epoch', y=metric, hue='Repetitions', style='Repetitions',
        markers=True, dashes=False, palette=freq_palette, ax=axes[i, 1],
        linewidth=2.5, markersize=8, err_style="band"
    )
    axes[i, 1].set_title(f"{metrics_map[metric]}\nLow Entropy", fontsize=14, fontweight='bold')
    axes[i, 1].set_ylabel("") # Hide Y-label for cleaner look

# Final Layout Adjustments
for ax in axes.flat:
    ax.set_xlabel("Training Epochs", fontsize=12)
    ax.set_ylim(-0.1, 1.1) # Fixed Y-limits for comparability
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend(title="Frequency", loc='upper left')

plt.tight_layout()
plt.subplots_adjust(top=0.93) # Adjust for suptitle
plt.suptitle("Impact of Exposure Frequency on Memorization Metrics", fontsize=18, fontweight='bold')

# Show and Save
plt.show()
fig.savefig("frequency_analysis_plot.png", dpi=300, bbox_inches='tight')
print("Graph saved as 'frequency_analysis_plot.png'")