In [1]:
%load_ext autoreload

In [1]:
from transformer_tests import *

model = CustomHookedTransformer.from_pretrained("attn-only-2l", device=device)

Loaded pretrained model attn-only-2l into HookedTransformer


In [27]:
torch.manual_seed(120)
seq_len = 3
nbatch = 4000

def head_ablation_hook(attn_result, hook):
    # attn_result[:, 3, :, :] = 0.0
    return attn_result

rep_tokens, str_tokens, pattern_store, logits, lgts = \
run_sequence(model, seq_len, head_ablation_hook, nbatch=nbatch)

In [28]:
lgts['mean_logit_diff'], lgts['p1mean'], lgts['p2mean']

(tensor(2.6467), tensor(0.0649), tensor(0.0132))

In [30]:
str_tokens

['BOS',
 ' alpha',
 ' epsilon',
 ' beta',
 ' gamma',
 ' epsilon',
 ' delta',
 ' alpha',
 ' epsilon',
 ' beta']

In [31]:
str_tokens_plot = ['BOS',
 ' [A]',
 ' [X]',
 ' [B]',
 ' [C]',
 ' [X]',
 ' [D]',
 ' [A]',
 ' [X]',
 ' [B]']

In [32]:
cv.attention.attention_patterns(tokens=str_tokens_plot, attention=pattern_store[1])

In [37]:
len(str_tokens_plot), len(pattern_store[1][6][-2])

(10, 10)

In [41]:
str_tokens_plot = ['BOS',
 '[A]',
 '[X]',
 '[B]',
 '[C] ',
 '[X] ',
 '[D] ',
 '[A]  ',
 '[X]  ',
 '[B]  ']

In [52]:
fig = px.line(pattern_store[1][6][-2][:-1])
fig.update_layout(xaxis_title="Input token", yaxis_title="Attention weight",
                  xaxis=dict(tickvals=np.arange(10), ticktext=str_tokens_plot[:-1]), showlegend=False)
fig_to_json(fig, json_dir, "triplet_attn_pattern_line2")
fig

### How does the model complete triplet sequences?

In [73]:
torch.manual_seed(129)
setattr(model.cfg, "custom_type", "None")
seq_len = 48
batch = 1000
tokens = generate_random_token_sequence(model, seq_len, batch, prefix_flag=False)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
# each token is different
str_tokens = 4*[' alpha', ' beta', ' gamma', ' delta', ' epsilon', ' zeta', ' eta', ' theta', ' iota', ' kappa', ' lambda', ' mu']
attn_patterns_random = get_attn_patterns(cache)

In [71]:
cv.attention.attention_patterns(tokens=str_tokens, attention=attn_patterns_random[1])


In [76]:
px.line(attn_patterns_random[1][6][-1])

In [56]:
setattr(model.cfg, "custom_type", "position_only")
input_string = 'Mary and John went to the store to get some milk'
str_tokens = ['<|BOS|>', 'Mary', ' and', ' John', ' went', ' to',
              ' the', ' store', ' to', ' get', ' some', ' milk']
tokens = model.to_tokens(input_string, prepend_bos=True)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
attn_patterns = get_attn_patterns(cache)
fig = cv.attention.attention_patterns(str_tokens, attn_patterns[0])
fig

### Testing longer sequences

In [67]:
seq_lens = [2, 3, 4, 5, 6, 7]

for seq_len in seq_lens:
    setattr(model.cfg, "custom_type", "None")

    torch.manual_seed(120)
    nbatch = 4000

    def head_ablation_hook(attn_result, hook):
        # attn_result[:, 3, :, :] = 0.0
        return attn_result

    rep_tokens, str_tokens, pattern_store, logits, lgts = \
    run_sequence(model, seq_len, head_ablation_hook, nbatch=nbatch)

    print(seq_len, lgts['mean_logit_diff'], lgts['p1mean'], lgts['p2mean'])

2 tensor(6.1266) tensor(0.0644) tensor(0.0005)
3 tensor(2.6467) tensor(0.0649) tensor(0.0132)
4 tensor(1.3655) tensor(0.0559) tensor(0.0292)
5 tensor(0.3211) tensor(0.0524) tensor(0.0441)
6 tensor(-0.1406) tensor(0.0473) tensor(0.0494)
7 tensor(-0.3511) tensor(0.0462) tensor(0.0557)


