# Setup

In [1]:
from attention_data import AttentionData
import os
import openai
import torch as t
from transformer_lens import HookedTransformer
%pip install python-dotenv
from dotenv import load_dotenv

load_dotenv()

# Set API Keys
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
assert OPENAI_API_KEY, "OPENAI_API_KEY environment variable is missing from .env"
openai.api_key = OPENAI_API_KEY

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Get a dataset

%pip install datasets > /dev/null
from datasets import load_dataset
dataset = load_dataset("stas/openwebtext-10k", split="train", trust_remote_code=True)


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
# Get a model

model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
# Make an AttentionData instance

attention_data = AttentionData(
    model=model, 
    text_batch=dataset['text'][:100], # Speed is sensitive to the number of samples
    openai_model="gpt-3.5-turbo-1106", 
    openai_api_key=OPENAI_API_KEY
)

# Usage
(Note: currently the first token attention seems overly high, might be a bug)

In [5]:
# Let's look at L10H7, which was studied closely here: https://arxiv.org/pdf/2310.04625.pdf
layer = 10
head = 7

In [6]:
prompt, description = attention_data.describe_head(layer=layer, head=head, num_samples=20)

Creating new samples for layer 10 head 7


Token indices sequence length is longer than the specified maximum sequence length for this model (5989 > 1024). Running this sequence through the model will result in indexing errors


Making API call to gpt-3.5-turbo-1106...

Based on the attention patterns provided, the attention head seems to pay attent
ion to specific contextually relevant tokens in each generation step. For instan
ce, in the first example, the attention is focused on tokens related to the subj
ect matter, such as "Former" and "Bush" in the context of "Stephen Hadley," indi
cating an emphasis on names and titles. In other examples, attention is directed
 towards indicative phrases like "Germany coach Joachim Low" and specific termin
ology like "Mueller" and "Russia Probe." Additionally, attention is also drawn t
o punctuation marks and conjunctions, potentially reflecting the importance of s
entence structure and coherence in generating text. Overall, the attention head 
appears to prioritize tokens that are semantically or syntactically significant 
for the given context, aligning with the transformer's mechanism of attending to
 relevant information for language generation.


In [19]:
# The "multiple" is the multiple of the average attention pattern value for a row,
# i.e. a multiple of 2 in a row with 10 tokens means the attention score was 0.2

ranked_multiples = attention_data.get_ranked_multiples(
    head=head, 
    layer=layer, 
    num_multiples=10, 
    display=True
)


"Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples","Layer 10 Head 7, Top 10 / 4639 Multiples",Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,Unnamed: 16_level_0,Unnamed: 17_level_0,Unnamed: 18_level_0,Unnamed: 19_level_0,Unnamed: 20_level_0,Unnamed: 21_level_0,Unnamed: 22_level_0,Unnamed: 23_level_0,Unnamed: 24_level_0,Unnamed: 25_level_0,Unnamed: 26_level_0,Unnamed: 27_level_0,Unnamed: 28_level_0,Unnamed: 29_level_0,Unnamed: 30_level_0,Unnamed: 31_level_0
Token,Multiple of Avg. score,Pattern,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1
Image,29.2,Image,copyright,Reg,Has,lett,Image,caption,Not,a,great,day,to,be,under,canvas,in,Glen,arm,",",County,Ant,rim,,,Northern,Ireland,was,the,w,ett
Two,29.2,Two,worlds,collide,in,Mario,+,Rabb,ids,®,Kingdom,Battle,!,,,This,game,is,available,for,purchase,exclusively,on,Nintendo,e,Shop,for,Nintendo,Switch,.,
Em,29.0,Em,ma,M,æ,rs,k,is,the,first,container,ship,in,the,E,-,class,of,eight,owned,by,the,A,.,P,.,M,oller,-,Ma,ers
Anim,28.7,Anim,als,are,dying,in,unnecessary,agony,because,of,a,lack,of,understanding,over,how,stunning,stops,them,feeling,pain,when,their,throats,are,cut,",",research,shows,.,
FBI,28.7,FBI,Hand,out,Gun,man,who,killed,a,security,agent,at,LA,X,was,carrying,a,note,expressing,',dis,app,ointment,in,the,government,"',",reports,claim,.,
A,28.6,A,Texas,woman,has,filed,a,lawsuit,against,three,police,officers,in,Victoria,",",claiming,that,they,brutally,beat,her,and,broke,her,ribs,without,a,good,reason,.,
The,28.0,The,Reds,announced,that,they,signed,free,agent,outfielder,Ryan,Lud,wick,to,a,two,-,year,deal,with,a,mutual,option,for,2015,(,Twitter,link,).,The,B
Metal,28.0,Metal,Gear,Solid,5,:,The,Phantom,Pain,was,our,game,of,the,year,in,2015,",",however,the,foundations,for,it,were,laid,by,standalone,introductory,chapter,Ground,Zer
Em,28.0,Em,ma,M,æ,rs,k,is,the,first,container,ship,in,the,E,-,class,of,eight,owned,by,the,A,.,P,.,M,oller,-,Ma,
Image,28.0,Image,copyright,Reg,Has,lett,Image,caption,Not,a,great,day,to,be,under,canvas,in,Glen,arm,",",County,Ant,rim,,,Northern,Ireland,was,the,w,


