In [1]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
from pathlib import Path
import torch as t
from torch import Tensor
import numpy as np
import pandas as pd 
import einops
from tqdm.notebook import tqdm
import plotly.express as px
import webbrowser
import re
import itertools
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set, Union
from functools import partial
from IPython.display import display, HTML
import circuitsvis as cv
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
# !git clone https://github.com/callummcdougall/path_patching.git

sys.path.append('./path_patching/')  # replace <repo-name> with the name of the cloned repository
from path_patching import Node, IterNode, path_patch, act_patch

import fns_alejo_exploration
import importlib

importlib.reload(fns_alejo_exploration)
from ioi_dataset import NAMES, IOIDataset
from fns_alejo_exploration import logits_to_ave_logit_diff, topk_predictions_from_prompt, \
    visualize_selected_heads, visualize_selected_head
t.set_grad_enabled(False)

from plotly_utils import imshow, line, scatter, bar
import part3_indirect_object_identification.tests as tests

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

In [2]:
t.cuda.empty_cache()

# Set up

In [3]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)
model.set_use_attn_result(True)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
name_pairs = [
    (" John", " Mary"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]

name_c = [" David", " Charlie", " Sam", " Alex"]


# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)
prompts = [
    prompt.format(name)
    for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]
]

abc_prompts = [
    prompt.format(name) for (prompt, name) in zip(prompt_format, name_c) for _ in range(2)
]


# Define the answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
# Define the answer tokens (same shape as the answers)
answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in answers
])


clean_tokens = model.to_tokens(prompts, prepend_bos=True).to(device)
flipped_indices = [i+1 if i % 2 == 0 else i-1 for i in range(len(clean_tokens))]
flipped_tokens = clean_tokens[flipped_indices]
abc_tokens = model.to_tokens(abc_prompts, prepend_bos=True).to(device)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
flipped_logits, flipped_cache = model.run_with_cache(flipped_tokens)
abc_logits, abc_cache = model.run_with_cache(abc_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
flipped_logit_diff = logits_to_ave_logit_diff(flipped_logits, answer_tokens)
abc_logit_diff = logits_to_ave_logit_diff(abc_logits, answer_tokens)

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Flipped string 0:", model.to_string(flipped_tokens[0])
)
print(f"Clean logit diff: {clean_logit_diff:.4f}")
print(f"Flipped logit diff: {flipped_logit_diff:.4f}")


def print_topk_predictions(logits, k=5):
    top_probs, top_indices = t.topk(t.softmax(logits[-1], dim=-1), k=k)
    preds = "\t".join([f"{p.item():.3f}: '{model.to_string(idx)}'" 
                          for p, idx in zip(top_probs, top_indices)])
    print(preds)

def print_prompt_completions(prompts: List[str], k=5):
    for prompt in prompts:
        tokens = model.to_tokens(prompt, prepend_bos=True).to(device)
        logits = model(tokens)
        top_probs, top_indices = t.topk(t.softmax(logits[0, -1], dim=-1), k=k)
        out = prompt + "\t\t" + "\t".join([f"{model.to_string(idx)}({p.item():.3f})" for p, idx in zip(top_probs, top_indices)])
        print(out)

def ioi_metric_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    flipped_logit_diff: float = flipped_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff)).item()

def ioi_metric_noising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    flipped_logit_diff: float = flipped_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff  - flipped_logit_diff)).item()


Clean string 0:     <|endoftext|>When John and Mary went to the shops, Mary gave the bag to 
Flipped string 0: <|endoftext|>When John and Mary went to the shops, John gave the bag to
Clean logit diff: 3.5519
Flipped logit diff: -3.5519


In [5]:
CIRCUIT = {
    "name mover": [(9, 9), (10, 0), (9, 6)],
    "backup name mover": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],
    "negative name mover": [(10, 7), (11, 10)],
    "s2 inhibition": [(7, 3), (7, 9), (8, 6), (8, 10)],
    "induction": [(5, 5), (5, 8), (5, 9), (6, 9)],
    "duplicate token": [(0, 1), (0, 10), (3, 0)],
    "previous token": [(2, 2), (4, 11)],
}

