## Setup

### Imports

In [1]:
import os
import json
import glob
import torch
import re
import pandas as pd

from transformers import AutoModelForCausalLM, AutoTokenizer

import transformer_lens
import transformer_lens.utils as tl_utils
from transformer_lens import HookedTransformer
import transformer_lens.patching as patching

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
#import seaborn as sns
import matplotlib.pyplot as plt
from utils.data_utils import generate_data_and_caches
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    read_json_file,
    get_ckpts,
    load_metrics,
    compute_ged,
    compute_weighted_ged,
    compute_gtd,
    compute_jaccard_similarity_to_reference,
    compute_jaccard_similarity,
    aggregate_metrics_to_tensors_step_number,
    get_ckpts
)

### Parameters

In [87]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
BASE_MODEL = "pythia-160m"
VARIANT = None
CACHE = "model_cache"
CHECKPOINT = 10000
torch.set_grad_enabled(False)

IOI_DATASET_SIZE = 70

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

### Functions

In [None]:
def convert_head_names_to_tuple(head_name):
    head_name = head_name.replace('a', '')
    head_name = head_name.replace('h', '')
    layer, head = head_name.split('.')
    return (int(layer), int(head))

In [90]:
def check_copy_circuit(model, layer, head, ioi_dataset, verbose=False, neg=False):
    
    # get the activation cache for the first layer from IOI dataset
    logits, cache = model.run_with_cache(ioi_dataset.toks.long())
    
    # sign adjustment, optional
    if neg:
        sign = -1
    else:
        sign = 1

    # pass the activations through the first layernorm for block 1
    #z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])
    z_0 = cache["blocks.0.hook_resid_post"]

    # pass the activations through the attention weights (values) for the head
    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    # add the bias
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    # pass the activations through the attention weights (output only) for the head
    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])

    # pass the activations through the final layernorm
    logits = model.unembed(o)

    k = 5
    n_right = 0

    for seq_idx, prompt in enumerate(ioi_dataset.ioi_prompts):
        for word in ["IO", "S1", "S2"]:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, ioi_dataset.word_idx[word][seq_idx]], k
                ).indices
            ]
            if "S" in word:
                name = "S"
            else:
                name = word
            if " " + prompt[name] in pred_tokens:
                n_right += 1
            else:
                if verbose:
                    print("-------")
                    print("Seq: " + ioi_dataset.sentences[seq_idx])
                    print("Target: " + ioi_dataset.ioi_prompts[seq_idx][name])
                    print(
                        " ".join(
                            [
                                f"({i+1}):{model.tokenizer.decode(token)}"
                                for i, token in enumerate(
                                    torch.topk(
                                        logits[
                                            seq_idx, ioi_dataset.word_idx[word][seq_idx]
                                        ],
                                        k,
                                    ).indices
                                )
                            ]
                        )
                    )
    percent_right = (n_right / (ioi_dataset.N * 3)) * 100
    print(
        f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%"
    )
    return percent_right

In [101]:
def load_model(BASE_MODEL, VARIANT, CHECKPOINT, CACHE, device):
    if not VARIANT:
        model = HookedTransformer.from_pretrained(
            BASE_MODEL,
            checkpoint_value=CHECKPOINT,
            center_unembed=True,
            center_writing_weights=True,
            fold_ln=True,
            refactor_factored_attn_matrices=False,
            dtype=torch.bfloat16,
            **{"cache_dir": CACHE},
        )
    else:
        revision = f"step{CHECKPOINT}"
        source_model = AutoModelForCausalLM.from_pretrained(
           VARIANT, revision=revision, cache_dir=CACHE
        ).to(device).to(torch.bfloat16)

        model = HookedTransformer.from_pretrained(
            BASE_MODEL,
            hf_model=source_model,
            center_unembed=False,
            center_writing_weights=False,
            fold_ln=False,
            dtype=torch.bfloat16,
            **{"cache_dir": CACHE},
        )

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    return model

## Retrieve & Process Data

### Circuit Data