In [20]:
# Look at the top occurences for a particular token
example_str_token = ranked_multiples[4][0]

ranked_multiples = attention_data.get_ranked_multiples(
    head=head, 
    layer=layer, 
    num_multiples=2,
    str_token=example_str_token,
    display=True
)

"Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'","Layer 10 Head 7, Top 2 / 4639 Multiplesfor 'FBI'",Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,Unnamed: 16_level_0,Unnamed: 17_level_0,Unnamed: 18_level_0,Unnamed: 19_level_0,Unnamed: 20_level_0,Unnamed: 21_level_0,Unnamed: 22_level_0,Unnamed: 23_level_0,Unnamed: 24_level_0,Unnamed: 25_level_0,Unnamed: 26_level_0,Unnamed: 27_level_0,Unnamed: 28_level_0,Unnamed: 29_level_0,Unnamed: 30_level_0,Unnamed: 31_level_0
Token,Multiple of Avg. score,Pattern,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1
FBI,28.7,FBI,Hand,out,Gun,man,who,killed,a,security,agent,at,LA,X,was,carrying,a,note,expressing,',dis,app,ointment,in,the,government,"',",reports,claim,.,
FBI,23.1,FBI,Hand,out,Gun,man,who,killed,a,security,agent,at,LA,X,was,carrying,a,note,expressing,',dis,app,ointment,in,the,government,,,,,


In [22]:
# Look at a random grouping of multiples that were larger than average

random_multiples = attention_data.get_random_multiples(
    head=head, 
    layer=layer, 
    num_multiples=15,
    display=True
)

"Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples","Layer 10 Head 7, 15 / 4639 Random Multiples",Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0,Unnamed: 14_level_0,Unnamed: 15_level_0,Unnamed: 16_level_0,Unnamed: 17_level_0,Unnamed: 18_level_0,Unnamed: 19_level_0,Unnamed: 20_level_0,Unnamed: 21_level_0,Unnamed: 22_level_0,Unnamed: 23_level_0,Unnamed: 24_level_0,Unnamed: 25_level_0,Unnamed: 26_level_0,Unnamed: 27_level_0,Unnamed: 28_level_0,Unnamed: 29_level_0,Unnamed: 30_level_0,Unnamed: 31_level_0
Token,Multiple of Avg. score,Pattern,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1
Two,14.8,Two,worlds,collide,in,Mario,+,Rabb,ids,®,Kingdom,Battle,!,,,This,game,is,available,for,purchase,exclusively,,,,,,,,,
bridge,1.7,L,ANS,ING,",",MI,--,En,bridge,directors,say,there,are,no,areas,where,bare,Line,5,metal,is,exposed,to,Great,Lakes,water,but,admitted,during,a,Michigan
marijuana,2.5,Several,states,voted,to,legalize,marijuana,this,past,Election,Day,but,the,pot,business,still,has,a,gri,pe,—,reg,ulations,.,,,Though,decriminal,ized,on,
The,25.3,The,45,-,year,-,old,�,�,high,way,shooter,�,�,who,engaged,in,a,12,-,minute,shootout,with,California,Highway,Patrol,officers,earlier,,,
SHARE,9.2,SHARE,THIS,ARTICLE,Share,Tweet,Post,Email,,,Aust,rians,elected,a,Green,Party,-,backed,,,,,,,,,,,,,
Researchers,15.4,Researchers,from,the,Johns,Hopkins,Center,for,Gun,Policy,and,Research,",",part,of,the,Johns,Hopkins,Bloomberg,School,of,Public,Health,",",compared,,,,,,
legalize,1.8,Several,states,voted,to,legalize,marijuana,this,past,Election,Day,but,the,pot,business,still,has,a,gri,pe,—,reg,ulations,.,,,Though,,,,
The,12.5,The,way,forward,for,aut,ow,ork,ers,:,An,online,interview,with,Jerry,White,,,15,October,2015,,,,,,,,,,
Booster,1.7,Im,g,Project,Des,cript,on,Back,ers,Pledge,/,Goal,/,%,+,1,Button,Co,Mo,Booster,Board,,,,,,,,,,
It,12.9,It,'s,prom,season,at,high,schools,across,the,country,",",a,special,time,that,until,recently,has,been,,,,,,,,,,,