#### Try again with longer sequence so probs go to 0

In [77]:
def generate_tokens_arb_seq3(model, seq_len, batch):
    '''
    Generates a sequence of repeated random tokens
    '''
    prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long() # tensor([[1]])
    first_sequence = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    second_sequence = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    third_sequence = torch.randint(0, model.cfg.d_vocab, (batch, 25), dtype=torch.int64)
    fourth_sequence = torch.randint(0, model.cfg.d_vocab, (batch, 50), dtype=torch.int64)
    fifth_sequence = torch.randint(0, model.cfg.d_vocab, (batch, 50), dtype=torch.int64)
    second_sequence[:, 1:-1] = first_sequence[:, 1:-1]
    # token_sequence = torch.cat([prefix, third_sequence,
    #                             second_sequence, first_sequence, 
    #                             second_sequence, first_sequence, 
    #                             second_sequence, first_sequence, 
    #                             second_sequence, first_sequence,
    #                             third_sequence, 
    #                             first_sequence], dim=-1).to(device)
    token_sequence = torch.cat([prefix,
                                fourth_sequence, # filler
                                first_sequence,
                                # fifth_sequence, # filler
                                second_sequence,
                                third_sequence, # filler
                                second_sequence], dim=-1).to(device)

    alphabet = [' epsilon', ' zeta', ' eta', ' theta',
                ' iota', ' kappa', ' lambda', ' mu',
                ' nu', ' xi', ' omicron', ' pi', ' rho',
                ' sigma', ' tau', ' upsilon']

    randoms = alphabet[:seq_len-2]
    seq1 = [' alpha'] + randoms + [' beta']
    seq2 = [' gamma'] + randoms + [' delta']
    str_tokens = ['BOS'] + seq1 + seq2 + seq1 + seq2 + seq1

    ipos = -2

    return token_sequence, str_tokens, ipos


In [80]:
seq_lens = [2, 3, 4, 5, 6, 7]
seq_lens = [5, 6, 7]
seq_lens = [7]

for seq_len in seq_lens:
    setattr(model.cfg, "custom_type", "None")

    torch.manual_seed(120)
    nbatch = 1000

    def head_ablation_hook(attn_result, hook):
        # attn_result[:, 4, :, :] = 0.0
        return attn_result

    rep_tokens, str_tokens, ipos = generate_tokens_arb_seq3(model, seq_len, nbatch)

    # Run with hooks
    ilayer = 0

    pattern_store = torch.zeros((model.cfg.n_layers,
                                model.cfg.n_heads, len(rep_tokens[0]), len(rep_tokens[0])),
                                device=model.cfg.device)

    def pattern_hook(pattern, hook):
        avg_pattern = einops.reduce(pattern, "batch head_index p1 p2 -> head_index p1 p2", "mean")
        pattern_store[hook.layer()] = avg_pattern

    # We make a boolean filter on activation names, that's true only on attention pattern names
    pattern_hook_names_filter = lambda name: name.endswith("pattern")


    logits = model.run_with_hooks(
        rep_tokens,
        return_type='logits', # For efficiency, we don't need to calculate the logits
        fwd_hooks=[(utils.get_act_name("pattern", ilayer), head_ablation_hook),
                (pattern_hook_names_filter, pattern_hook)]
    )

    lgts = {}
    output_tokens = rep_tokens.index_select(1, torch.tensor([seq_len+50,50+seq_len*2])).T
    output_logits = logits[:, ipos, :]
    output_probs = logits[:, ipos, :].softmax(-1)

    lgts['logits1'] = output_logits[np.arange(0, output_logits.shape[0]), output_tokens[0]]
    lgts['logits2'] = output_logits[np.arange(0, output_logits.shape[0]), output_tokens[1]]
    lgts['probs1'] = output_probs[np.arange(0, output_probs.shape[0]), output_tokens[0]]
    lgts['probs2'] = output_probs[np.arange(0, output_probs.shape[0]), output_tokens[1]]

    lgts['mean_logit_diff'] = (lgts['logits1'] - lgts['logits2']).mean()
    lgts['p1mean'] = lgts['probs1'].mean()
    lgts['p2mean'] = lgts['probs2'].mean()
    lgts['pmean'] = output_probs.mean()

    maxes = output_probs.argmax(-1)
    lgts['nmax0'] = (maxes == output_tokens[0]).type(torch.float64).mean()
    lgts['nmax16'] = (maxes == 16).type(torch.float64).mean()

    print(seq_len, lgts['mean_logit_diff'], lgts['p1mean'], lgts['p2mean'])