CIRCUIT["positive name mover"] = CIRCUIT["name mover"] + CIRCUIT["backup name mover"]
CIRCUIT["all name mover"] = CIRCUIT["positive name mover"] + CIRCUIT["negative name mover"]
CIRCUIT["duplicate detector"] = CIRCUIT["duplicate token"] + CIRCUIT["induction"]
CIRCUIT["all"] = CIRCUIT["duplicate detector"] + CIRCUIT["s2 inhibition"] + CIRCUIT["all name mover"] 

# IOI Circuit in the Wild

In [6]:
from datasets import load_dataset

# openwebtext_string = load_dataset('stas/openwebtext-10k')['train']['text']
# openwebtext_tokens = model.to_tokens(openwebtext_string, prepend_bos=True)

conll_data = load_dataset('conll2003')
wiki_data = load_dataset('wikitext', 'wikitext-103-v1')
# glue_data = load_dataset('glue', 'wnli')

Found cached dataset conll2003 (/root/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset wikitext (/root/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
wiki_string = wiki_data['train']['text']
wiki_string = [sentence for sentence in wiki_string if len(sentence) > 150 and len(sentence) < 400] 
wiki_string = sorted(wiki_string, key=len)

In [8]:
from torch.utils.data import DataLoader, Dataset

class StringDataset(Dataset):
    def __init__(self, data: List[str], max_len: int = 10_000):
        self.data = data[:max_len]

    def __getitem__(self, idx):
        if t.is_tensor(idx):
            if idx.dtype == t.bool:
                idx = idx.nonzero(as_tuple=False)
            return [self.data[i] for i in idx]
        else:
            return self.data[idx]

    def __len__(self):
        return len(self.data)


dataset = StringDataset(wiki_string)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)


In [9]:
def select_prompts_by_pattern(dataloader: DataLoader,
                              top_k=20,
                              mode: Literal['non-bos-sum', 'entropy'] = 'entropy',
                              heads: List[Tuple[int, int]] = CIRCUIT['s2 inhibition'],
                              template_prompt: Optional[str] = None,
    ) -> Tuple[Float[Tensor, 'top_k'], Int[Tensor, 'top_k']]:
    all_scores = []

    for batch in tqdm(dataloader):
        tokens = model.to_tokens(batch, prepend_bos=True)
        _, cache = model.run_with_cache(tokens, names_filter=lambda n: 'pattern' in n)
        attn_patterns = t.stack([cache['pattern', layer][:, head] for layer, head in heads]) # (head, batch, pos, pos)

        if mode == 'non-bos-sum':
            score = attn_patterns[..., 1:].mean(dim=[0, 2, 3])

        elif mode == 'entropy':
            non_bos_attn = attn_patterns[..., 1:]
            score = (non_bos_attn*t.log(non_bos_attn)).sum(dim=[0, 2, 3]) # Minimize entropy

        elif mode == 'high-attn':
            n_heads, batch = attn_patterns.shape[:2]
            non_bos_max_attn = attn_patterns[..., 1:].reshape(n_heads, batch, -1)
            score = t.any((non_bos_max_attn > 0.7).int().sum(0) >= 2, dim=-1) # At least 2 heads have more than 0.7 attn to the same token
            score = score.float()
        
        else:
            raise ValueError(f'Unknown mode: {mode}')

        all_scores.append(score)
        

    all_scores = t.cat(all_scores)
    top_scores, top_idx = t.topk(all_scores, k=top_k)
    return top_scores, top_idx

In [10]:
from functools import reduce

def select_prompts_by_ioi_similarity(dataset: Dataset,
                                     circuits: List[List[Tuple[int, int]]],
                                     prob_threshold: float = 0.5,
                                     ):
    all_matches = []

    heads = reduce(lambda x, y: x + y, circuits)
    for batch in tqdm(DataLoader(dataset, batch_size=64, shuffle=False)):
        batch_size = len(batch)
        tokens = model.to_tokens(batch, prepend_bos=True)
        _, cache = model.run_with_cache(tokens, names_filter=lambda n: 'pattern' in n)

        match = t.ones(batch_size).bool()

        for circuit in circuits:
            attn_patterns = t.stack([cache['pattern', layer][:, head] for layer, head in circuit]) # (head, batch, pos, pos)
            non_bos_max_attn = attn_patterns[..., 1:].reshape(*attn_patterns.shape[:2], -1)
            circuit_match = t.any((non_bos_max_attn > prob_threshold).int().sum(0) >= 2, dim=-1) # At least 2 heads have more than 0.5 attn to the same token
            match = match & circuit_match.cpu()
            
        all_matches.append(match)
        

    all_matches = t.cat(all_matches)
    return dataset[all_matches]

In [11]:
active_name_mover_idx = select_prompts_by_pattern(dataloader, mode='high-attn', heads=CIRCUIT['s2 inhibition'], top_k=60)[1]
active_name_mover_prompts = dataset[active_name_mover_idx]

  0%|          | 0/157 [00:00<?, ?it/s]

In [12]:
print('\n'.join(active_name_mover_prompts[:60]))


 Due to popularity of this song in Italy , the song was chosen by the Lega Calcio as a theme soundtrack for the Serie A and Serie B 2007 – 08 season . 

 Osborne Reef is an artificial reef off the coast of Fort Lauderdale , Florida , constructed of concrete jacks in a 50 feet ( 15 m ) diameter circle . 

 At Lord 's , 25 , 26 , 27 August . The Australians ( 610 / 5 declared ) defeated the Gentlemen of England ( 245 and 284 ) by an innings and 81 runs . 

 Loaded weight : 30 @,@ 384 lb ( 13 @,@ 782 kg ) CAS mission : 47 @,@ 094 lb ( 21 @,@ 361 kg ) Anti @-@ armor mission : 42 @,@ 071 lb ( 19 @,@ 083 kg ) 

 However , there have been a number of five @-@ by @-@ fours ( at least 4 points , 4 rebounds , 4 assists , 4 steals , and 4 blocks ) in the playoffs . 

 He is the only player in Duke history to record at least 2 @,@ 000 points , 500 rebounds , 400 assists , 250 3 @-@ pointers , and 200 steals in a career . 

 Dark matter and dark energy are the current leading topics in astronomy ,

In [14]:
def get_tokens_and_answers(prompts: List[str], answers: List[Tuple[str, str]]) -> Tuple[Int[Tensor, 'batch seq'], Int[Tensor, 'batch 2'], List[int]]:
    tokens = model.to_tokens(prompts, prepend_bos=True)
    answer_tokens = t.cat([model.to_tokens(answer, prepend_bos=False) for answer in answers], dim=1).T # (batch, 2) 
    assert answer_tokens.shape == (len(prompts), 2)
    answer_pos = t.tensor([len(model.to_str_tokens(prompt)) - 1  for prompt in prompts])
    return tokens.to(device), answer_tokens.to(device), answer_pos.to(device)

def my_logit_diff(logits: Float[Tensor, 'batch seq vocab'], answer_tokens: Int[Tensor, 'batch 2'], answer_pos: Int[Tensor, 'batch']):
    batch_idx = t.arange(len(answer_pos))
    logits_ans = logits[batch_idx[:, None], answer_pos[:, None], answer_tokens]
    assert logits_ans.shape == (logits.shape[0], 2)
    return (logits_ans[..., 0] - logits_ans[..., 1]).mean().item()

# soft_match_prompts = [
#     "Compared to other metals, mercury is a poor conductor of heat, but a fair",
#     "Osborne Reef is an artificial",
#     "Local bus services operate in Highlbury (three routes), Hyde Park (five",
#     "Nathaniel Hart was the second son of Thomas",
#     "Jack Markov was the third son of Michael"
# ]
# soft_match_answers = [' conductor', ' reef', ' routes', ' Hart', ' Mark']
# soft_match_tokens, soft_match_ans, soft_match_pos = get_tokens_and_answers(soft_match_prompts, soft_match_answers)
# soft_match_logits, soft_match_cache = model.run_with_cache(soft_match_tokens)

clean_prompts = [
    "Jack Hammer completed 230 passes, 350 rebounds, and 120",
    "Dark matter and dark",
    "John Conduitt , Newton 's assistant at the Royal Mint and husband of Newton 's",
    "The Wall of Remembrance has the names of the 400 casualties ( killed and",
    "Even though he started as a Series A player, he soon moved to play at Series",
]
clean_answers = [(' assists', 'pass'), (' energy', ' matter'), (' son', ' assistant'),
                 (' wounded', ' killed'), (' B', ' A')
                 ]

clean_tokens, clean_ans, clean_pos = get_tokens_and_answers(clean_prompts, clean_answers)
clean_logits, clean_cache = model.run_with_cache(clean_tokens)

my_logit_diff(clean_logits, clean_ans, clean_pos)

# acronym_prompts = [
#     "The World Spectacle Circus (",
#     "The Brazilian Balllet Association (",
#     "The International Society of Chemistry (",    
# ]

# print_prompt_completions(soft_match_prompts)
# print_prompt_completions(clean\_prompts)

6.059764862060547

In [15]:
from copy import deepcopy

bos_cache = deepcopy(clean_cache)
for layer in range(clean_cache.model.cfg.n_layers):
    bos_cache['pattern', layer].fill_(0)
    bos_cache['pattern', layer][..., 0] = 1
    bos_cache['z', layer].fill_(0)
    bos_cache['z', layer][:] = einops.einsum(bos_cache['pattern', layer],
                                          bos_cache['v', layer],
                                          'b h p_q p_k, b p_k h d_h -> b p_q h d_h')

In [23]:
# results = path_patch(
#     model=model,
#     orig_input=clean_tokens,
#     new_input='this arg is not used',
#     new_cache=bos_cache,
#     sender_nodes=IterNode('z'),
#     receiver_nodes=Node('resid_post', 11),
#     patching_metric=partial(my_logit_diff, answer_tokens=clean_ans, answer_pos=clean_pos),
# )

results = act_patch(
    model=model,
    orig_input=clean_tokens,
    new_input='this arg is not used',
    new_cache=bos_cache,
    patching_nodes=IterNode('z'),
    patching_metric=partial(my_logit_diff, answer_tokens=clean_ans, answer_pos=clean_pos),
)

  0%|          | 0/144 [00:00<?, ?it/s]

In [24]:
imshow(results['z'] - results['z'].median())

In [18]:
imshow(results['z'] - results['z'].median())

In [49]:
visualize_selected_heads(model, active_name_mover_prompts, CIRCUIT['s2 inhibition'] + CIRCUIT['name mover'], max_seq_len=90, idx=2)

In [33]:
ioi_like_prompts = select_prompts_by_ioi_similarity(dataset, [CIRCUIT['name mover'], CIRCUIT['s2 inhibition']])

  0%|          | 0/157 [00:00<?, ?it/s]

In [51]:
print('\n'.join(ioi_like_prompts[:30]))

 A sequence ( x0 , x1 , x2 , … ) has a limit x if the distance | x − xn | becomes arbitrarily small as n increases . The statement that 0 @.@ 999 … = 

 US 56 / US 283 in Dodge City . US 50 / US 56 travels concurrently to Kinsley . US 50 / US 283 travels concurrently to west @-@ southwest of Wright . 

 Osborne Reef is an artificial reef off the coast of Fort Lauderdale , Florida , constructed of concrete jacks in a 50 feet ( 15 m ) diameter circle . 

 At Lord 's , 25 , 26 , 27 August . The Australians ( 610 / 5 declared ) defeated the Gentlemen of England ( 245 and 284 ) by an innings and 81 runs . 

 " Traitement de la <unk> du col de l ’ uterus par <unk> <unk> " ( in French ) . Paris : Boletín de la Société de <unk> et d ’ <unk> de París . 1923 . 

 Loaded weight : 30 @,@ 384 lb ( 13 @,@ 782 kg ) CAS mission : 47 @,@ 094 lb ( 21 @,@ 361 kg ) Anti @-@ armor mission : 42 @,@ 071 lb ( 19 @,@ 083 kg ) 

 The World Snowboard Federation ( WSF ) has a more elaborate classification system 

In [69]:
visualize_selected_heads(model, ioi_like_prompts, CIRCUIT['name mover'] + CIRCUIT['s2 inhibition'], max_seq_len=90, idx=24)

In [59]:
prompt = 'In 1920, Jack married Linda, and in 1923 he graduated from college'
visualize_selected_heads(model, [prompt], CIRCUIT['name mover'] + CIRCUIT['s2 inhibition'])

In [60]:
print_topk_predictions(model('In 1920, Jack married Linda, and in')[0], k=10)

0.079: ' 1929'	0.058: ' 1927'	0.047: ' 1926'	0.046: ' 1928'	0.040: ' 1933'	0.040: ' 1930'	0.039: ' 1925'	0.038: ' 1923'	0.037: ' 1921'	0.036: ' 1931'


In [12]:
top_max_attn_scores, top_max_attn_idx = select_prompts_by_pattern(dataloader, mode='high-attn', heads=CIRCUIT['name mover'])
top_prompts = [dataset[top_max_attn_idx[i]] for i in range(len(top_max_attn_idx))]
top_prompts

  0%|          | 0/157 [00:00<?, ?it/s]

[' Tourist and Heritage Railways in Victoria are governed by provisions in the Tourist and Heritage Railways Act 2010 which commenced on 1 October 2011 . \n',
 ' Le banquet céleste ( " The heavenly banquet " ) , organ ( 1928 , a recomposition of a section from his unpublished orchestral piece Le banquet <unk> ) \n',
 ' The " water " engo are <unk> ( " brooding " , but a pun on naga @-@ ame " long rain " ) , <unk> ( " a river of tears " ) and <unk> ( " is soaked " ) . \n',
 ' Missouri ex rel . Gaines v. Canada , 305 U.S. 337 ( 1938 ) <unk> that provide a school to white students must provide in @-@ state education to blacks \n',
 ' “ Tōkyō mutan ” ( <unk> , Tokyo dreams ) . [ S ] Ginza Nikon Salon ( Ginza , Tokyo ) , September 2007 ; Osaka Nikon Salon ( Osaka ) , October 2007 . \n',
 ' On second evening I asked General Burston as an old friend of Rowell to endeavour to induce a proper frame of mind but Burston met with no success . \n',
 ' US 56 / US 283 in Dodge City . US 50 / US 56 tr

In [13]:
visualize_selected_heads(model, top_prompts, CIRCUIT['name mover'], max_seq_len=90, idx=12)

In [173]:
CIRCUIT['backup name mover']

[(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)]

In [31]:
visualize_selected_heads(model, top_prompts, CIRCUIT['s2 inhibition'], max_seq_len=60, idx=18)

In [4]:
visualize_selected_heads(model, top_prompts, CIRCUIT['all'], max_seq_len=90, idx=3)

NameError: name 'model' is not defined

In [None]:
print(top_max_attn_scores)
visualize_selected_head(model, [dataset[top_max_attn_idx[0]]], CIRCUIT['s2 inhibition'][0], max_seq_len=90)

In [21]:
heads = CIRCUIT['s2 inhibition']
prompt = 'When John and Mary went to the shop, John gave a bottle to'
_, cache = model.run_with_cache(prompt, names_filter=lambda n: 'pattern' in n)
attn_patterns = t.stack([cache['pattern', layer][:, head] for layer, head in heads])

n_heads, batch = attn_patterns.shape[:2]
non_bos_max_attn = attn_patterns[..., 1:].reshape(n_heads, batch, -1)
score = t.any((non_bos_max_attn > 0.7).int().sum(0) >= 2, dim=-1) # At least 2 heads have more than 0.4 attn to the same token
score = score.float()
score

tensor([1.], device='cuda:0')

In [143]:
prompt

' Elizabeth I of England : The aging Queen of England . Already close to death , she is killed by a poisonous gas device constructed by Otto Von Doom . \n'

In [78]:
visualize_selected_heads(model, [conll_string[top_sum_idx[0]]], CIRCUIT['s2 inhibition'])

In [71]:
# conll_string['train']['tokens'][:50]

def cat_string_conll(dataset: List[List[str]]):
    out = []
    for sentence in dataset:
        sentence_str = sentence[0]
        for word in sentence[1:]:
            if word in ['.', ',', '?', '!', ':', ';', "'s", "n't"]:
                sentence_str += word
            else:
                sentence_str += ' ' + word
        out.append(sentence_str)
    return out

conll_string = cat_string_conll(conll_data['train']['tokens'])
conll_string = sorted(conll_string, key=len)[7000:] # remove short sentences
print(conll_string[0])

Wall Street ponders Rubin's role if Clinton wins.


# Lazy patching

Mercury is a heavy , silvery @-@ white liquid metal . Compared to other metals , it is a poor conductor of heat , but a fair conductor of electricity .
Local bus services are operated by Webberbus ( seven routes ) , First Somerset & Avon ( three routes ) , and Quantock Motor Services ( two routes ) . 
Osborne Reef is an artificial reef off the coast of Fort Lauderdale , Florida , constructed of concrete jacks in a 50 feet ( 15 m ) diameter circle .
Nathaniel Hart was one of seven children , the second son of Colonel Thomas Hart , a veteran of the Revolutionary War , and his wife Susanna ( Gray ) Hart .


Compared to other metals, mercury is a poor conductor of heat, but a fair		 conductor(0.333)	 amount(0.237)	 bit(0.083)	 one(0.060)	 number(0.020)
Local bus services operate in Highlbury (three routes), Hyde Park (five		 routes(0.527)	),(0.298)	)(0.074)	,(0.027)	 and(0.010)
Nathaniel Hart, the second son of Colonel Thomas		The(0.067)	A(0.035)	This(0.019)	In(0.016)	I(0.014)
Nathaniel Hart was one of seven children, the second son of Colonel Thomas		 Hart(0.705)	 and(0.047)	 H(0.021)	,(0.012)	 ((0.006)


In [97]:

print(bos_cache['pattern', layer].shape)
bos_cache['v', layer].shape

torch.Size([8, 12, 15, 15])


torch.Size([8, 15, 12, 64])

In [100]:
from copy import deepcopy

bos_cache = deepcopy(clean_cache)
for layer in range(clean_cache.model.cfg.n_layers):
    bos_cache['pattern', layer].fill_(0)
    bos_cache['pattern', layer][..., 0] = 1
    bos_cache['z', layer].fill_(0)
    bos_cache['z', layer][:] = einops.einsum(bos_cache['pattern', layer],
                                          bos_cache['v', layer],
                                          'b h p_q p_k, b p_k h d_h -> b p_q h d_h')

In [103]:
results_flipped = act_patch(
    model=model,
    orig_input=prompts,
    new_cache=flipped_cache,
    patching_nodes=IterNode('z'),
    patching_metric=ioi_metric_noising,
    verbose=True,
)

results_bos = act_patch(
    model=model,
    orig_input=prompts,
    new_cache=bos_cache,
    patching_nodes=IterNode('z'),
    patching_metric=ioi_metric_noising,
    verbose=True,
)

imshow(t.stack([results_flipped['z'], results_bos['z']]), facet_col=0,
       facet_labels=['Patched hook_z with flipped S2', 'Patched pattern with only attn to BOS'])

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [111]:
CIRCUIT['s2 inhibition']

[(7, 3), (7, 9), (8, 6), (8, 10)]

In [110]:
results_path_flipped = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=abc_tokens,
    sender_nodes=IterNode('z'),
    receiver_nodes=[Node('q', layer=layer, head=head) for layer, head in CIRCUIT['name mover']],
    patching_metric=ioi_metric_noising,
    verbose=True,
)

results_path_bos = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=abc_tokens,
    new_cache=bos_cache,
    sender_nodes=IterNode('z'),
    receiver_nodes=[Node('q', layer=layer, head=head) for layer, head in CIRCUIT['name mover']],
    patching_metric=ioi_metric_noising,
    verbose=True,
)

