In [1]:
import numpy as np
import torch
from alphatoe import plot, game, interpretability
from transformer_lens import HookedTransformer, HookedTransformerConfig
import json
import einops
import circuitsvis as cv
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
import torch.nn.functional as F

import matplotlib.pyplot as plt
from importlib import reload
from tqdm import tqdm

In [2]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)


In [3]:
def get_head_attention(seq: torch.Tensor) -> torch.Tensor:
    def hook(module, input, output):
        module.attention = output.clone()
    try:
        handle = model.blocks[0].hook_attn_out.register_forward_hook(hook)
        _ = model(seq)
        attention = model.blocks[0].hook_attn_out.attention
    except Exception as e:
        handle.remove()
        raise e
    
    return attention

def get_value_attention(seq: torch.Tensor) -> torch.Tensor:
    def hook(module, input, output):
        module.value = output.clone()
    try:
        handle = model.blocks[0].attn.hook_v.register_forward_hook(hook)
        _ = model(seq)
        value = model.blocks[0].attn.hook_v.value
    except Exception as e:
        handle.remove()
        raise e
    
    return value



In [4]:
#win top left to bottom right
moves = [10, 0, 10, 4, 10, 8]
seq = torch.tensor(moves)
get_value_attention(seq)[0, -1, 2]

tensor([-0.2119,  0.0599, -0.4458,  0.6086, -0.2821,  0.1315, -0.0441,  0.1153,
        -0.1475,  0.7113,  0.3539, -0.2458,  0.3536,  0.3840, -0.4710,  0.6980],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [5]:
pos = model.pos_embed(seq.unsqueeze(0))[0]
pos[0] = 0
pos[2] = 0
pos[4] = 0

In [6]:
head_two_wv = model.blocks[0].attn.W_V[2]

pos_value = pos @ head_two_wv
pos_value.shape

torch.Size([6, 16])

In [7]:
head_two_wo = model.blocks[0].attn.W_O[2]

pos_out = pos_value @ head_two_wo
pos_out_last = pos_out[-1]

In [8]:
def get_head_attention(seq: torch.Tensor) -> torch.Tensor:
    def hook(module, input, output):
        module.attention = output.clone()

    try:
        handle = model.blocks[0].hook_attn_out.register_forward_hook(hook)
        _ = model(seq)
        attention = model.blocks[0].hook_attn_out.attention
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return attention

def modify_pre_mlp_residuals(seq: torch.Tensor, vec) -> torch.Tensor:
    def hook(module, input, output):
        return vec

    try:
        handle = model.blocks[0].hook_attn_out.register_forward_hook(hook)
        out = model(seq)
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return out

def get_activation_vector(move: int):
    moves = list(range(0, 9))
    moves.remove(move)
    z_seqs = [[10, move, snd] for snd in moves]
    other_seqs = [[[10, fst, snd] for fst in moves if fst != snd] for snd in moves]
    all_acts = []
    flatten_acts = []
    for i, z_seq in enumerate(z_seqs):
        z_act = get_head_attention((torch.tensor(z_seq)))[0, -1]
        z_acts = []
        for other_seq in other_seqs[i]:
            other_act = get_head_attention((torch.tensor(other_seq)))[0, -1]

            act_diff = other_act - z_act
            z_acts.append(act_diff)
            flatten_acts.append(act_diff)

        all_acts.append(z_acts)
    flat_acts = torch.stack(flatten_acts)
    # return torch.norm(flat_acts, dim=1)
    return flat_acts.mean(0)

In [9]:
zero_act = get_activation_vector(0)
one_act = get_activation_vector(1)
two_act = get_activation_vector(2)
three_act = get_activation_vector(3)
four_act = get_activation_vector(4)
five_act = get_activation_vector(5)
six_act = get_activation_vector(6)
seven_act = get_activation_vector(7)
eight_act = get_activation_vector(8)

In [20]:
vec = pos_out_last - zero_act - four_act - eight_act
no_pos_vec = zero_act + four_act + eight_act
print(modify_pre_mlp_residuals(seq, vec)[0,-1])
print(modify_pre_mlp_residuals(seq, -no_pos_vec)[0,-1])
print(modify_pre_mlp_residuals(seq, pos_out_last)[0,-1])


tensor([-21.5328,  10.4674,   8.3215,  10.3605, -20.5882,   8.9510,   8.2554,
         10.0498, -25.0590,   4.7009], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([-21.0829,  10.5153,   8.8922,  10.4121, -19.8378,   9.0687,   8.7273,
         10.9109, -23.8312,   1.3638], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([ 1.3368,  2.4447, -0.5984,  1.6671,  0.1068,  2.8515,  0.1042,  0.7622,
        -5.2992, -3.0658], device='cuda:0', grad_fn=<SelectBackward0>)


In [11]:
def zero_content_embedding(seq: torch.Tensor) -> torch.Tensor:
    def hook(module, input, output):
        return torch.zeros_like(output)
    def other_hook(module, input, output):
        module.attention = output.clone()
    try:
        handle = model.hook_embed.register_forward_hook(hook)
        other_handle = model.blocks[0].hook_attn_out.register_forward_hook(other_hook)
        _ = model(seq)
        attention = model.blocks[0].hook_attn_out.attention
        handle.remove()
        other_handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return attention


In [12]:
zero_content = zero_content_embedding(seq)[0, -1]

In [13]:
vec = zero_content - zero_act - two_act - eight_act
strong_vec = 10*zero_content - zero_act - two_act - eight_act
print(modify_pre_mlp_residuals(seq, vec)[0,-1])
print(modify_pre_mlp_residuals(seq, strong_vec)[0,-1])

tensor([-32.4057,  14.9805, -25.8678,  17.5184,  18.7365,  14.3305,  15.0063,
         11.6420, -30.4062,  -3.6964], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([-50.1820,  18.9277, -14.2695,  58.7326,  70.8523,  17.0660,  36.7576,
        -13.9224, -51.8323, -87.3968], device='cuda:0',
       grad_fn=<SelectBackward0>)


In [15]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)
#moves = [10, 0, 10, 4, 10, 8]
model(seq)[0,-1]

tensor([-32.7298,  10.3781,   8.3182,   9.1393, -27.6169,   9.5415,   7.9608,
          9.1062, -29.3370,  17.7020], device='cuda:0',
       grad_fn=<SelectBackward0>)