7 tensor(-0.7710) tensor(0.0414) tensor(0.0590)


In [None]:
6 tensor(0.4474) tensor(0.0477) tensor(0.0432)


In [8]:
5 tensor(-1.8811) tensor(0.0270) tensor(0.0749)
6 tensor(-1.2793) tensor(0.0375) tensor(0.0599)
7 tensor(-0.6719) tensor(0.0415) tensor(0.0587)

tensor(0.0002)

torch.Size([2, 8, 67, 67])

In [45]:
cv.attention.attention_patterns([str(x) for x in np.arange(pattern_store.shape[-1])], pattern_store[1])

In [81]:
px.line(pattern_store[1][6][-2])

In [None]:
0.348, 0.043

In [60]:
pattern_store[1][6][-2][6], pattern_store[1][6][-2][62]

(tensor(0.3078), tensor(0.4594))

In [54]:
pattern_store[1][6][-2][6], pattern_store[1][6][-2][12]

(tensor(0.3227), tensor(0.4325))

In [None]:
2 tensor(11.1258) tensor(0.4308) tensor(0.0006)
3 tensor(4.5519) tensor(0.2179) tensor(0.0262)
4 tensor(2.2751) tensor(0.1442) tensor(0.0536)

In [64]:
setattr(model.cfg, "custom_type", "None")

torch.manual_seed(120)
seq_len = 2
nbatch = 4000

def head_ablation_hook(attn_result, hook):
    # attn_result[:, 3, :, :] = 0.0
    return attn_result

rep_tokens, str_tokens, pattern_store, logits, lgts = \
run_sequence(model, seq_len, head_ablation_hook, nbatch=nbatch)


In [65]:
lgts['mean_logit_diff'], lgts['p1mean'], lgts['p2mean']

(tensor(6.1266), tensor(0.0644), tensor(0.0005))

In [63]:
cv.attention.attention_patterns(tokens=str_tokens, attention=pattern_store[1])

In [39]:
setattr(model.cfg, "custom_type", "position_only")
input_string = 'Mary and John went to the store to get some milk Mary and John went to the store to get some milk'
input_string = input_string * 2
tokens = model.to_tokens(input_string, prepend_bos=True)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
attn_patterns = get_attn_patterns(cache)
fig = cv.attention.attention_patterns([str(x) for x in np.arange(attn_patterns[0].shape[-1])], attn_patterns[1])
fig

# Patching with toy model

In [82]:
seq_len = 3
nbatch = 400
torch.manual_seed(110)
rep_tokens, str_tokens, ipos = generate_tokens_arb_seq2(model, seq_len, nbatch)


In [83]:
clean_tokens = rep_tokens
corrupted_tokens = torch.clone(rep_tokens)
corrupted_tokens[:, -seq_len] = corrupted_tokens[:, -2*seq_len]
answer_token_indices = torch.stack((clean_tokens[:, seq_len], clean_tokens[:, seq_len * 2])).T

In [84]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -2, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 4.9987
Corrupted logit diff: -4.3181


In [85]:
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)


In [86]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [87]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)


  0%|          | 0/16 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [88]:
attn_head_out_all_pos_act_patch_results

tensor([[ 3.8196e-01,  2.7722e-02,  1.7210e-03,  5.7585e-01, -2.2249e-03,
          1.9580e-03, -6.7935e-03, -9.2724e-04],
        [-4.6523e-04, -1.7757e-03,  1.7705e-03,  1.7271e-03,  1.0182e-03,
          6.7267e-04,  9.2941e-01,  6.8933e-02]])