imshow(t.stack([results_path_flipped['z'], results_path_bos['z']]), facet_col=0,
       facet_labels=['Path patch name movers with flipped S2', 'Path patch name movers with attn only to BOS'])

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [113]:
results_path_flipped = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=abc_tokens,
    sender_nodes=IterNode('z'),
    receiver_nodes=[Node('v', layer=layer, head=head) for layer, head in CIRCUIT['s2 inhibition']],
    patching_metric=ioi_metric_noising,
    verbose=True,
)

results_path_bos = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=flipped_tokens,
    new_cache=bos_cache,
    sender_nodes=IterNode('z'),
    receiver_nodes=[Node('v', layer=layer, head=head) for layer, head in CIRCUIT['s2 inhibition']],
    patching_metric=ioi_metric_noising,
    verbose=True,
)

imshow(t.stack([results_path_flipped['z'], results_path_bos['z']]), facet_col=0,
       facet_labels=['Path patch s-inhibition with flipped S2', 'Path patch s-inhibition with attn only to BOS'])

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [109]:
results1 = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=abc_tokens,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
    verbose=True
)

results2 = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=abc_tokens,
    new_cache=bos_cache,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
    verbose=True
)


imshow(t.stack([results1['z'], results2['z']]), facet_col=0,)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [86]:
visualize_selected_heads(model, prompts, heads=[(1, 10), (11, 8), (5, 9)])

