## Setup

### Imports

In [184]:
import os
import json
import glob
import torch
import re
import einops
import pandas as pd
from functools import partial
from torch import Tensor
from torchtyping import TensorType as TT
from jaxtyping import Float

from transformers import AutoModelForCausalLM

import transformer_lens
import transformer_lens.utils as tl_utils
from transformer_lens import HookedTransformer, ActivationCache
import transformer_lens.patching as patching

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
#import seaborn as sns
import matplotlib.pyplot as plt
from utils.data_utils import generate_data_and_caches
from utils.data_processing import (
    load_edge_scores_into_dictionary,
    read_json_file,
    get_ckpts,
    load_metrics,
    compute_ged,
    compute_weighted_ged,
    compute_gtd,
    compute_jaccard_similarity_to_reference,
    compute_jaccard_similarity,
    aggregate_metrics_to_tensors_step_number,
    get_ckpts
)
from utils.metrics import compute_logit_diff, _logits_to_mean_logit_diff
from utils.visualization import plot_attention_heads, imshow_p

### Parameters

In [137]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
BASE_MODEL = "pythia-160m"
VARIANT = None
CACHE = "model_cache"
IOI_DATASET_SIZE = 70

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa31690bc10>

### Functions

In [None]:
def convert_head_names_to_tuple(head_name):
    head_name = head_name.replace('a', '')
    head_name = head_name.replace('h', '')
    layer, head = head_name.split('.')
    return (int(layer), int(head))

In [90]:
def check_copy_circuit(model, layer, head, ioi_dataset, verbose=False, neg=False):
    
    # get the activation cache for the first layer from IOI dataset
    logits, cache = model.run_with_cache(ioi_dataset.toks.long())
    
    # sign adjustment, optional
    if neg:
        sign = -1
    else:
        sign = 1

    # pass the activations through the first layernorm for block 1
    #z_0 = model.blocks[1].attn.ln1(cache["blocks.0.hook_resid_post"])
    z_0 = cache["blocks.0.hook_resid_post"]

    # pass the activations through the attention weights (values) for the head
    v = torch.einsum("eab,bc->eac", z_0, model.blocks[layer].attn.W_V[head])
    # add the bias
    v += model.blocks[layer].attn.b_V[head].unsqueeze(0).unsqueeze(0)

    # pass the activations through the attention weights (output only) for the head
    o = sign * torch.einsum("sph,hd->spd", v, model.blocks[layer].attn.W_O[head])

    # pass the activations through the final layernorm
    logits = model.unembed(o)

    k = 5
    n_right = 0

    for seq_idx, prompt in enumerate(ioi_dataset.ioi_prompts):
        for word in ["IO", "S1", "S2"]:
            pred_tokens = [
                model.tokenizer.decode(token)
                for token in torch.topk(
                    logits[seq_idx, ioi_dataset.word_idx[word][seq_idx]], k
                ).indices
            ]
            if "S" in word:
                name = "S"
            else:
                name = word
            if " " + prompt[name] in pred_tokens:
                n_right += 1
            else:
                if verbose:
                    print("-------")
                    print("Seq: " + ioi_dataset.sentences[seq_idx])
                    print("Target: " + ioi_dataset.ioi_prompts[seq_idx][name])
                    print(
                        " ".join(
                            [
                                f"({i+1}):{model.tokenizer.decode(token)}"
                                for i, token in enumerate(
                                    torch.topk(
                                        logits[
                                            seq_idx, ioi_dataset.word_idx[word][seq_idx]
                                        ],
                                        k,
                                    ).indices
                                )
                            ]
                        )
                    )
    percent_right = (n_right / (ioi_dataset.N * 3)) * 100
    print(
        f"Copy circuit for head {layer}.{head} (sign={sign}) : Top {k} accuracy: {percent_right}%"
    )
    return percent_right

In [181]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"],
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    # SOLUTION
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size