imshow(attn_head_out_all_pos_act_patch_results)

In [96]:
imshow(attn_head_out_all_pos_act_patch_results,
       yaxis="Layer",
       xaxis="Head",
       title="attn_head_out Activation Patching (All Pos)")

In [90]:
cv.attention.attention_patterns(tokens=str_tokens, attention=clean_cache[''])

KeyError: 'hook_'

In [91]:
clean_tokens[0]

tensor([    1, 19056,  1655, 35935,  8452,  1655, 19575, 19056,  1655, 35935,
         8452,  1655, 19575, 19056,  1655, 35935,  8452,  1655, 19575, 19056,
         1655, 35935,  8452,  1655, 19575, 19056,  1655, 35935])

In [92]:
corrupted_tokens[0]

tensor([    1, 19056,  1655, 35935,  8452,  1655, 19575, 19056,  1655, 35935,
         8452,  1655, 19575, 19056,  1655, 35935,  8452,  1655, 19575, 19056,
         1655, 35935,  8452,  1655, 19575,  8452,  1655, 35935])

In [93]:
answer_token_indices[0]

tensor([35935, 19575])

In [94]:
def get_attn_patch(seq_len, nbatch):
    # seq_len = 4
    # nbatch = 100
    torch.manual_seed(110)
    rep_tokens, str_tokens, ipos = generate_tokens_arb_seq(model, seq_len, nbatch)

    clean_tokens = rep_tokens
    corrupted_tokens = torch.clone(rep_tokens)
    corrupted_tokens[:, seq_len*2 + 1] = corrupted_tokens[:, seq_len + 1]
    answer_token_indices = torch.stack((clean_tokens[:, seq_len], clean_tokens[:, seq_len * 2])).T
    
    def get_logit_diff(logits, answer_token_indices=answer_token_indices):
        if len(logits.shape)==3:
            # Get final logits only
            logits = logits[:, -2, :]
        correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
        incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
        return (correct_logits - incorrect_logits).mean()

    clean_logits, clean_cache = model.run_with_cache(clean_tokens)
    corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

    clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
    # print(f"Clean logit diff: {clean_logit_diff:.4f}")

    corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
    # print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")
    
    CLEAN_BASELINE = clean_logit_diff
    CORRUPTED_BASELINE = corrupted_logit_diff
    def ioi_metric(logits, answer_token_indices=answer_token_indices):
        return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

    # print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
    # print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")


    attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    
    return attn_head_out_all_pos_act_patch_results


In [95]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)

  0%|          | 0/16 [00:00<?, ?it/s]

In [97]:
attn_patch_list = []
seq_lens = [2, 3, 4, 5, 6, 7, 8]
nbatchs = [200, 200, 400, 400, 500, 500, 500]

for seq_len, nbatch in zip(seq_lens, nbatchs):
    attn_patch = get_attn_patch(seq_len, nbatch)
    attn_patch_list.append(attn_patch)



  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [99]:
v.shape

torch.Size([14, 8])

In [100]:
attn_head_out_all_pos_act_patch_results.shape

torch.Size([2, 8])

In [101]:
import plotly.io as pio
pio.renderers

pio.renderers.default = "vscode"


In [130]:
v = torch.cat(attn_patch_list)[2:-4]

In [131]:
w = np.column_stack((v[::2], v[1::2]))

In [132]:
w.shape

(4, 16)

In [147]:
fig = px.imshow(w, zmin=0, zmax=1,
          color_continuous_midpoint=0.0,
          color_continuous_scale="Blues",
          labels={"x":"Attention Head", "y":"Sequence Length"}) #.show(renderer)
xlabels = ['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4', 'L0H5', 'L0H6', 'L0H7',
           'L1H0', 'L1H1', 'L1H2', 'L1H3', 'L1H4', 'L1H5', 'L1H6', 'L1H7']
ylabels = ['3', '4', '5', '6']
# ylabels = ['Seq. 2 L0', 'Seq. 2 L1',
#            'Seq. 3 L0', 'Seq. 3 L1',
#            'Seq. 4 L0', 'Seq. 4 L1',
#            'Seq. 5 L0', 'Seq. 5 L1',
#            'Seq. 6 L0', 'Seq. 6 L1']