# Patching within IOI  

I used the solu-8-pile model for this. I probably won't go deeeper in there

In [71]:
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]
print('--- Clean logit preds ---')
print_topk_predictions(clean_logits[0])
print('--- Flipped logit preds ---')
print_topk_predictions(flipped_logits[0])

--- Clean logit preds ---
0.694: ' John'	0.057: ' them'	0.044: ' the'	0.028: ' Mary'	0.014: ' her'
--- Flipped logit preds ---
0.699: ' Mary'	0.065: ' them'	0.047: ' the'	0.027: ' his'	0.025: ' John'


In [80]:
results = act_patch(
    model=model,
    orig_input=flipped_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode('z'),
    patching_metric=ioi_metric_denoising,
    verbose=True,
)

imshow(results['z'])

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [24]:
results = act_patch(
    model=model,
    orig_input=clean_tokens,
    new_cache=flipped_cache,
    patching_nodes=IterNode('z'),
    patching_metric=ioi_metric_noising,
    verbose=True,
)

imshow(results['z'])

  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


In [None]:
help(path_patch)

In [18]:

results = path_patch(
    model,
    clean_tokens,
    new_input=abc_tokens,
    sender_nodes=IterNode('z'),
    receiver_nodes=Node('resid_post', layer=7),
    patching_metric=partial(ioi_metric_noising, flipped_logit_diff=abc_logit_diff),
    verbose=True,
)

