In [2]:
%load_ext autoreload

In [110]:
from transformer_tests import *

model = CustomHookedTransformer.from_pretrained("attn-only-2l", device=device)

Loaded pretrained model attn-only-2l into HookedTransformer


In [111]:
setattr(model.cfg, "custom_type", "token_only")
setattr(model.cfg, "attention_dir", "causal")
input_string = 'Mary and John went to the store to get some milk'
str_tokens = ['<|BOS|>', 'Mary', ' and', ' John', ' went', ' to',
              ' the', ' store', ' to', ' get', ' some', ' milk']
tokens = model.to_tokens(input_string)

torch.manual_seed(129)
seq_len = 12
batch = 500
tokens = generate_random_token_sequence(model, seq_len, batch, prefix_flag=True)

setattr(model.cfg, "custom_type", "standard")
logits, cache_total = model.run_with_cache(tokens, remove_batch_dim=False)
setattr(model.cfg, "custom_type", "position_only")
logits, cache_posn = model.run_with_cache(tokens, remove_batch_dim=False)
setattr(model.cfg, "custom_type", "token_only")
logits, cache_token = model.run_with_cache(tokens, remove_batch_dim=False)

attn_total = get_attn_patterns(cache_total)
attn_posn = get_attn_patterns(cache_posn)
attn_token = get_attn_patterns(cache_token)

attn_scores_total = get_attn_scores(cache_total)
attn_scores_posn = get_attn_scores(cache_posn)
attn_scores_token = get_attn_scores(cache_token)

patterns_stacked = torch.stack((attn_total[0], attn_posn[0], attn_token[0]))
scores_stacked = torch.stack((attn_scores_total[0], attn_scores_posn[0], attn_scores_token[0]))

In [112]:
str_tokens = ['BOS', '[A]', '[B]', '[C]', '[D]', '[E]', '[F]', '[G]', '[H]', '[I]', '[J]', '[K]', '[L]']

In [19]:
str_labels = str_tokens
y_str_labels = [x.strip() for x in str_labels]
fig = px.imshow(patterns_stacked[0][3], x=str_labels, y=str_labels,
          color_continuous_scale='blues')

fig.update_layout(width=500, height=500)
fig_to_json(fig, json_dir, 'attn_pattern_prev_token_head')
fig

In [113]:
str_labels = str_tokens
y_str_labels = [x.strip() for x in str_labels]
fig = px.imshow(patterns_stacked[1][3], x=str_labels, y=str_labels,
          color_continuous_scale='blues')

fig.update_layout(width=500, height=500)
fig_to_json(fig, json_dir, 'attn_pattern_prev_token_head_posn_only')
fig

### Test without BOS

In [122]:
# setattr(model.cfg, "custom_type", "token_only")
# setattr(model.cfg, "attention_dir", "causal")
# input_string = 'Mary and John went to the store to get some milk'
# str_tokens = ['<|BOS|>', 'Mary', ' and', ' John', ' went', ' to',
#               ' the', ' store', ' to', ' get', ' some', ' milk']
# tokens = model.to_tokens(input_string)

torch.manual_seed(129)
seq_len = 12
batch = 500
tokens = generate_random_token_sequence(model, seq_len, batch, prefix_flag=False)

setattr(model.cfg, "custom_type", "standard")
logits, cache_total = model.run_with_cache(tokens, remove_batch_dim=False)
setattr(model.cfg, "custom_type", "position_only")
logits, cache_posn = model.run_with_cache(tokens, remove_batch_dim=False)
setattr(model.cfg, "custom_type", "token_only")
logits, cache_token = model.run_with_cache(tokens, remove_batch_dim=False)

attn_total = get_attn_patterns(cache_total)
attn_posn = get_attn_patterns(cache_posn)
attn_token = get_attn_patterns(cache_token)

attn_scores_total = get_attn_scores(cache_total)
attn_scores_posn = get_attn_scores(cache_posn)
attn_scores_token = get_attn_scores(cache_token)

patterns_stacked = torch.stack((attn_total[0], attn_posn[0], attn_token[0]))
scores_stacked = torch.stack((attn_scores_total[0], attn_scores_posn[0], attn_scores_token[0]))



In [128]:
str_tokens = ['[A]', '[B]', '[C]', '[D]', '[E]', '[F]', '[G]', '[H]', '[I]', '[J]', '[K]', '[L]']
str_labels = str_tokens
y_str_labels = [x.strip() for x in str_labels]
fig = px.imshow(patterns_stacked[0][3], x=str_labels, y=str_labels,
          color_continuous_scale='blues')

fig.update_layout(width=500, height=500)
# fig_to_json(fig, json_dir, 'attn_pattern_prev_token_head_posn_only')
fig