fig.update_layout(xaxis=dict(tickvals=np.arange(16), ticktext=xlabels),
                  yaxis=dict(tickvals=np.arange(4), ticktext=ylabels))

fig_to_json(fig, json_dir, "attn_patch_heatmap")
fig

In [119]:
# fig = imshow(v[2:-2],
#        yaxis="Layer",
#        xaxis="Head",
#        title="attn_head_out Activation Patching (All Pos)")

fig = px.imshow(utils.to_numpy(v[2:-2]),
          color_continuous_midpoint=0.0,
          color_continuous_scale="RdBu",
          labels={"x":"Attention Head", "y":"Sequence # and Layer"}) #.show(renderer)

ylabels = ['Seq. 2 L0', 'Seq. 2 L1',
           'Seq. 3 L0', 'Seq. 3 L1',
           'Seq. 4 L0', 'Seq. 4 L1',
           'Seq. 5 L0', 'Seq. 5 L1',
           'Seq. 6 L0', 'Seq. 6 L1']

fig.update_layout(xaxis=dict(tickvals=np.arange(8), ticktext=[0, 1, 2, 3, 4, 5, 6, 7]),
                  yaxis=dict(tickvals=np.arange(10), ticktext=ylabels))

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]

attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")


  0%|          | 0/208 [00:00<?, ?it/s]

In [None]:
attn_head_out_act_patch_results.shape

torch.Size([16, 13])

In [None]:
imshow(attn_head_out_act_patch_results)

In [None]:
imshow(attn_head_out_act_patch_results, 
    yaxis="Head Label", 
    xaxis="Pos", 
    x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
    y=ALL_HEAD_LABELS,
    title="attn_head_out Activation Patching By Pos")

In [None]:
imshow(attn_head_out_all_pos_act_patch_results,
       yaxis="Layer",
       xaxis="Head",
       title="attn_head_out Activation Patching (All Pos)")

In [None]:
prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, James gave the ball to', 'When Tom and James went to the park, Tom gave the ball to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'After Martin and Amy went to the park, Amy gave a drink to', 'After Martin and Amy went to the park, Martin gave a drink to']
answers = [(' Mary', ' John'), (' John', ' Mary'), (' Tom', ' James'), (' James', ' Tom'), (' Dan', ' Sid'), (' Sid', ' Dan'), (' Martin', ' Amy'), (' Amy', ' Martin')]

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|BOS|>When John and Mary went to the shops, John gave the bag to
Corrupted string 0 <|BOS|>When John and Mary went to the shops, Mary gave the bag to
Answer token indices tensor([[ 6221,  2436],
        [ 2436,  6221],
        [ 6098,  5330],
        [ 5330,  6098],
        [ 5518, 23953],
        [23953,  5518],
        [ 8470, 21462],
        [21462,  8470]])


### Induction bug

In [406]:
torch.manual_seed(130)
seq_len = 4
nbatch = 1000

def generate_tokens_arb_seq_old2(model, seq_len, batch):
    '''
    Generates a sequence of repeated random tokens
    of the form: a, b, c, d, x, b, c, y, a, b, c, d, x, b, c, y

    Outputs are:
        rep_tokens: [batch, 1+2*seq_len]
    '''
    prefix = (torch.ones(batch, 1) * model.tokenizer.bos_token_id).long() # tensor([[1]])
    first_sequence = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    second_sequence = torch.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=torch.int64)
    second_sequence[:, 1:-1] = first_sequence[:, 1:-1]
    # second_sequence[:, 1:-1] = first_sequence[:, 1:-1]
    # We want a sequence of the form: a, b, c, d, x, b, c, y, a, b, c, d, x, b, c, y
    token_sequence = torch.cat([prefix, first_sequence, second_sequence, first_sequence, second_sequence, first_sequence], dim=-1).to(device)
    return token_sequence

