In [2]:
%load_ext autoreload
%autoreload 2

import json
import pandas as pd
import os, sys
import wandb
import numpy as np

# Add the directory containing the utils folder to the system path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from utils.testing import *

In [3]:
all_model_runs = access_wandb_runs(filters=None)

# Baseline
baseline_run = None
for run in all_model_runs: # Hacky
    if run.name == "no-ablation-baseline":
        baseline_run = run
        break

all_model_runs = list(access_wandb_runs())
all_model_runs.append(baseline_run)

all_sae_runs = access_wandb_runs(project='ablation-sae')

results_dict = {}

summary_mapping = {
    'IOI Edges': 'acdc_ioi_edges',
    'Loss': 'val.loss_clean',
}

config_mapping = {
    'K-Level': 'k_neurons',
    'Ablation Type': 'ablation_mask_level',
}

sae_metrics = {
    'L0 Norm': 'sparsity.l0',
    'CE Score': 'model_performance_preservation.ce_loss_score',
}

def process_metric(run, metric_name: str, metric_key: str) -> tuple[any, str | None]:
    """
    Process a metric from wandb run summary metrics, handling nested dictionary structures.
    """
    
    if "." in metric_key:
        # Handle nested dictionary structure
        parent_key, child_key = metric_key.split(".")
        if parent_key in run.summaryMetrics:
            return run.summaryMetrics[parent_key].get(child_key)
        return None
    
    return run.summaryMetrics.get(metric_key)

for run in all_model_runs:
    run_name = run.name
    run_results = {}
    
    # Located in summaryMetrics
    for key in summary_mapping.keys():
        metric_key = summary_mapping[key]
        metric_value = process_metric(run, key, metric_key)
        
        run_results[key] = metric_value
            
            
    # Located in config
    for key in config_mapping.keys():
        
        # Check if the key is in the summary
        metric_key = config_mapping[key]
        
        # Convert json_config from string to dictionary
        config = json.loads(run.json_config)
        
        if metric_key in config.keys():
            run_results[key] = run.config[metric_key]
        else:
            run_results[key] = None
            
    # SAE Metrics
    for sae_run in all_sae_runs:
        if run.name == sae_run.name:
            for key in sae_metrics.keys():
                metric_key = sae_metrics[key]
                run_results[key] = process_metric(sae_run, key, metric_key)
    
    results_dict[run_name] = run_results
    

# Convert to dataframe
raw_df = pd.DataFrame.from_dict(results_dict, orient='index')

    
raw_df

Unnamed: 0,IOI Edges,Loss,K-Level,Ablation Type,L0 Norm,CE Score
fanciful-fog-78,51,1.991794,2.0,overall,4.403858,0.699882
earnest-moon-79,57,1.945575,2.0,layer-by-layer,4.920801,0.718273
super-violet-80,27,1.950913,1.0,overall,5.160132,0.672419
upbeat-glitter-83,30,1.945714,1.0,layer-by-layer,5.178028,0.714519
comfy-cherry-84,28,2.003018,4.0,overall,5.152246,0.689363
cosmic-leaf-81-part2,30,1.896251,2.0,overall,4.570484,0.707557
major-planet-86-part3,32,1.944271,4.0,layer-by-layer,4.034986,0.657748
light-morning-92,79,2.024968,8.0,overall,5.274683,0.658187
deft-dew-98-part2,35,1.934347,8.0,layer-by-layer,4.904907,0.738611
solar-vortex-103,36,1.866994,8.0,layer-by-layer,4.996948,0.72199


In [4]:
# Processing for table

# Create a copy of the dataframe
results_df = raw_df.copy()

# Convert K-Level to int, handle NaN values using fillna
results_df['K-Level'] = results_df['K-Level'].fillna(0).astype(int)

# Convert to string, replace 0 with -
results_df['K-Level'] = results_df['K-Level'].apply(lambda x: '-' if x == 0 else str(x))

# Convert Loss to perplexity and round to 2 decimal places
results_df['Loss'] = results_df['Loss'].apply(lambda x: round(np.exp(x), 2))

# Rename loss to perplexity
results_df = results_df.rename(columns={'Loss': 'Perplexity'})

# In Ablation Type, replace None with 'None', overall with 'Global', layer-by-layer with 'Layerwise'
results_df['Ablation Type'] = results_df['Ablation Type'].fillna('Baseline').replace({'overall': 'Global', 'layer-by-layer': 'Layerwise'})

# Round L0 Norm and CE Score to 2 decimal places
results_df['L0 Norm'] = results_df['L0 Norm'].apply(lambda x: round(x, 2))
results_df['CE Score'] = results_df['CE Score'].apply(lambda x: round(x, 2))

results_df

Unnamed: 0,IOI Edges,Perplexity,K-Level,Ablation Type,L0 Norm,CE Score
fanciful-fog-78,51,7.33,2,Global,4.4,0.7
earnest-moon-79,57,7.0,2,Layerwise,4.92,0.72
super-violet-80,27,7.04,1,Global,5.16,0.67
upbeat-glitter-83,30,7.0,1,Layerwise,5.18,0.71
comfy-cherry-84,28,7.41,4,Global,5.15,0.69
cosmic-leaf-81-part2,30,6.66,2,Global,4.57,0.71
major-planet-86-part3,32,6.99,4,Layerwise,4.03,0.66
light-morning-92,79,7.58,8,Global,5.27,0.66
deft-dew-98-part2,35,6.92,8,Layerwise,4.9,0.74
solar-vortex-103,36,6.47,8,Layerwise,5.0,0.72


In [37]:
# Before converting to LateX

