In [2]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F
from alphatoe import plot, game
from transformer_lens import HookedTransformer, HookedTransformerConfig
import json
import einops
import circuitsvis as cv
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from functools import partial
from copy import copy
import tqdm
import pandas as pd
import numpy as np

In [3]:
weights = torch.load("../scripts/models/prob all 8 layer control-20230718-185339.pt")
with open("../scripts/models/prob all 8 layer control-20230718-185339.json", "r") as f:
    args = json.load(f)

In [4]:
model_cfg = HookedTransformerConfig(
    n_layers=args["n_layers"],
    n_heads=args["n_heads"],
    d_model=args["d_model"],
    d_head=args["d_head"],
    d_mlp=args["d_mlp"],
    act_fn=args["act_fn"],
    normalization_type=args["normalization_type"],
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device=args["device"],
    seed=args["seed"],
)

In [5]:
model = HookedTransformer(model_cfg)
model.cfg.use_attn_result = True
model.load_state_dict(weights)

<All keys matched successfully>

In [6]:
def ablate_output(module, input, output):
    zeros = [
        0,
        2,
        4,
        6,
    ]  # NOTE: ints here refer to head indices. All included indices will be zeroed
    print(output.shape)
    for i in zeros:
        output[0, :, i] = 0
    return output

In [21]:
handle.remove()

NameError: name 'handle' is not defined

In [None]:
ablate_output

In [20]:
def ablate_all_but_one_head(head, seq):
    def hook(module, input, output):
        result = torch.zeros_like(output)
        result[:, :, head, :] = output[:, :, head, :]
        return result

    model.cfg.use_attn_result = True
    try:
        handle = model.blocks[0].attn.hook_result.register_forward_hook(hook)
        logits = model(torch.tensor(seq))
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return logits

In [22]:
def ablate_one_head(head, seq):
    def hook(module, input, output):
        result = output.clone()
        result[:, :, head, :] = 0
        return result

    model.cfg.use_attn_result = True
    try:
        handle = model.blocks[0].attn.hook_result.register_forward_hook(hook)
        logits = model.run_with_cache(torch.tensor(seq))
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return logits

In [14]:
try:
    handle.remove()
except:
    pass
handle = model.blocks[0].attn.hook_result.register_forward_hook(ablate_all)

In [14]:
seq = [10, 1, 2, 3, 4]
with torch.no_grad():
    logits, cache = model.run_with_cache(torch.tensor(seq))

In [11]:
transformer_lens
model.blocks[0].mlp.W_in

Parameter containing:
tensor([[ 0.0096,  0.0409, -0.0163,  ...,  0.0177,  0.1423, -0.0354],
        [-0.0476, -0.0478,  0.0218,  ..., -0.0857,  0.0150,  0.1479],
        [ 0.0601,  0.1086, -0.0511,  ..., -0.0158, -0.0597, -0.0972],
        ...,
        [ 0.0269, -0.0254, -0.0021,  ..., -0.0670,  0.0562,  0.1216],
        [-0.1848,  0.1034,  0.0765,  ..., -0.0215, -0.1346, -0.0804],
        [-0.2362, -0.0332, -0.1650,  ...,  0.0590, -0.0217,  0.0781]],
       device='cuda:0', requires_grad=True)

In [19]:
neuron_activations = copy(cache["post", 0][0])

In [32]:
logits, cache = ablate_one_head(2, seq)

In [33]:
ablation_activations = copy(cache["post", 0][0])

In [35]:
plot.lines(neuron_activations)

In [36]:
plot.lines(ablation_activations)

In [37]:
plot.lines(neuron_activations - ablation_activations)

In [40]:
vals, indices = torch.sort(neuron_activations - ablation_activations)

In [41]:
vals

tensor([[-0.1606, -0.1243, -0.1105,  ...,  0.1659,  0.1688,  0.2383],
        [-0.2793, -0.2397, -0.2168,  ...,  0.5147,  0.5775,  0.5860],
        [-0.2193, -0.2094, -0.1954,  ...,  0.4017,  0.4086,  0.4488],
        [-0.2083, -0.1866, -0.1684,  ...,  0.3301,  0.3438,  0.3690],
        [-0.1395, -0.1324, -0.1318,  ...,  0.2541,  0.2771,  0.3005]],
       device='cuda:0')

In [44]:
plot.lines(vals)