imshow(results['z'])

  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


In [None]:

results = path_patch(
    model,
    clean_tokens,
    new_input=flipped_tokens,
    sender_nodes=IterNode('z'),
    receiver_nodes=Node('resid_post', layer=7),
    patching_metric=ioi_metric_noising,
    verbose=True,
)

imshow(results['z'])

  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


In [17]:
CIRCUIT = [(2, 15), (3, 15), (4, 2), (5, 0), (6, 3), (6, 10), (6, 13), (7, 3), (7, 7), (7, 8)]
visualize_selected_heads(model, prompts, CIRCUIT)

In [None]:
def plot_logit_contribution(cache: ActivationCache, head=Tuple[int, int]):
    """ Plot the top logits that a given head boosts or downgrades"""

In [16]:
all_results = []
for hook_point in ['k', 'q', 'v']:
    results = path_patch(
        model,
        clean_tokens,
        new_input=flipped_tokens,
        sender_nodes=IterNode('z'),
        receiver_nodes=Node(hook_point, layer=5, head=0),
        patching_metric=ioi_metric_noising,
        verbose=True,
    )
    all_results.append(results['z'])


imshow(t.stack(all_results), facet_col=0, facet_labels=['k', 'q', 'v'])

  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


In [15]:
all_results = []
for hook_point in ['k', 'q', 'v']:
    results = path_patch(
        model,
        clean_tokens,
        new_input=flipped_tokens,
        sender_nodes=IterNode('z'),
        receiver_nodes=Node(hook_point, layer=6, head=3),
        patching_metric=ioi_metric_noising,
        verbose=True,
    )
    all_results.append(results['z'])


imshow(t.stack(all_results), facet_col=0, facet_labels=['k', 'q', 'v'])

  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


  0%|          | 0/128 [00:00<?, ?it/s]

results['z'].shape = (layer=8, head=16)


In [14]:
clean_cache['result', 0].shape

torch.Size([8, 15, 16, 1024])

In [53]:
head_attrs.shape

torch.Size([128, 8, 50278])

In [59]:
head_logit_diff.shape

torch.Size([128, 1, 8])

In [62]:
head_results = clean_cache.stack_head_results(apply_ln=True)[:, :, -1]
head_attrs = head_results @ model.W_U
batch_idx = t.arange(8)
correct_ans, wrong_ans = t.unbind(answer_tokens, dim=-1)
# head_logit_diff = head_attrs[:, batch_idx, correct_ans]
head_logit_diff = head_attrs[:, batch_idx, wrong_ans]

imshow(einops.rearrange(head_logit_diff.mean(-1), '(layer head) -> layer head', layer=8))