In [2]:
import os
import json
import glob
import re
from typing import Dict, List

import torch
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import chart_studio
from chart_studio import plotly as py

from IPython.display import display, clear_output, HTML
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    compute_weighted_jaccard_similarity,
    compute_weighted_jaccard_similarity_to_reference,
    compute_ewma_weighted_jaccard_similarity,
)

from utils.component_evaluation import plot_head_circuit_scores

from utils.visualization import plot_graph_metric

In [3]:
#!pip install chart_studio nltk

In [4]:
# ignore the following if not using chart_studio. if you do want to publish graphs, simple include upload=True in the plotly function

# load API key from local file chart_studio_api_key.txt - should be username and api key separated by a comma
with open("chart_studio_api_key.txt") as f:
    username, api_key = f.read().strip().split(",")
    # strip leading whitespace
    username = username.strip()
    api_key = api_key.strip()

chart_studio.tools.set_credentials_file(username=username, api_key=api_key)

In [55]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
MODEL_NAME = 'pythia-160m-seed1'

## IOI Graph Metrics

In [59]:
folder_path = f'results/graphs/{MODEL_NAME}/{TASK}'
df = load_edge_scores_into_dictionary(folder_path)
df = df[df['checkpoint'] >= 4000]

perf_metrics = torch.load(f'results/backup/{MODEL_NAME}/nmh_backup_metrics.pt')
perf_metric_dict = {checkpoint: perf_metrics[checkpoint]['logit_diff'] for checkpoint in perf_metrics.keys()}
clear_output()

In [60]:
# Summarize the total number of edges in the graph at each checkpoint, filtered by in_circuit
subgraph_df = df[df['in_circuit'] == True]

# Group by checkpoint and sum the number of edges
subgraph_df = subgraph_df.groupby('checkpoint').size().reset_index(name='num_edges')

### Graph Size

In [61]:
plot_graph_metric(subgraph_df, 'num_edges', perf_metric_dict, f'Graph Size for {MODEL_NAME}', y_range=1000, x_axis_col='checkpoint', log_x=True)

### Graph Similarity

In [62]:
weighted_jaccard_results = compute_weighted_jaccard_similarity(df)
plot_graph_metric(weighted_jaccard_results, 'jaccard_similarity', perf_metric_dict, f'Graph Size for {MODEL_NAME}', y_range=1, x_axis_col='checkpoint_2', log_x=True)

In [63]:
comparison_checkpoint = 5000 
jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(df, comparison_checkpoint)
plot_graph_metric(jaccard_reference_results, 'jaccard_similarity', perf_metric_dict, f'Weighted Jaccard Similarity to Checkpoint {comparison_checkpoint} for {MODEL_NAME}', y_range=1, x_axis_col='checkpoint', log_x=True)

In [64]:
comparison_checkpoint = 143000

jaccard_reference_results = compute_weighted_jaccard_similarity_to_reference(df, comparison_checkpoint)
plot_graph_metric(jaccard_reference_results, 'jaccard_similarity', perf_metric_dict, f'Weighted Jaccard Similarity to Checkpoint {comparison_checkpoint} for {MODEL_NAME}', y_range=1, x_axis_col='checkpoint', log_x=True)

In [65]:
jaccard_reference_results = compute_ewma_weighted_jaccard_similarity(df, alpha=0.1)
plot_graph_metric(jaccard_reference_results, 'ewma_change_rate', perf_metric_dict, f'Exponential Weighted Average Graph Change Rate for {MODEL_NAME}', y_range=1, x_axis_col='checkpoint_2', log_x=True)

## Circuit Components

In [57]:
components_over_time = torch.load(f'results/components/{MODEL_NAME}/components_over_time.pt')
heads_over_time = torch.load(f'results/components/{MODEL_NAME}/heads_over_time.pt')

In [58]:
ckpts = list(components_over_time.keys())
ckpts.sort()
#ckpts

### NMH Metrics

In [59]:
copy_scores = dict()
filtered_copy_scores = dict()
io_attns = dict()
io_s1_attn_ratio = dict()
copy_suppression_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['direct_effect_scores'] is not None:
        copy_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_scores']
        filtered_copy_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_scores']
        io_attns[ckpt] = components_over_time[ckpt]['direct_effect_scores']['io_attn_scores']
        io_s1_attn_ratio[ckpt] = components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] / components_over_time[ckpt]['direct_effect_scores']['s1_attn_scores']
        copy_suppression_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_suppression_scores']