### Line plots

In [24]:
df = pd.DataFrame(scores_stacked[:, 3, 8, :].T, columns=['Total', 'Position-only', 'Token-only'])
fig = px.line(df.iloc[:9], markers=True)
fig.update_yaxes(range=[-5, 21])
fig.update_layout(yaxis_title='Attention score', xaxis_title='Position in Sequence')
fig_to_json(fig, json_dir, 'attn_scores_prev_token_head')
fig

In [25]:
df = pd.DataFrame(patterns_stacked[:, 3, 8, :].T, columns=['Total', 'Position-only', 'Token-only'])
fig = px.line(df.iloc[:9], markers=True)
fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(yaxis_title='Attention pattern', xaxis_title='Position in Sequence')
fig_to_json(fig, json_dir, 'attn_pattern_prev_token_head_lastpos')
fig

In [26]:
df = pd.DataFrame(scores_stacked[:, 0, 8, :].T, columns=['Total', 'Position-only', 'Token-only'])
fig = px.line(df.iloc[:9], markers=True)
fig.update_yaxes(range=[-5, 20])
fig.update_layout(yaxis_title='Attention scores', xaxis_title='Position in Sequence')
fig_to_json(fig, json_dir, 'attn_scores_lbt_head_lastpos')
fig

In [27]:
df = pd.DataFrame(patterns_stacked[:, 0, 8, :].T, columns=['Total', 'Position-only', 'Token-only'])
fig = px.line(df.iloc[:9], markers=True)
fig.update_yaxes(range=[-0.1, 1.1])
fig.update_layout(yaxis_title='Attention pattern', xaxis_title='Position in Sequence')
fig_to_json(fig, json_dir, 'attn_pattern_lbt_head_lastpos')
fig

In [61]:
str_labels = str_tokens
y_str_labels = [x.strip() for x in str_labels]
fig = px.imshow(patterns_stacked[0][0], x=str_labels, y=str_labels,
          color_continuous_scale='blues')

fig.update_layout(width=500, height=500)
fig_to_json(fig, json_dir, 'attn_pattern_lbt_head')
fig

### Induction head "bug"

In [140]:
setattr(model.cfg, "custom_type", "standard")
batch = 1000
seq_len = 10
prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long() # tensor([[1]])
main_sequence = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
main_sequence[:, -1] = main_sequence[:, 0]
str_tokens = ['A', 'B', 'Random', 'Random', 'Random', 'Random', 'Random', 'Random', 'Random', 'A']
token_sequence = torch.cat([prefix, main_sequence], dim=-1).to(device)
# token_sequence = main_sequence
logits, cache = model.run_with_cache(token_sequence, remove_batch_dim=False)

tokens_select = torch.tensor([1, 2])
lgts = compute_lgts(logits, token_sequence, tokens_select, ipos=-1)
attn_patterns1 = get_attn_patterns(cache)
lgts['p2mean'], token_sequence[0]


(tensor(0.0551),
 tensor([    1, 12430, 34076, 39428, 32536, 24913, 37765, 38643,  7524, 40743,
         12430]))

In [141]:
setattr(model.cfg, "custom_type", "standard")
batch = 1000
seq_len = 10
prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long() # tensor([[1]])
main_sequence = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
main_sequence[:, -1] = main_sequence[:, 0]
# token_sequence = torch.cat([prefix, main_sequence], dim=-1).to(device)
token_sequence = main_sequence
logits, cache = model.run_with_cache(token_sequence, remove_batch_dim=False)

tokens_select = torch.tensor([1, 2])
lgts = compute_lgts(logits, token_sequence, tokens_select, ipos=-1)
attn_patterns2 = get_attn_patterns(cache)

lgts['p1mean'], token_sequence[0]

(tensor(7.2567e-06),
 tensor([20576, 12765, 26213, 39886, 11151, 28482, 19522, 35901, 19857, 20576]))

In [145]:
# fig = px.line(attn_patterns[1][6, -1])
fig = go.Figure()
str_tokens = ['[A]', '[B]', 'Random', 'Random', 'Random', 'Random', 'Random', 'Random', 'Random', '[A]']
str_tokens_plot = ['BOS'] + str_tokens
fig.add_trace(go.Scatter(y=list(attn_patterns[1][6, -1]), name='[A] at position 1', mode="lines"))
fig.add_trace(go.Scatter(y=[None] + list(attn_patterns2[1][6, -1]), name='[A] at position 0', mode="lines"))
fig.update_layout(yaxis_title='Attention Pattern', xaxis_title='Token', 
                  xaxis=dict(tickvals=np.arange(11), ticktext=str_tokens_plot))
fig_to_json(fig, json_dir, 'attn_pattern_prev_token_head_bug')
fig