In [1]:
from transformer_tests import *

model = CustomHookedTransformer.from_pretrained("attn-only-2l", device=device)

Loaded pretrained model attn-only-2l into HookedTransformer


## Sequences of the form ..[A][B] ... [A][C] ... [A] &rarr; ?


#### Test the batch number required for averaging

In [3]:
torch.manual_seed(200)
batch_list = np.arange(500, 5001, 500)
lgts_list = []

for batch in batch_list:
    abac_args = {'seq_len': 12, 'ab_pos': 2, 'ac_pos': 9, 'batch': batch, 'prefix': True}
    token_sequence, str_tokens = genseq('abac', model, **abac_args)
    logits, cache = model.run_with_cache(token_sequence, remove_batch_dim=False)
    tokens_select = torch.tensor([abac_args['ab_pos'] + 1, abac_args['ac_pos'] + 1])
    lgts = compute_lgts(logits, token_sequence, tokens_select, ipos=-1)
    lgts_list.append(lgts)
    print(batch)


500
1000
1500
2000
2500
3000
3500
4000
4500
5000


In [5]:
probs1 = [x['p1mean'] for x in lgts_list]
probs2 = [x['p2mean'] for x in lgts_list]

p1_gtr_p2 = []
for x in lgts_list:
    p1_gtr_p2.append((x['probs1'] > x['probs2']).type(torch.float64).mean())

df = pd.DataFrame(p1_gtr_p2, index=batch_list)

In [6]:
px.line(df)

A batch size of around 2000 looks good enough

#### Calculate how probability changes moving the position of the [A][C] tokens

In [117]:
torch.manual_seed(200)
ac_pos_list = [4, 5, 6, 7, 8, 9, 10]
lgts_list_acpos = []
attn_patterns_list = []

for ac_pos in ac_pos_list:
    abac_args = {'seq_len': 12, 'ab_pos': 2, 'ac_pos': ac_pos, 'batch': 2000, 'prefix': True}
    token_sequence, str_tokens = genseq('abac', model, **abac_args)
    logits, cache = model.run_with_cache(token_sequence, remove_batch_dim=False)
    tokens_select = torch.tensor([abac_args['ab_pos'] + 1, abac_args['ac_pos'] + 1])
    lgts = compute_lgts(logits, token_sequence, tokens_select, ipos=-1)
    lgts_list_acpos.append(lgts)

    attn_patterns = get_attn_patterns(cache)
    attn_patterns_list.append(attn_patterns)


p1_gtr_p2 = []
for x in lgts_list_acpos:
    p1_gtr_p2.append((x['probs1'] > x['probs2']).type(torch.float64).mean())

probs1 = [x['p1mean'] for x in lgts_list_acpos]
probs2 = [x['p2mean'] for x in lgts_list_acpos]

df = pd.DataFrame(np.vstack((probs1, probs2)).T, index=ac_pos_list, columns=['p([B])', 'p([C])'])


In [15]:
df.index.name = 'Position of second [A] token'
fig = px.line(df, markers=True)
fig.update_layout(yaxis_title='Probability')
json_dir = '../blog/static'
fig_to_json(fig, json_dir, 'position_of_second_a_token')
fig.show()

In [118]:
attn_patterns = get_attn_patterns(cache)
fig = cv.attention.attention_patterns(str_tokens, attn_patterns[1])
fig

In [93]:
str_labels = ['<BOS>', 'Random', '[A]', '[B]', ' Random', '  Random',
              '   Random', '    Random', '     Random', '      Random', ' [A]', '[C]', '  [A]']
y_str_labels = [x.strip() for x in str_labels]
fig = px.imshow(attn_patterns[1][6], x=str_labels, y=str_labels,
          color_continuous_scale='blues')

fig.update_layout(width=600, height=600)
fig_to_json(fig, json_dir, name)

In [105]:
lines = [attn_pattern[1][6][-1] for attn_pattern in attn_patterns_list]
x_str_labels = ['<BOS>', 'Random', '[A]', '[B]', 'pos 4', 'pos 5',
              'pos 6', 'pos 7', 'pos 8', 'pos 9', 'pos 10', 'pos 11', '[A]  ']

df = pd.DataFrame(np.vstack(lines), index=ac_pos_list, columns=x_str_labels).T
df.index.name = 'Position in Sequence'
df.columns.name = 'Position of second [A]'
fig = px.line(df, markers=True)
fig.update_layout(yaxis_title='Attention pattern')
fig_to_json(fig, json_dir, 'position_of_second_a_token_attention')
fig.show()


### What causes this recency bias in the induction head?

In [109]:

def cv_fig_to_html(fig, html_dir, name):
   with open(html_dir + "/" + name + ".html", "w") as f:
       f.write(fig.__str__())
html_dir = '../blog/layouts/shortcodes'