rep_tokens = generate_tokens_arb_seq_old2(model, seq_len, nbatch)
alphabet = [' epsilon', ' zeta', ' eta', ' theta', ' iota', ' kappa', ' lambda', ' mu', ' nu', ' xi', ' omicron', ' pi', ' rho', ' sigma', ' tau', ' upsilon']
randoms = alphabet[:seq_len-2]
seq1 = [' alpha'] + randoms + [' beta']
seq2 = [' gamma'] + randoms + [' delta']
str_tokens = ['BOS'] + seq1 + seq2 + seq1 + seq2 + seq1


In [407]:
rep_tokens[0][1:].reshape(5, 4)

tensor([[ 3121, 24847,  8911,   215],
        [24022, 24847,  8911, 10192],
        [ 3121, 24847,  8911,   215],
        [24022, 24847,  8911, 10192],
        [ 3121, 24847,  8911,   215]])

In [408]:
# indices = {5: 12, 4: 1}
# indices = {6: 3, 7: 2, 13: 3, 14: 5, 15: 2}
indices = {2: 7, 3:6, 10:7, 11: 6, 17: 7, 18: 1, 19: 6}

# indices = {6: 13, 5: 14}
# indices = {2: 1, 10: 1, 1: 6, 9: 6}
# indices = {}

for key, value in indices.items():
  str_tokens[key] = str_tokens[value]
  rep_tokens[:, key] = torch.clone(rep_tokens[:, value])

# torch.manual_seed(200)
# rep_tokens[:, 15] = torch.randint(0, model.cfg.d_vocab, (1, nbatch), dtype=torch.int64).to(device)
# rep_tokens[:, 16] = torch.randint(0, model.cfg.d_vocab, (1, nbatch), dtype=torch.int64).to(device)
rep_tokens[:, 17] = torch.randint(0, model.cfg.d_vocab, (1, nbatch), dtype=torch.int64).to(device)
rep_tokens[:, 18] = torch.randint(0, model.cfg.d_vocab, (1, nbatch), dtype=torch.int64).to(device)

# str_tokens[5] = str_tokens[12]
# str_tokens[4] = str_tokens[1]
# str_tokens[12] = str_tokens[2]
# str_tokens[13] = str_tokens[2]
# str_tokens[14] = str_tokens[1]

# rep_tokens[:, 5] = torch.clone(rep_tokens[:, 12])
# rep_tokens[:, 4] = torch.clone(rep_tokens[:, 1])
# rep_tokens[:, 12] = rep_tokens[:, 2]
# rep_tokens[:, 13] = rep_tokens[:, 2]
# rep_tokens[:, 14] = rep_tokens[:, 1]


In [409]:
rep_tokens[0][1:].reshape(5, 4)

tensor([[ 3121,  8911, 24847,   215],
        [24022, 24847,  8911, 10192],
        [ 3121,  8911, 24847,   215],
        [24022, 24847,  8911, 10192],
        [23359, 30140, 24847,   215]])

In [410]:
# Run with hooks (this is where we write to the `induction_score_store` tensor`)
ihead = 3
ilayer = 0

pattern_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads, len(rep_tokens[0]), len(rep_tokens[0])), device=model.cfg.device)

def pattern_hook(pattern, hook):
    avg_pattern = einops.reduce(pattern, "batch head_index p1 p2 -> head_index p1 p2", "mean")
    pattern_store[hook.layer()] = avg_pattern

def head_ablation_hook(attn_result, hook, head_index_to_ablate=ihead):
    # attn_result[:, head_index_to_ablate, :, :] = 0.0
    # attn_result[:, 3, :, :] = 0.0
    # attn_result[:, 0, :, :] = 0.0

    # attn_result[:, head_index_to_ablate+1:, :, :] = 0.0
    # mask = attn_result[:, head_index_to_ablate, :, :] < 0.3
    # attn_result[:, head_index_to_ablate, :, :][mask] = 0
    # attn_result[:, head_index_to_ablate, :, :] = 0.0

    return attn_result


# We make a boolean filter on activation names, that's true only on attention pattern names
pattern_hook_names_filter = lambda name: name.endswith("pattern")


logits = model.run_with_hooks(
    rep_tokens,
    return_type='logits', # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(utils.get_act_name("pattern", ilayer), head_ablation_hook),
              (pattern_hook_names_filter, pattern_hook)]
)


