We investigate how the enrichment process of factual recall is impacted by competing circuits. We leverage [attribute lenses](https://arxiv.org/abs/2308.09124) to perform our analysis.

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import torch
import transformer_lens.utils as tl_utils
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

from pii import utils

### Load model

In [3]:
# You will need to login to huggingface first:
#   git config --global credential.helper store
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# transformerlens model
tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model,
    device="cpu",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
).to("cuda" if torch.cuda.is_available() else "cpu")

tl_model.generate(
    "The capital of Germany is",
    max_new_tokens=20,
    temperature=0,
)

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


  0%|          | 0/20 [00:00<?, ?it/s]

'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'

### Test prompts

In [5]:
def gen_prompt(
    person: str,
    question: str,
    correct: bool = True,
) -> str:
    return f"""\
[INST] <<SYS>>
You are an obedient and well-informed assistant who only ever responds with a single word.
<</SYS>>

{'' if correct else 'Give an incorrect answer to the question. '}\
The question is about {person}. {question} [/INST]
"""
#     return f"""\
# [INST] <<SYS>>
# You only ever responds with a single word. \
# You always give the {'right' if correct else 'wrong'} answer.
# <</SYS>>

# I have a question about {person}. {question} [/INST]
# """

In [6]:
# person = "Valentina Lisitsa" # Wrong
person = "Itzhak Perlman"
question = "What instrument does this person play?"
utils.get_top_responses(
    prompt=gen_prompt(person, question),
    model=tl_model,
)
print()
utils.get_top_responses(
    prompt=gen_prompt(person, question, correct=False),
    model=tl_model,
    n_continuation_tokens=50,
)

Rank 0. Logit: 23.21 Prob: 98.94% Tokens: |V|i|olin|</s>|
Rank 1. Logit: 18.19 Prob:  0.65% Tokens: |"|V|i|olin|"|</s>|
Rank 2. Logit: 17.15 Prob:  0.23% Tokens: |Vi|olin|</s>|
Rank 3. Logit: 16.40 Prob:  0.11% Tokens: |VI|OL|IN|</s>|
Rank 4. Logit: 15.02 Prob:  0.03% Tokens: |vi|olin|</s>|

Rank 0. Logit: 16.52 Prob: 51.93% Tokens: |V|i|olin|</s>|
Rank 1. Logit: 14.88 Prob: 10.14% Tokens: |B|ass|</s>|
Rank 2. Logit: 14.62 Prob:  7.80% Tokens: |G|uit|ar|</s>|
Rank 3. Logit: 14.55 Prob:  7.28% Tokens: |W|rong|!|It|zh|ak|Perl|man|is|a|viol|in|ist|.|</s>|
Rank 4. Logit: 13.91 Prob:  3.81% Tokens: |D|rum|s|</s>|


In [2]:
person = "Miles Davis"
question = "What instrument does this person play?"
utils.get_top_responses(
    prompt=gen_prompt(person, question),
    model=tl_model,
)
print()
utils.get_top_responses(
    prompt=gen_prompt(person, question, correct=False),
    model=tl_model,
    n_continuation_tokens=20,
)

NameError: name 'utils' is not defined

In [79]:
person = "Kevin Durant"
question = "What year was this person born in?"
utils.get_top_responses(
    prompt=gen_prompt(person, question),
    model=tl_model,
)
print()
utils.get_top_responses(
    prompt=gen_prompt(person, question, correct=False),
    model=tl_model,
    n_continuation_tokens=40,
)

Rank 0. Logit: 21.03 Prob: 94.74% Tokens: |1|9|8|8|</s>|
Rank 1. Logit: 17.51 Prob:  2.81% Tokens: |"|1|9|8|8|"|
Rank 2. Logit: 16.88 Prob:  1.49% Tokens: |B|orn|:||1|9|
Rank 3. Logit: 15.94 Prob:  0.58% Tokens: |D|ur|ant|</s>|
Rank 4. Logit: 14.39 Prob:  0.12% Tokens: |Ke|vin|</s>|

Rank 0. Logit: 17.65 Prob: 77.25% Tokens: |1|9|8|5|</s>|
Rank 1. Logit: 15.51 Prob:  9.06% Tokens: |*||1|9|8|5|</s>|
Rank 2. Logit: 14.56 Prob:  3.53% Tokens: |W|rong|.|</s>|
Rank 3. Logit: 13.84 Prob:  1.72% Tokens: |�|�|�|�|</s>|
Rank 4. Logit: 13.43 Prob:  1.14% Tokens: |U|h|...||1|9|8|5|.|</s>|


In [80]:
person = "Kevin Durant"
question = "What gender is this person?"
utils.get_top_responses(
    prompt=gen_prompt(person, question),
    model=tl_model,
)
print()
utils.get_top_responses(
    prompt=gen_prompt(person, question, correct=False),
    model=tl_model,
    n_continuation_tokens=40,
)

Rank 0. Logit: 24.01 Prob: 98.65% Tokens: |M|ale|</s>|
Rank 1. Logit: 18.66 Prob:  0.47% Tokens: |Male|</s>|
Rank 2. Logit: 18.30 Prob:  0.33% Tokens: |Man|</s>|
Rank 3. Logit: 18.14 Prob:  0.28% Tokens: |male|</s>|
Rank 4. Logit: 17.23 Prob:  0.11% Tokens: |"|M|ale|"|</s>|

Rank 0. Logit: 18.00 Prob: 62.80% Tokens: |F|em|ale|</s>|
Rank 1. Logit: 17.17 Prob: 27.41% Tokens: |W|oman|</s>|
Rank 2. Logit: 15.22 Prob:  3.87% Tokens: |M|ale|</s>|
Rank 3. Logit: 14.00 Prob:  1.15% Tokens: |She|</s>|
Rank 4. Logit: 13.92 Prob:  1.06% Tokens: |Fem|ale|</s>|


In [83]:
person = "Kevin Durant"
question = "What team does this person play for?"
utils.get_top_responses(
    prompt=gen_prompt(person, question),
    model=tl_model,
)
print()
utils.get_top_responses(
    prompt=gen_prompt(person, question, correct=False),
    model=tl_model,
    n_continuation_tokens=40,
)

Rank 0. Logit: 20.54 Prob: 55.06% Tokens: |Th|under|</s>|
Rank 1. Logit: 19.98 Prob: 31.21% Tokens: |W|ar|riors|</s>|
Rank 2. Logit: 18.30 Prob:  5.83% Tokens: |N|ets|</s>|
Rank 3. Logit: 17.90 Prob:  3.91% Tokens: |Ok|ay|!|Kevin|Durant|plays|
Rank 4. Logit: 16.84 Prob:  1.36% Tokens: |Bro|ok|lyn|</s>|

Rank 0. Logit: 17.13 Prob: 69.09% Tokens: |B|ull|s|</s>|
Rank 1. Logit: 14.81 Prob:  6.81% Tokens: |W|ar|riors|</s>|
Rank 2. Logit: 14.20 Prob:  3.69% Tokens: |*|Brook|lyn|</s>|
Rank 3. Logit: 14.19 Prob:  3.66% Tokens: |N|ets|</s>|
Rank 4. Logit: 13.87 Prob:  2.66% Tokens: |M|iami|</s>|


### Set up attribute lenses