In [111]:
setattr(model.cfg, "custom_type", "position_only")
input_string = 'Mary and John went to the store to get some milk Mary and John went to the store to get some milk'
str_tokens = ['<|BOS|>', 'Mary', ' and', ' John', ' went', ' to',
              ' the', ' store', ' to', ' get', ' some', ' milk',
              ' Mary', ' and', ' John', ' went', ' to',
              ' the', ' store', ' to', ' get', ' some', ' milk']
tokens = model.to_tokens(input_string, prepend_bos=True)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
attn_patterns = get_attn_patterns(cache)
fig = cv.attention.attention_patterns(str_tokens, attn_patterns[1])
cv_fig_to_html(fig, html_dir, "induction-head-pos-dependence")
fig

In [None]:
torch.manual_seed(129)
setattr(model.cfg, "custom_type", "None")
seq_len = 
batch = 1000
tokens = generate_random_token_sequence(model, seq_len, batch, prefix_flag=False)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
# each token is different
str_tokens = 4*[' alpha', ' beta', ' gamma', ' delta', ' epsilon', ' zeta', ' eta', ' theta', ' iota', ' kappa', ' lambda', ' mu']
attn_patterns_random = get_attn_patterns(cache)

In [2]:
torch.manual_seed(129)
setattr(model.cfg, "custom_type", "None")
seq_len = 48
batch = 500
tokens = generate_random_token_sequence(model, seq_len, batch, prefix_flag=False)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
# each token is different
str_tokens = 4*[' alpha', ' beta', ' gamma', ' delta', ' epsilon', ' zeta', ' eta', ' theta', ' iota', ' kappa', ' lambda', ' mu']
attn_patterns_random = get_attn_patterns(cache)

In [3]:
px.line(attn_patterns_random[1][6][-1])

In [17]:
fig = px.line(attn_patterns_random[1][6][7][:8], markers=True)
fig.update_layout(xaxis_title="Position in Sequence",
                  yaxis_title="Attention score",
                  showlegend=False)
fig_to_json(fig, json_dir, "random_attention_pattern_ind_head_7")
fig

In [10]:
fig = px.line(attn_patterns_random[1][6][-1], markers=True)
fig.update_layout(xaxis_title="Position in Sequence",
                  yaxis_title="Attention score",
                  showlegend=False)
fig_to_json(fig, json_dir, "random_attention_pattern_ind_head_48")
fig

### Same but ablate L0H4

In [116]:
model = CustomHookedTransformer.from_pretrained("attn-only-2l", device=device)

Loaded pretrained model attn-only-2l into HookedTransformer


In [272]:
# patterns_list = []
# for ablated_head in [0, 1, 2, 3, 4, 5, 6, 7]:

setattr(model.cfg, "custom_type", "token_only")

torch.manual_seed(129)
nbatch = 500
seq_len = 50

def head_ablation_hook(attn_result, hook):
    attn_result[:, 0, :, :] = 0.0
    # attn_result[:, 1, :, :] = 0.0
    attn_result[:, 2, :, :] = 0.0
    attn_result[:, 3, :, :] = 0.0
    attn_result[:, 4, :, :] = 0.0
    attn_result[:, 5, :, :] = 0.0
    attn_result[:, 6, :, :] = 0.0
    attn_result[:, 7, :, :] = 0.0
    return attn_result

tokens = generate_random_token_sequence(model, seq_len, nbatch, prefix_flag=False)

# Run with hooks
ilayer = 0

pattern_store = torch.zeros((model.cfg.n_layers,
                            model.cfg.n_heads, len(tokens[0]), len(tokens[0])),
                            device=model.cfg.device)

def pattern_hook(pattern, hook):
    avg_pattern = einops.reduce(pattern, "batch head_index p1 p2 -> head_index p1 p2", "mean")
    pattern_store[hook.layer()] = avg_pattern

# We make a boolean filter on activation names, that's true only on attention pattern names
pattern_hook_names_filter = lambda name: name.endswith("pattern")


logits = model.run_with_hooks(
    tokens,
    return_type='logits', # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(utils.get_act_name("pattern", ilayer), head_ablation_hook),
            (pattern_hook_names_filter, pattern_hook)]
)

# patterns_list.append(pattern_store[1, 6, -1])

In [277]:
fig = px.line(pattern_store[1, 6, -1], markers=True)
fig.update_layout(xaxis_title="Position in Sequence",
                  yaxis_title="Attention score",
                  showlegend=False)
fig_to_json(fig, json_dir, "all_heads_ablated_except_l0h2")
fig

In [75]:
df = pd.DataFrame(np.column_stack((patterns_list)).T).T
px.line(df, markers=True)

In [70]:
fig = px.line(pattern_store[1, 6, -1], markers=True)
fig.update_layout(xaxis_title="Position in Sequence",
                  yaxis_title="Attention score",
                  showlegend=False)