#io_attns = {ckpt: components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] for ckpt in ckpts}
#io_s1_attn_ratio = {ckpt: components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] / components_over_time[ckpt]['direct_effect_scores']['s1_attn_scores'] for ckpt in ckpts}

In [60]:
all_nmh = set()
for ckpt in ckpts:
    all_nmh.update(heads_over_time[ckpt]['nmh'])

In [61]:
# note that most, but not all of these are formally 'NMHs'; if attention to S1 exceeds attention to IO, they are not NMHs
all_heads_copy_score = plot_head_circuit_scores(MODEL_NAME, copy_scores, show_legend=False, title= f'Copy Score Across Checkpoints ({MODEL_NAME})', disable_title=True)


In [62]:
all_heads_io_attn = plot_head_circuit_scores(MODEL_NAME, io_s1_attn_ratio, title= f'IO:S1 Attn Ratio Across Checkpoints ({MODEL_NAME})', limit_to_list=all_nmh, range_y=[0, 20])

### Copy Suppression Metrics

In [63]:
copy_suppression_scores_df = plot_head_circuit_scores(MODEL_NAME, copy_suppression_scores, show_legend=False, title= f'Copy Suppression Scores Across Checkpoints ({MODEL_NAME})', disable_title=True)

### S2I Metrics

In [64]:
components_over_time[143000]['s2i_scores'].keys()

KeyError: 143000

In [None]:
components_over_time[143000]['s2i_scores']['s2i_ablated_logit_diff_deltas'].keys() #['copy_scores']

dict_keys(['token_same_pos_oppo', 'token_oppo_pos_same', 'token_oppo_pos_oppo'])

In [None]:
pos_signal_importance = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['s2i_scores'] is not None:
        pos_signal_importance[ckpt] = components_over_time[ckpt]['s2i_scores']['s2i_ablated_logit_diff_deltas']['token_same_pos_oppo']

In [65]:
pos_signal_df = plot_head_circuit_scores(MODEL_NAME, pos_signal_importance, show_legend=False, title= f'S2I Pos Signal Ablation Logit Diff Change % Across Checkpoints ({MODEL_NAME})', disable_title=True)

In [66]:
pos_signal_io_attn_change = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['s2i_scores'] is not None:
        pos_signal_io_attn_change[ckpt] = components_over_time[ckpt]['s2i_scores']['s2i_io_attention_deltas']['token_same_pos_oppo']

In [67]:
pos_signal_io_attn_df = plot_head_circuit_scores(MODEL_NAME, pos_signal_io_attn_change, show_legend=False, title= f'Effect of S2I Pos Signal Ablation On NMH IO Attn ({MODEL_NAME})', disable_title=True)

### Tertiary Component Scores

In [69]:
components_over_time[137000]['tertiary_head_scores'].keys()

dict_keys(['induction_scores', 'prev_token_scores', 'duplicate_token_scores'])

In [70]:
induction_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['tertiary_head_scores'] is not None:
        induction_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['induction_scores']

In [71]:
induction_df = plot_head_circuit_scores(MODEL_NAME, induction_scores, show_legend=False, title= f'Induction Scores Across Checkpoints ({MODEL_NAME})', disable_title=True)

In [72]:
prev_token_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['tertiary_head_scores'] is not None:
        prev_token_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['prev_token_scores']

In [73]:
prev_token_df = plot_head_circuit_scores(MODEL_NAME, prev_token_scores, show_legend=False, title= f'Prev Token Scores Across Checkpoints ({MODEL_NAME})', disable_title=True)

In [74]:
duplicate_token_scores = dict()
for ckpt in ckpts:
    if components_over_time[ckpt]['tertiary_head_scores'] is not None:
        duplicate_token_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['duplicate_token_scores']

In [75]:
duplicate_token_df = plot_head_circuit_scores(MODEL_NAME, duplicate_token_scores, show_legend=False, title= f'Duplicate Token Scores Across Checkpoints ({MODEL_NAME})', disable_title=True)