In [127]:
def load_model(BASE_MODEL, VARIANT, CHECKPOINT, CACHE, device):
    if not VARIANT:
        model = HookedTransformer.from_pretrained(
            BASE_MODEL,
            checkpoint_value=CHECKPOINT,
            center_unembed=True,
            center_writing_weights=True,
            fold_ln=True,
            refactor_factored_attn_matrices=False,
            #dtype=torch.bfloat16,
            **{"cache_dir": CACHE},
        )
    else:
        revision = f"step{CHECKPOINT}"
        source_model = AutoModelForCausalLM.from_pretrained(
           VARIANT, revision=revision, cache_dir=CACHE
        ).to(device) #.to(torch.bfloat16)

        model = HookedTransformer.from_pretrained(
            BASE_MODEL,
            hf_model=source_model,
            center_unembed=False,
            center_writing_weights=False,
            fold_ln=False,
            #dtype=torch.bfloat16,
            **{"cache_dir": CACHE},
        )

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    return model

## Retrieve & Process Data

### Circuit Data

In [66]:
folder_path = f'results/graphs/pythia-160m/{TASK}'
df = load_edge_scores_into_dictionary(folder_path)

Processing file 1/153: results/graphs/pythia-160m/ioi/57000.json
Processing file 2/153: results/graphs/pythia-160m/ioi/141000.json
Processing file 3/153: results/graphs/pythia-160m/ioi/95000.json
Processing file 4/153: results/graphs/pythia-160m/ioi/107000.json
Processing file 5/153: results/graphs/pythia-160m/ioi/34000.json
Processing file 6/153: results/graphs/pythia-160m/ioi/6000.json
Processing file 7/153: results/graphs/pythia-160m/ioi/37000.json
Processing file 8/153: results/graphs/pythia-160m/ioi/39000.json
Processing file 9/153: results/graphs/pythia-160m/ioi/104000.json
Processing file 10/153: results/graphs/pythia-160m/ioi/59000.json
Processing file 11/153: results/graphs/pythia-160m/ioi/67000.json
Processing file 12/153: results/graphs/pythia-160m/ioi/111000.json
Processing file 13/153: results/graphs/pythia-160m/ioi/16.json
Processing file 14/153: results/graphs/pythia-160m/ioi/76000.json
Processing file 15/153: results/graphs/pythia-160m/ioi/1.json
Processing file 16/153:

### Performance Data

In [92]:
directory_path = 'results'
perf_metrics = load_metrics(directory_path)

ckpts = get_ckpts(schedule="exp_plus_detail")
#pythia_evals = aggregate_metrics_to_tensors_step_number("results/pythia-evals/pythia-v1")

# filter everything before 1000 steps
df = df[df['checkpoint'] >= 1000]

df[['source', 'target']] = df['edge'].str.split('->', expand=True)
len(df['target'].unique())

445

In [93]:
perf_metric = perf_metrics['pythia-160m'][TASK][PERFORMANCE_METRIC]

perf_metric = [x.item() for x in perf_metric]

# zip into dictionary with ckpts as key
perf_metric_dict = dict(zip(ckpts, perf_metric))


## Experiments

### Dataset Setup

In [128]:
initial_model = load_model(BASE_MODEL, VARIANT, 143000, CACHE, device)
size=70
ioi_dataset, abc_dataset = generate_data_and_caches(initial_model, size, verbose=True)

answer_tokens = torch.cat((torch.Tensor(ioi_dataset.io_tokenIDs).unsqueeze(1), torch.Tensor(ioi_dataset.s_tokenIDs).unsqueeze(1)), dim=1).to(device)
answer_tokens = answer_tokens.long()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer


### Get Experimental Candidates

In [129]:
EXPERIMENTAL_CHECKPOINT = 143000
COPY_SCORE_THRESHOLD = 75.0

In [182]:
experimental_model.reset_hooks()
experimental_model = load_model(BASE_MODEL, VARIANT, EXPERIMENTAL_CHECKPOINT, CACHE, device)
orig_logits, orig_cache = experimental_model.run_with_cache(ioi_dataset.toks.long())

answer_residual_directions = experimental_model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)



# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream: Float[Tensor, "batch seq d_model"] = orig_cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
scaled_residual_stream = orig_cache.apply_ln_to_stack(final_residual_stream, layer=-1)
print(f"Scaled residual stream shape: {scaled_residual_stream.shape}")
scaled_final_token_residual_stream: Float[Tensor, "batch d_model"] = scaled_residual_stream[torch.arange(final_residual_stream.size(0)), ioi_dataset.word_idx["end"]]
print(f"Final token residual stream shape: {scaled_final_token_residual_stream.shape}")

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions,
    "batch d_model, batch d_model ->"
) / 70

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {_logits_to_mean_logit_diff(orig_logits, ioi_dataset).item():.10f}")