In [66]:
effect_sizes = {}
for token in range(vals.shape[0]):
    for index, effect in zip(indices[token], vals[token]):
        effect_sizes[index.item()] = effect_sizes.setdefault(index.item(), 0.0) + abs(
            effect.item()
        )

In [75]:
effects = [(effect_sizes[index], index) for index in effect_sizes.keys()]
effects.sort()
effects = [(index, effect) for (effect, index) in effects]

In [76]:
print(effects)

[(5, 0.0), (92, 0.0), (114, 0.0), (123, 0.0), (150, 0.0), (152, 0.0), (160, 0.0), (165, 0.0), (244, 0.0), (247, 0.0), (314, 0.0), (332, 0.0), (380, 0.0), (438, 0.0), (443, 0.0), (433, 0.002846095710992813), (151, 0.0029787421226501465), (254, 0.0045240819454193115), (320, 0.0065285563468933105), (87, 0.017875888384878635), (14, 0.021632302552461624), (166, 0.02217673510313034), (102, 0.022608399391174316), (354, 0.02576705813407898), (451, 0.03469875454902649), (89, 0.0356471985578537), (61, 0.03628448024392128), (250, 0.04294384643435478), (489, 0.04538116231560707), (352, 0.04665052890777588), (297, 0.046697378158569336), (200, 0.04777367413043976), (300, 0.048993855714797974), (365, 0.05231852829456329), (120, 0.05418532341718674), (355, 0.05762887001037598), (421, 0.05819614231586456), (427, 0.061831504106521606), (154, 0.06213250756263733), (431, 0.06426942348480225), (346, 0.06640011072158813), (25, 0.068764328956604), (344, 0.06900292634963989), (187, 0.07100927829742432), (182,

In [77]:
plot.line(effects)

Which neurons strongly activate for game over states?

In [5]:
game_overs = torch.tensor(
    [
        [10, 0, 3, 1, 4, 2],
        [10, 3, 0, 4, 1, 5],
        [10, 6, 0, 7, 1, 8],
        [10, 0, 1, 3, 4, 6],
        [10, 1, 0, 4, 3, 7],
        [10, 2, 0, 5, 3, 8],
        [10, 0, 1, 4, 2, 8],
        [10, 2, 1, 4, 0, 6],
        [10, 2, 1, 4, 3, 6],
    ]
)
game_goings = torch.tensor(
    [
        [10, 0, 3, 1, 4],
        [10, 3, 0, 4, 1],
        [10, 6, 0, 7, 1],
        [10, 0, 1, 3, 4],
        [10, 1, 0, 4, 3],
        [10, 2, 0, 5, 3],
        [10, 0, 1, 4, 2],
        [10, 2, 1, 4, 0],
    ]
)

In [156]:
for board in game_overs:
    bd = game.play_game(list(board) + [9])
    bd.get_winner()
    print("--------------------")

| X | X | X |
| O | O |   |
|   |   |   |
--------------------
| O | O |   |
| X | X | X |
|   |   |   |
--------------------
| O | O |   |
|   |   |   |
| X | X | X |
--------------------
| X | O |   |
| X | O |   |
| X |   |   |
--------------------
| O | X |   |
| O | X |   |
|   | X |   |
--------------------
| O |   | X |
| O |   | X |
|   |   | X |
--------------------
| X | O | O |
|   | X |   |
|   |   | X |
--------------------
| O | O | X |
|   | X |   |
| X |   |   |
--------------------


In [6]:
with torch.no_grad():
    game_over_logits_cache = [
        model.run_with_cache(torch.tensor(seq)) for seq in tqdm.tqdm(game_overs)
    ]
game_over_logits, game_over_cache = map(list, zip(*game_over_logits_cache))

  0%|          | 0/9 [00:00<?, ?it/s]

  model.run_with_cache(torch.tensor(seq)) for seq in tqdm.tqdm(game_overs)
100%|██████████| 9/9 [00:00<00:00, 22.47it/s]


In [7]:
game_over_activations = torch.stack([cache["post", 0][0] for cache in game_over_cache])
game_over_acts, game_over_indices = torch.sort(game_over_activations)

In [25]:
activation_differences, difference_indices = torch.sort(
    game_over_activations[:, -1, ...] - game_over_activations[:, -2, ...]
)

In [26]:
over_acts = game_over_acts[:, -1]
over_indices = game_over_indices[:, -1]
pre_over_acts = game_over_acts[:, -2]
pre_over_indices = game_over_indices[:, -2]

In [27]:
print(over_acts.shape)
print(pre_over_acts.shape)
print(activation_differences.shape)

torch.Size([1, 512])
torch.Size([1, 512])
torch.Size([1, 512])


In [18]:
game_over_logits[0].shape

torch.Size([1, 6, 10])

In [8]:
unembedded_logits = game_over_logits[0][0] @ model.W_U.T

In [9]:
label = torch.tensor([0.0] * 9 + [1.0]).to("cuda")
loss_fn = F.cross_entropy
og_loss = loss_fn(game_over_logits[0][0, -1, :], label)

In [10]:
print(og_loss)

tensor(0.0001, device='cuda:0')


In [11]:
def ablate_one_neuron(neuron, seq):
    def hook(module, input, output):
        result = output.clone()
        result[:, :, neuron] = 0
        return result

    model.cfg.use_attn_result = True
    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        logits = model.run_with_cache(torch.tensor(seq))
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return logits

In [12]:
def neuron_ablated_logits_losses(data):
    modified_logits = []
    modified_losses = []
    for i in range(512):
        logits = ablate_one_neuron(i, data)[0]
        neuron_losses = []
        for i in range(game_overs.shape[0]):
            loss = loss_fn(logits[i, -1, :], label)
            neuron_losses.append(loss)
        modified_losses.append(torch.tensor(neuron_losses).to("cpu").detach().numpy())
        modified_logits.append(logits.to("cpu").detach().numpy())
    modified_losses = torch.tensor(np.array(modified_losses))
    return modified_logits, modified_losses

In [13]:
modified_logits, modified_losses = neuron_ablated_logits_losses(game_overs)

  logits = model.run_with_cache(torch.tensor(seq))


In [32]:
plot.lines(modified_losses.T, log_y=False)

In [15]:
modified_losses.T[4, 325:328]

tensor([3.3379e-06, -0.0000e+00, 4.1842e-05])

In [16]:
vals, indx = torch.sort(modified_losses.T, descending=True)

In [17]:
top_vals = vals[:, :7]
top_indx = indx[:, :7]

In [18]:
idx_count = {}
for i in range(top_indx.shape[0]):
    for j in range(top_indx.shape[1]):
        key = top_indx[i, j].item()
        idx_count.setdefault(key, [])
        idx_count[key] += [(i, top_vals[i, j].item())]

In [19]:
for key in idx_count:
    if len(idx_count[key]) > 1:
        print("%3s - %d: " % (str(key), len(idx_count[key])))
        for game, val in idx_count[key]:
            print("    %d %.3f" % (game, val))

195 - 7: 
    0 0.386
    1 0.940
    4 0.004
    5 0.007
    6 0.179
    7 0.049
    8 0.038
117 - 5: 
    0 0.184
    2 1.736
    6 0.071
    7 0.033
    8 0.006
342 - 7: 
    0 0.034
    1 0.397
    3 0.088
    4 0.011
    6 0.021
    7 0.059
    8 0.031
267 - 2: 
    0 0.029
    1 0.124
239 - 7: 
    1 0.028
    2 0.006
    4 0.021
    5 0.015
    6 0.053
    7 0.008
    8 0.006
165 - 5: 
    1 0.017
    2 0.024
    3 0.027
    7 0.015
    8 0.012
219 - 2: 
    2 0.018
    7 0.008
  2 - 2: 
    3 0.037
    4 0.004
492 - 5: 
    3 0.011
    5 0.010
    6 0.019
    7 0.017
    8 0.012
 71 - 3: 
    3 0.011
    4 0.001
    8 0.006
447 - 3: 
    4 0.002
    5 0.001
    6 0.027


Getting logit ablation effects on MLP output.
Which neurons are most effected by a head ablation?

In [243]:
# get normal activations with no ablation
# ablate each head and get activations
# sort activations by neuron index within each ablated dataset
# Which neurons are affected most by which head
# see if they line up with our previous data

In [20]:
logits, cache = model.run_with_cache(game_overs)
original_acts = cache["post", 0][:, -1, :]

In [23]:
ablated_logits = []
ablated_activations = []
for i in range(8):
    logits, cache = ablate_one_head(i, game_overs)
    ablated_logits.append(logits)
    ablated_activations.append(cache["post", 0][:, -1])
ablated_logits = torch.stack(ablated_logits)
ablated_activations = torch.stack(ablated_activations)


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [24]:
ablated_activations.shape

torch.Size([8, 9, 512])

In [25]:
ablated_diffs = ablated_activations - original_acts

In [26]:
ablated_vals, ablated_indices = torch.sort(ablated_diffs, descending=True)

In [30]:
ablated_vals.shape

torch.Size([8, 9, 512])

In [38]:
# steps till end-state
# who won
# win condition
# What conditions are rotations of others
# All rotations of a game
# Flips - middle row and middle column
# Index in the train-test set
# Whether or not it was trained on or tested on

In [6]:
all_games = game.generate_all_games([game.Board()])

In [76]:
game.Board().win_condition

In [18]:
columns = [
    "moves played",
    "steps till end state",
    "winner",
    "win condition",
    "game rotations",
    "horizontal flip",
    "vertical flip",
    "index in training",
    "train or test",
]
df = pd.DataFrame(columns=columns)

In [19]:
df["moves played"] = [game.moves_played for game in all_games]

In [20]:
df["winner"] = [game.get_winner() for game in all_games]

In [21]:
df["steps till end state"] = [len(game.moves_played) for game in all_games]

In [22]:
df["win condition"] = [game.win_condition for game in all_games]

In [23]:
df["game rotations"] = [get_rotate_game_state_indices(game) for game in all_games]

In [None]:
df["horizontal flip"]

In [17]:
df.head()

Unnamed: 0,moves played,steps till end state,winner,win condition,game rotations,game flips,index in training,train or test
0,"[0, 1, 3, 2, 6]",5,X,left column,"[399, 1439, 1040]",,,
1,"[0, 1, 3, 4, 6]",5,X,left column,"[396, 1438, 1043]",,,
2,"[0, 1, 3, 5, 6]",5,X,left column,"[398, 1437, 1041]",,,
3,"[0, 1, 3, 7, 6]",5,X,left column,"[395, 1436, 1044]",,,
4,"[0, 1, 3, 8, 6]",5,X,left column,"[397, 1435, 1042]",,,


In [24]:
index_lookup = {tuple(mp): i for i, mp in zip(df.index, df.get("moves played"))}
# dict[tuple(moves_played), int]
rotation_ref = {
    0: 2,
    1: 5,
    2: 8,
    3: 1,
    4: 4,
    5: 7,
    6: 0,
    7: 3,
    8: 6,
}


def get_rotate_game_state_indices(game) -> list[int]:
    # rotates by 90 more deg clockwise each rot
    rot1 = [rotation_ref[move] for move in game.moves_played]
    rot2 = [rotation_ref[move] for move in rot1]
    rot3 = [rotation_ref[move] for move in rot2]
    return [index_lookup[tuple(rot)] for rot in [rot1, rot2, rot3]]

In [None]:
horizontal_flip_ref = {
    0: 6,
    1: 7,
    2: 8,
    3: 3,
    4: 4,
    5: 5,
    6: 0,
    7: 1,
    8: 2,
}
vertical_flip_ref = {
    0: 2,
    1: 1,
    2: 0,
    3: 5,
    4: 4,
    5: 3,
    6: 8,
    7: 7,
    8: 6,
}

In [81]:
for game in all_games:
    try:
        print(game.win_condition)
    except Exception as e:
        print(game.moves_played, e)

left column
left column
left column
left column
left column
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
left column
left column
left column
left column
left column
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
left column
left column
left column
left column
left column
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
left column
left column
left column
left column
left column
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top row
top row
top row
top row
top row
top row
top row
top row
top row
top row
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top left -> bottom right
top l

In [72]:
g = all_games[0]
g.lose_condition = None

In [43]:
print(df["moves_played"])

0                     [0, 1, 3, 2, 6]
1                     [0, 1, 3, 4, 6]
2                     [0, 1, 3, 5, 6]
3                     [0, 1, 3, 7, 6]
4                     [0, 1, 3, 8, 6]
                     ...             
255163    [8, 7, 6, 5, 4, 2, 1, 3, 0]
255164    [8, 7, 6, 5, 4, 2, 3, 0, 1]
255165    [8, 7, 6, 5, 4, 2, 3, 1, 0]
255166    [8, 7, 6, 5, 4, 3, 1, 0, 2]
255167    [8, 7, 6, 5, 4, 3, 1, 2, 0]
Name: moves_played, Length: 255168, dtype: object
