In [24]:
import numpy as np
import torch
from alphatoe import plot, game, interpretability
from transformer_lens import HookedTransformer, HookedTransformerConfig
import json
import einops
import circuitsvis as cv
from plotly.subplots import make_subplots
import plotly.graph_objects as go

import matplotlib.pyplot as plt
from importlib import reload
from copy import copy
import pandas as pd
from showmethetypes import SMTT
from itertools import permutations

In [25]:
tt = SMTT()

In [26]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [27]:
def get_head_attention(seq: torch.Tensor) -> torch.Tensor:
    def hook(module, input, output):
        module.attention = output.clone()

    try:
        handle = model.blocks[0].hook_attn_out.register_forward_hook(hook)
        _ = model(seq)
        attention = model.blocks[0].hook_attn_out.attention
    except Exception as e:
        handle.remove()
        raise e

    return attention

In [28]:
def modify_pre_mlp_residuals(seq: torch.Tensor, vec) -> torch.Tensor:
    def hook(module, input, output):
        return output + vec

    try:
        handle = model.blocks[0].hook_attn_out.register_forward_hook(hook)
        out = model(seq)
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return out

In [29]:
def get_activation_vector(move: int):
    moves = list(range(0, 9))
    moves.remove(move)
    z_seqs = [[10, move, snd] for snd in moves]
    other_seqs = [[[10, fst, snd] for fst in moves if fst != snd] for snd in moves]
    all_acts = []
    flatten_acts = []
    for i, z_seq in enumerate(z_seqs):
        z_act = get_head_attention((torch.tensor(z_seq)))[0, -1]
        z_acts = []
        for other_seq in other_seqs[i]:
            other_act = get_head_attention((torch.tensor(other_seq)))[0, -1]

            act_diff = other_act - z_act
            z_acts.append(act_diff)
            flatten_acts.append(act_diff)

        all_acts.append(z_acts)
    flat_acts = torch.stack(flatten_acts)
    # return torch.norm(flat_acts, dim=1)
    return flat_acts.mean(0)

In [30]:
list(permutations(range(3), 2))

