In [1]:
import os
import json
import glob
import re
import sys
from typing import Dict, List

import torch
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import chart_studio
from chart_studio import plotly as py

from IPython.display import display, clear_output, HTML

sys.path.append("..")
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    compute_weighted_jaccard_similarity,
    compute_weighted_jaccard_similarity_to_reference,
    compute_ewma_weighted_jaccard_similarity,
    generate_in_circuit_df_files
)
from utils.result_plotting import plot_head_circuit_scores, plot_graph_metric

In [2]:
# ignore the following if not using chart_studio. if you do want to publish graphs, simple include upload=True in the plotly function

# load API key from local file chart_studio_api_key.txt - should be username and api key separated by a comma
with open("../auth/chart_studio_api_key.txt") as f:
    username, api_key = f.read().strip().split(",")
    # strip leading whitespace
    username = username.strip()
    api_key = api_key.strip()

chart_studio.tools.set_credentials_file(username=username, api_key=api_key)

In [22]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
MODEL_NAME = 'pythia-410m'
OUTPUT_DIR = f"../results/plots/component_metrics/{TASK}/"

# create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

## Circuit Components

In [23]:
components_over_time = torch.load(f'../results/components/{MODEL_NAME}/components_over_time.pt')
heads_over_time = torch.load(f'../results/components/{MODEL_NAME}/heads_over_time.pt')

In [24]:
ckpts = list(components_over_time.keys())
ckpts.sort()
#ckpts

In [25]:
copy_scores = dict()
filtered_copy_scores = dict()
io_attns = dict()
io_s1_attn_ratio = dict()
copy_suppression_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['direct_effect_scores'] is not None:
        copy_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_scores']
        filtered_copy_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_scores']
        io_attns[ckpt] = components_over_time[ckpt]['direct_effect_scores']['io_attn_scores']
        io_s1_attn_ratio[ckpt] = components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] / components_over_time[ckpt]['direct_effect_scores']['s1_attn_scores']
        copy_suppression_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_suppression_scores']


#io_attns = {ckpt: components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] for ckpt in ckpts}
#io_s1_attn_ratio = {ckpt: components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] / components_over_time[ckpt]['direct_effect_scores']['s1_attn_scores'] for ckpt in ckpts}

### NMH Metrics

In [26]:
all_nmh = set()
for ckpt in ckpts:
    all_nmh.update(heads_over_time[ckpt]['nmh'])

In [27]:

all_heads_copy_score = plot_head_circuit_scores(
    copy_scores, 
    show_legend=False, 
    title= f'Copy Score Across Checkpoints ({MODEL_NAME})',
    log_x=True, 
    disable_title=True,
    output_path = OUTPUT_DIR
)

In [28]:
all_heads_io_attn = plot_head_circuit_scores(
    io_s1_attn_ratio, 
    title= f'IO:S1 Attn Ratio Across Checkpoints ({MODEL_NAME})', 
    limit_to_list=all_nmh, 
    range_y=[0, 20],
    log_x=True,
    output_path = OUTPUT_DIR
)

### Copy Suppression Metrics

In [41]:
copy_suppression_scores_df = plot_head_circuit_scores(
    copy_suppression_scores, 
    show_legend=False, 
    title= f'Copy Suppression Scores Across Checkpoints ({MODEL_NAME})',
    log_x=True, 
    disable_title=True,
    output_path = OUTPUT_DIR
)

### S2I Metrics

In [42]:
pos_signal_importance = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['s2i_scores'] is not None:
        pos_signal_importance[ckpt] = components_over_time[ckpt]['s2i_scores']['s2i_ablated_logit_diff_deltas']['token_same_pos_oppo']

In [43]:
pos_signal_df = plot_head_circuit_scores(
    pos_signal_importance, 
    show_legend=False, 
    title= f'S2I Pos Signal Ablation Logit Diff Change % Across Checkpoints ({MODEL_NAME})', 
    log_x=True,
    disable_title=True,
    output_path = OUTPUT_DIR
)

In [44]:
pos_signal_io_attn_change = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['s2i_scores'] is not None:
        pos_signal_io_attn_change[ckpt] = components_over_time[ckpt]['s2i_scores']['s2i_io_attention_deltas']['token_same_pos_oppo']

In [45]:
pos_signal_io_attn_df = plot_head_circuit_scores(
    pos_signal_io_attn_change, 
    show_legend=False, 
    title= f'Effect of S2I Pos Signal Ablation On NMH IO Attn ({MODEL_NAME})', 
    log_x=True,
    disable_title=True,
    output_path = OUTPUT_DIR
)

### Tertiary Component Scores

In [46]:
components_over_time[137000]['tertiary_head_scores'].keys()

dict_keys(['induction_scores', 'prev_token_scores', 'duplicate_token_scores'])

In [47]:
induction_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['tertiary_head_scores'] is not None:
        induction_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['induction_scores']

In [48]:
induction_df = plot_head_circuit_scores(
    induction_scores, 
    show_legend=False, 
    title= f'Induction Scores Across Checkpoints ({MODEL_NAME})', 
    log_x=True,
    disable_title=True,
    output_path = OUTPUT_DIR
)

In [49]:
prev_token_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['tertiary_head_scores'] is not None:
        prev_token_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['prev_token_scores']

In [50]:
prev_token_df = plot_head_circuit_scores(
    prev_token_scores, 
    show_legend=False, 
    title= f'Prev Token Scores Across Checkpoints ({MODEL_NAME})', 
    log_x=True,
    disable_title=True,
    output_path = OUTPUT_DIR
)

In [51]:
duplicate_token_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['tertiary_head_scores'] is not None:
        duplicate_token_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['duplicate_token_scores']

In [52]:
duplicate_token_df = plot_head_circuit_scores(
    duplicate_token_scores, 
    show_legend=False, 
    title= f'Duplicate Token Scores Across Checkpoints ({MODEL_NAME})', 
    log_x=True,
    disable_title=True,
    output_path = OUTPUT_DIR
)