In [66]:
folder_path = f'results/graphs/pythia-160m/{TASK}'
df = load_edge_scores_into_dictionary(folder_path)

Processing file 1/153: results/graphs/pythia-160m/ioi/57000.json
Processing file 2/153: results/graphs/pythia-160m/ioi/141000.json
Processing file 3/153: results/graphs/pythia-160m/ioi/95000.json
Processing file 4/153: results/graphs/pythia-160m/ioi/107000.json
Processing file 5/153: results/graphs/pythia-160m/ioi/34000.json
Processing file 6/153: results/graphs/pythia-160m/ioi/6000.json
Processing file 7/153: results/graphs/pythia-160m/ioi/37000.json
Processing file 8/153: results/graphs/pythia-160m/ioi/39000.json
Processing file 9/153: results/graphs/pythia-160m/ioi/104000.json
Processing file 10/153: results/graphs/pythia-160m/ioi/59000.json
Processing file 11/153: results/graphs/pythia-160m/ioi/67000.json
Processing file 12/153: results/graphs/pythia-160m/ioi/111000.json
Processing file 13/153: results/graphs/pythia-160m/ioi/16.json
Processing file 14/153: results/graphs/pythia-160m/ioi/76000.json
Processing file 15/153: results/graphs/pythia-160m/ioi/1.json
Processing file 16/153:

### Performance Data

In [92]:
directory_path = 'results'
perf_metrics = load_metrics(directory_path)

ckpts = get_ckpts(schedule="exp_plus_detail")
#pythia_evals = aggregate_metrics_to_tensors_step_number("results/pythia-evals/pythia-v1")

# filter everything before 1000 steps
df = df[df['checkpoint'] >= 1000]

df[['source', 'target']] = df['edge'].str.split('->', expand=True)
len(df['target'].unique())

445

In [93]:
perf_metric = perf_metrics['pythia-160m'][TASK][PERFORMANCE_METRIC]

perf_metric = [x.item() for x in perf_metric]

# zip into dictionary with ckpts as key
perf_metric_dict = dict(zip(ckpts, perf_metric))


## Experiments

### Dataset Setup

In [102]:
initial_model = load_model(BASE_MODEL, VARIANT, 143000, CACHE, device)
size=70
ioi_dataset, abc_dataset = generate_data_and_caches(initial_model, size, verbose=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer


### Get Experimental Candidates

In [107]:
EXPERIMENTAL_CHECKPOINT = 80000

In [108]:
experimental_model = load_model(BASE_MODEL, VARIANT, EXPERIMENTAL_CHECKPOINT, CACHE, device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer


In [109]:
candidate_nmh = df[df['target']=='logits']
candidate_nmh = candidate_nmh[candidate_nmh['in_circuit'] == True]

candidate_list = candidate_nmh[candidate_nmh['checkpoint']==EXPERIMENTAL_CHECKPOINT]['source'].unique().tolist()
candidate_list = [convert_head_names_to_tuple(c) for c in candidate_list if (c[0] != 'm' and c != 'input')]

In [110]:
for layer, head in candidate_list:
    copy_score = check_copy_circuit(experimental_model, layer, head, ioi_dataset, verbose=False, neg=False)

Copy circuit for head 9.4 (sign=1) : Top 5 accuracy: 20.0%
Copy circuit for head 8.10 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 8.9 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 8.2 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 9.6 (sign=1) : Top 5 accuracy: 22.380952380952383%
Copy circuit for head 9.7 (sign=1) : Top 5 accuracy: 45.714285714285715%
Copy circuit for head 9.8 (sign=1) : Top 5 accuracy: 13.80952380952381%
Copy circuit for head 11.10 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 10.11 (sign=1) : Top 5 accuracy: 17.61904761904762%
Copy circuit for head 10.7 (sign=1) : Top 5 accuracy: 70.47619047619048%
Copy circuit for head 10.1 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 9.9 (sign=1) : Top 5 accuracy: 10.476190476190476%
Copy circuit for head 7.2 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 6.5 (sign=1) : Top 5 accuracy: 0.0%