[(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]

In [31]:
def get_positional_vector(l: int):
    moves = list(range(0, 9))
    seqs = torch.tensor(
        list(map(lambda lis: [10] + list(lis), permutations(moves, l - 1)))
    )
    tt(seqs)

    all_acts = get_head_attention(seqs)[:, -1]
    # return torch.norm(flat_acts, dim=1)
    tt(all_acts)
    return all_acts.mean(0)

In [32]:
tt = SMTT("torch")

In [33]:
tt = SMTT()

In [34]:
pos_vec = get_positional_vector(6)
tt(pos_vec)

Tensor (dtype: torch.int64)
    |  (device: cpu)
    |__dim_0 (15120)
    |__dim_1 (6)
Tensor (dtype: torch.float32)
    |  (device: cuda:0)
    |__dim_0 (15120)
    |__dim_1 (128)
Tensor (dtype: torch.float32)
    |  (device: cuda:0)
    |__dim_0 (128)


In [35]:
zero_act = get_activation_vector(0)
one_act = get_activation_vector(1)
two_act = get_activation_vector(2)
three_act = get_activation_vector(3)
four_act = get_activation_vector(4)
five_act = get_activation_vector(5)
six_act = get_activation_vector(6)
seven_act = get_activation_vector(7)
eight_act = get_activation_vector(8)

In [36]:
tt(zero_act)

Tensor (dtype: torch.float32)
    |  (device: cuda:0)
    |__dim_0 (128)


In [37]:
seq = [10]
vec = zero_act + one_act + two_act
out = model(torch.tensor(seq))
add_vec = modify_pre_mlp_residuals(torch.tensor(seq), vec)
subtract_vec = modify_pre_mlp_residuals(torch.tensor(seq), -vec)
print("Normal output", torch.nn.functional.softmax(out[0, -1], dim=0))
print("Subtracting 0, adding 1", add_vec[0, -1])
print("Adding 0, subtracting 1", subtract_vec[0, -1])

Normal output tensor([0.1111, 0.1112, 0.1113, 0.1109, 0.1113, 0.1112, 0.1110, 0.1107, 0.1113,
        0.0000], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Subtracting 0, adding 1 tensor([ 38.1596,  38.6384,  40.7392,  -1.4404,  -2.3295,  -1.6009,  -3.1859,
         -1.2897,  -1.2053, -90.7778], device='cuda:0',
       grad_fn=<SelectBackward0>)
Adding 0, subtracting 1 tensor([-24.3372, -16.9424, -20.2947,  25.0872,  25.0004,  22.9896,  24.7819,
         23.7868,  22.5200, -65.9743], device='cuda:0',
       grad_fn=<SelectBackward0>)


In [38]:
pos_six = model.pos_embed(torch.tensor(range(6)))[-1, -1]

In [39]:
tt(pos_six)

Tensor (dtype: torch.float32)
    |  (device: cuda:0)
    |__dim_0 (128)


In [40]:
seq = [10]
vec = zero_act + one_act + two_act + pos_vec + pos_six
out = model(torch.tensor(seq))
add_vec = modify_pre_mlp_residuals(torch.tensor(seq), vec)
subtract_vec = modify_pre_mlp_residuals(torch.tensor(seq), -vec)
print("Normal output", torch.nn.functional.softmax(out[0, -1], dim=0))
print("Subtracting 0, adding 1", add_vec[0, -1])
print("Adding 0, subtracting 1", subtract_vec[0, -1])

Normal output tensor([0.1111, 0.1112, 0.1113, 0.1109, 0.1113, 0.1112, 0.1110, 0.1107, 0.1113,
        0.0000], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Subtracting 0, adding 1 tensor([  47.3994,   46.4811,   48.6058,   -2.3001,   -2.9395,   -0.8019,
          -4.1709,   -2.7387,   -1.7201, -113.8362], device='cuda:0',
       grad_fn=<SelectBackward0>)
Adding 0, subtracting 1 tensor([-16.6672, -11.0680, -13.0832,  16.5233,  15.0852,  13.5191,  14.5390,
         12.9728,  12.5428, -34.2667], device='cuda:0',
       grad_fn=<SelectBackward0>)


In [41]:
seq = [10, 0, 10, 1, 10, 2]
vec = zero_act + one_act + two_act + pos_vec + pos_six
out = model(torch.tensor(seq))
add_vec = modify_pre_mlp_residuals(torch.tensor(seq), vec)
subtract_vec = modify_pre_mlp_residuals(torch.tensor(seq), -vec)
print("Normal output", torch.nn.functional.softmax(out[0, -1], dim=0))
print("Subtracting 0, adding 1", add_vec[0, -1])
print("Adding 0, subtracting 1", subtract_vec[0, -1])

tensor([[[  16.4406,   16.4415,   16.4427,   16.4385,   16.4421,   16.4415,
            16.4397,   16.4371,   16.4422, -117.7469],
         [ -73.5781,   13.6657,   13.6675,   13.6603,   13.6619,   13.6677,
            13.6644,   13.6623,   13.6662,  -41.8547],
         [ -28.0465,   19.5765,   14.9150,   16.1395,   14.9450,   16.9700,
            15.4190,   17.5065,   15.6091,  -87.9268],
         [ -30.0472,  -29.6935,   13.7019,   14.8522,   14.6196,   16.0425,
            13.7595,   14.6526,   15.0848,  -41.4792],
         [ -19.3795,  -10.7714,   14.2134,   15.0274,   14.2680,   15.9633,
            14.2001,   17.0041,   14.0319,  -64.8916],
         [ -33.0726,  -23.2774,  -34.1688,   10.0674,   10.0670,   10.6562,
             9.0185,   10.3589,   10.4061,   18.4578]]], device='cuda:0',
       grad_fn=<AddBackward0>)

[10, 0, 1, 2] - zero_act - one_act - 2_act
[10, 1, 2]
[10, 2, 1]
[10, 1, 2, 3]
[10, 2, 1, 3]
[10, 1, 3, 2]