# Column order: Ablation Type, K-Level, IOI Edges, L0 Norm, CE Score, Perplexity
latex_df = results_df[['Ablation Type', 'K-Level', 'IOI Edges', 'L0 Norm', 'CE Score', 'Perplexity']]

# Go through each k level and ablation type, if there are mutiple rows with the same k level and ablation type, only keep row with highest CE Score
latex_df = latex_df.sort_values(by=['Perplexity'], ascending=True).drop_duplicates(subset=['K-Level', 'Ablation Type'])

# Sort by K level in descending order but keep the baseline at the top
# Split into baseline and non-baseline
baseline_df = latex_df[latex_df['Ablation Type'] == 'Baseline']
other_df = latex_df[latex_df['Ablation Type'] != 'Baseline']

# Sort non-baseline rows by K-level first, then Ablation Type 
# (Global comes before Layerwise alphabetically)
other_df = other_df.sort_values(by=['K-Level', 'Ablation Type'], ascending=[False, True])

# Concatenate baseline back on top
latex_df = pd.concat([baseline_df, other_df])

latex_df

Unnamed: 0,Ablation Type,K-Level,IOI Edges,L0 Norm,CE Score,Perplexity
no-ablation-baseline,Baseline,-,79,7.22,0.63,5.73
amber-waterfall-109,Global,8,70,5.22,0.65,6.97
solar-vortex-103,Layerwise,8,36,5.0,0.72,6.47
solar-smoke-111,Global,4,41,4.9,0.67,6.73
eager-moon-110,Layerwise,4,30,4.01,0.65,6.58
cosmic-leaf-81-part2,Global,2,30,4.57,0.71,6.66
usual-silence-107,Layerwise,2,54,4.9,0.7,6.55
helpful-dust-112-part2,Global,1,38,5.48,0.66,6.49
distinctive-yogurt-117,Layerwise,1,36,5.33,0.69,6.58


In [41]:
caption = r"Impact of Self-Ablation on Interpretability and Performance. Lower \textit{k}-values in the kWTA mechanism generally improve interpretability, with a minimal increase in perplexity compared to the regular transformer baseline. Arrows indicate the optimization direction: $\downarrow$ denotes that lower values are better."
label = "tab:ablation_results"

def get_best_values(df):
    metrics = {
        'IOI Edges': False,
        'L0 Norm': False,
        'CE Score': True,
        'Perplexity': False
    }
    return {metric: df[metric].max() if higher_is_better else df[metric].min() 
            for metric, higher_is_better in metrics.items()}

def format_value(value, best_value, is_int=False):
    value = int(value) if is_int else value
    return f"\\textbf{{{value}}}" if value == best_value else f"{value}"

best_values = get_best_values(latex_df)

# Build LaTeX table
latex_str = (
    "\\begin{table*}[t]\n"
    "    \\centering\n"
   f"    \\caption{{{caption}}}\n"
   f"    \\label{{{label}}}\n"
    "    \\begin{tabular}{cccccc}\n"  # Added vertical lines between groups
    "    \\toprule\n"
    "    \\multicolumn{2}{c}{Architecture} & \\multicolumn{1}{c}{ACDC} & \\multicolumn{2}{c}{SAE} & \\multicolumn{1}{c}{LM} \\\\\n"
    "    \\cmidrule(r){1-2} \\cmidrule(lr){3-3} \\cmidrule(lr){4-5} \\cmidrule(l){6-6}\n"
    "    Ablation Type & K-Level & IOI Edges $\\downarrow$ & L0 Norm $\\downarrow$ & CE Score $\\uparrow$ & Perplexity $\\downarrow$ \\\\\n"
    "    \\midrule\n"
)

for i, (_, row) in enumerate(latex_df.iterrows()):
    values = [
        row['Ablation Type'],
        row['K-Level'],
        format_value(row['IOI Edges'], best_values['IOI Edges'], is_int=True),
        format_value(row['L0 Norm'], best_values['L0 Norm']),
        format_value(row['CE Score'], best_values['CE Score']),
        format_value(row['Perplexity'], best_values['Perplexity'])
    ]
    latex_str += f"    {' & '.join(map(str, values))} \\\\\n"
    
    if row['Ablation Type'] == 'Baseline':
        latex_str += "    \\midrule\n"
    elif row['Ablation Type'] == 'Layerwise' and i < len(latex_df) - 2:
        latex_str += "    \\midrule\n"

latex_str += "    \\bottomrule\n    \\end{tabular}\n\\end{table*}"

print(latex_str)

\begin{table*}[t]
    \centering
    \caption{Impact of Self-Ablation on Interpretability and Performance. Lower \textit{k}-values in the kWTA mechanism generally improve interpretability, with a minimal increase in perplexity compared to the regular transformer baseline. Arrows indicate the optimization direction: $\downarrow$ denotes that lower values are better.}
    \label{tab:ablation_results}
    \begin{tabular}{cccccc}
    \toprule
    \multicolumn{2}{c}{Architecture} & \multicolumn{1}{c}{ACDC} & \multicolumn{2}{c}{SAE} & \multicolumn{1}{c}{LM} \\
    \cmidrule(r){1-2} \cmidrule(lr){3-3} \cmidrule(lr){4-5} \cmidrule(l){6-6}
    Ablation Type & K-Level & IOI Edges $\downarrow$ & L0 Norm $\downarrow$ & CE Score $\uparrow$ & Perplexity $\downarrow$ \\
    \midrule
    Baseline & - & 79 & 7.22 & 0.63 & \textbf{5.73} \\
    \midrule
    Global & 8 & 70 & 5.22 & 0.65 & 6.97 \\
    Layerwise & 8 & 36 & 5.0 & \textbf{0.72} & 6.47 \\
    \midrule
    Global & 4 & 41 & 4.9 & 0.67 & 6.73 \