#torch.testing.assert_close(average_logit_diff, clean_logit_diff)

per_head_residual, labels = orig_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, orig_cache, logit_diff_directions)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=experimental_model.cfg.n_layers, head_index=experimental_model.cfg.n_heads)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer
Answer residual directions shape: torch.Size([70, 2, 768])
Logit difference directions shape: torch.Size([70, 768])
Final residual stream shape: torch.Size([70, 21, 768])
Scaled residual stream shape: torch.Size([70, 21, 768])
Final token residual stream shape: torch.Size([70, 768])
Calculated average logit diff: 4.0878019333
Original logit difference:     4.1340785027


In [185]:
per_head_logit_diffs.shape

torch.Size([12, 12])

In [186]:
imshow_p(
    per_head_logit_diffs,
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

In [131]:
candidate_nmh = df[df['target']=='logits']
candidate_nmh = candidate_nmh[candidate_nmh['in_circuit'] == True]

candidate_list = candidate_nmh[candidate_nmh['checkpoint']==EXPERIMENTAL_CHECKPOINT]['source'].unique().tolist()
candidate_list = [convert_head_names_to_tuple(c) for c in candidate_list if (c[0] != 'm' and c != 'input')]

In [132]:
NMHs = []

for layer, head in candidate_list:
    copy_score = check_copy_circuit(experimental_model, layer, head, ioi_dataset, verbose=False, neg=False)
    NMHs.append((layer, head, copy_score))

Copy circuit for head 5.8 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 6.5 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 6.6 (sign=1) : Top 5 accuracy: 2.857142857142857%
Copy circuit for head 7.2 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 10.1 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 10.2 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 10.7 (sign=1) : Top 5 accuracy: 66.19047619047619%
Copy circuit for head 10.8 (sign=1) : Top 5 accuracy: 1.9047619047619049%
Copy circuit for head 10.10 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 10.11 (sign=1) : Top 5 accuracy: 7.142857142857142%
Copy circuit for head 11.8 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 11.9 (sign=1) : Top 5 accuracy: 0.0%
Copy circuit for head 9.9 (sign=1) : Top 5 accuracy: 4.761904761904762%
Copy circuit for head 9.8 (sign=1) : Top 5 accuracy: 0.9523809523809524%
Copy circuit for head 7.8 (sign=1) : Top 5 accuracy: 33.80952380952381%
Copy circuit for hea

In [171]:
heads_to_ablate = [x[:2] for x in NMHs if x[2] >= COPY_SCORE_THRESHOLD]
head_labels = [f"L{l}H{h}" for l in range(experimental_model.cfg.n_layers) for h in range(experimental_model.cfg.n_heads)]

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, :, head_idx, :] = 0
    return z

for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    experimental_model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)

ablated_logits, ablated_cache = experimental_model.run_with_cache(ioi_dataset.toks)
print(f"Original IOI Metric: {_logits_to_mean_logit_diff(orig_logits, ioi_dataset).item():.4f}")
print(f"Post ablation IOI Metric: {_logits_to_mean_logit_diff(ablated_logits, ioi_dataset).item()}")

#experimental_model.reset_hooks()

Heads to ablate: [(8, 2), (8, 10)]
Original IOI Metric: 4.1341
Post ablation IOI Metric: 4.028079032897949


In [176]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache, logit_diff_directions)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(experimental_model.cfg.n_layers, experimental_model.cfg.n_heads)

In [179]:
exclusions = [(6, 6), (7, 9), (8, 9)] + [(9, 1), (9, 5)]
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
for layer, head in exclusions:
    per_head_ablated_logit_diffs[layer, head] = 0

plot_attention_heads(
    per_head_ablated_logit_diffs/_logits_to_mean_logit_diff(orig_logits, ioi_dataset).item(), 
    title="Logit Diff Contribution From Backup Heads", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.05