In [411]:
  lgts = {}
  output_tokens = rep_tokens.index_select(1, torch.tensor([seq_len, seq_len*2])).T
  output_tokens = rep_tokens.index_select(1, torch.tensor([seq_len, 2])).T
  output_logits = logits[:, ipos, :]
  output_probs = logits[:, ipos, :].softmax(-1)

  lgts['logits1'] = output_logits[np.arange(0, output_logits.shape[0]), output_tokens[0]]
  lgts['logits2'] = output_logits[np.arange(0, output_logits.shape[0]), output_tokens[1]]
  lgts['probs1'] = output_probs[np.arange(0, output_probs.shape[0]), output_tokens[0]]
  lgts['probs2'] = output_probs[np.arange(0, output_probs.shape[0]), output_tokens[1]]

  lgts['mean_logit_diff'] = (lgts['logits1'] - lgts['logits2']).mean()
  lgts['p1mean'] = lgts['probs1'].mean()
  lgts['p2mean'] = lgts['probs2'].mean()
  lgts['pmean'] = output_probs.mean()

  maxes = output_probs.argmax(-1)
  lgts['nmax0'] = (maxes == output_tokens[0]).type(torch.float64).mean()
  lgts['nmax16'] = (maxes == 16).type(torch.float64).mean()

In [404]:
output_tokens[:, 0]

tensor([ 215, 8911])

In [405]:
  lgts['p1mean'],   lgts['p2mean'],   lgts['mean_logit_diff']

(tensor(0.0408), tensor(0.0255), tensor(0.7218))

In [339]:
  lgts['p1mean'],   lgts['p2mean'],   lgts['mean_logit_diff']

(tensor(0.1862), tensor(0.0116), tensor(5.6141))

In [303]:
fig = px.line(pattern_store[1][6][-2][:-1])
fig.update_layout(xaxis_title="Input token", yaxis_title="Attention pattern",
                  xaxis=dict(tickvals=np.arange(17), ticktext=str_tokens_plot[:-1]), showlegend=False)
fig

In [278]:
a = logits.index_select(2, rep_tokens[:, -2])
a = rep_tokens.index_select(1, torch.tensor([seq_len, seq_len*2]))
b = logits[:, -2, :]
b.shape

torch.Size([500, 48262])

In [277]:
a = logits.index_select(2, rep_tokens[:, -2])
a = rep_tokens.index_select(1, torch.tensor([seq_len, seq_len*2]))
b = logits[:, -2, :]

c1 = b[np.arange(0, b.shape[0]), a[:, 0]]
c2 = b[np.arange(0, b.shape[0]), a[:, 1]]

(c1 - c2).mean(), c1.mean(), c2.mean()

# j1, j2 = rep_tokens[0][seq_len], rep_tokens[0][seq_len*2]

# logits[0][-2][j1], logits[0][-2][j2], j1, j2

(tensor(0.7084), tensor(9.0774), tensor(8.3690))

In [227]:
cv.attention.attention_patterns(tokens=str_tokens, attention=pattern_store[1])

In [345]:
str_tokens_plot = ['BOS',
                   '[A]', r'[X2]', r'[X1]', '[B]',
                   '[C]', r'[X1]', r'[X2]', '[D]',
                   '[A]', r'[X2]', r'[X1]', '[B]',
                   '[C]', r'[X1]', r'[X2]', '[D]',
                   '[X2]', r'[A]', r'[X1]']

In [346]:
len(pattern_store[1][6][-2][:-1]), len(str_tokens_plot)

(20, 20)

In [349]:
fig = px.line(pattern_store[1][6][-2][:-1], markers=True)
fig.update_layout(xaxis_title="Input token", yaxis_title="Attention pattern",
                  xaxis=dict(tickvals=np.arange(20), ticktext=str_tokens_plot), showlegend=False)
fig_to_json(fig, json_dir, "attn_pattern_induction_bug2")
fig

In [239]:
rep_tokens[0][1:].reshape(4, 4)

tensor([[ 3121, 24847,  8911,   215],
        [20945,  8911, 24847, 24994],
        [ 3121, 24847,  8911,   215],
        [ 8911, 20945, 24847, 24994]])