In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F
from alphatoe import plot, game, data
from transformer_lens import HookedTransformer, HookedTransformerConfig
import json
import einops
import circuitsvis as cv
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from functools import partial
from copy import copy
import tqdm
import pandas as pd
import numpy as np

In [3]:
weights = torch.load("../scripts/models/prob all 8 layer control-20230718-185339.pt")
with open("../scripts/models/prob all 8 layer control-20230718-185339.json", "r") as f:
    args = json.load(f)

In [4]:
model_cfg = HookedTransformerConfig(
    n_layers=args["n_layers"],
    n_heads=args["n_heads"],
    d_model=args["d_model"],
    d_head=args["d_head"],
    d_mlp=args["d_mlp"],
    act_fn=args["act_fn"],
    normalization_type=args["normalization_type"],
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device=args["device"],
    seed=args["seed"],
)

In [5]:
model = HookedTransformer(model_cfg)
model.cfg.use_attn_result = True
model.load_state_dict(weights)

<All keys matched successfully>

In [15]:
def ablate_output(module, input, output):
    zeros = [
        0,
        2,
        4,
        6,
    ]  # NOTE: ints here refer to head indices. All included indices will be zeroed
    print(output.shape)
    for i in zeros:
        output[0, :, i] = 0
    return output

In [16]:
handle.remove()

NameError: name 'handle' is not defined

In [None]:
ablate_output

In [None]:
def ablate_all_but_one_head(head, seq):
    def hook(module, input, output):
        result = torch.zeros_like(output)
        result[:, :, head, :] = output[:, :, head, :]
        return result

    model.cfg.use_attn_result = True
    try:
        handle = model.blocks[0].attn.hook_result.register_forward_hook(hook)
        logits = model(torch.tensor(seq))
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return logits

In [None]:
def ablate_one_head(head, seq):
    def hook(module, input, output):
        result = output.clone()
        result[:, :, head, :] = 0
        return result

    model.cfg.use_attn_result = True
    try:
        handle = model.blocks[0].attn.hook_result.register_forward_hook(hook)
        logits = model.run_with_cache(torch.tensor(seq))
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return logits

In [None]:
try:
    handle.remove()
except:
    pass
handle = model.blocks[0].attn.hook_result.register_forward_hook(ablate_all)

In [None]:
seq = [10, 1, 2, 3, 4]
with torch.no_grad():
    logits, cache = model.run_with_cache(torch.tensor(seq))

In [None]:
transformer_lens
model.blocks[0].mlp.W_in

In [None]:
neuron_activations = copy(cache["post", 0][0])

In [None]:
logits, cache = ablate_one_head(2, seq)

In [None]:
ablation_activations = copy(cache["post", 0][0])

In [None]:
plot.lines(neuron_activations)

In [None]:
plot.lines(ablation_activations)

In [None]:
plot.lines(neuron_activations - ablation_activations)

In [None]:
vals, indices = torch.sort(neuron_activations - ablation_activations)

In [None]:
vals

In [None]:
plot.lines(vals)

In [None]:
effect_sizes = {}
for token in range(vals.shape[0]):
    for index, effect in zip(indices[token], vals[token]):
        effect_sizes[index.item()] = effect_sizes.setdefault(index.item(), 0.0) + abs(
            effect.item()
        )

In [None]:
effects = [(effect_sizes[index], index) for index in effect_sizes.keys()]
effects.sort()
effects = [(index, effect) for (effect, index) in effects]

In [None]:
print(effects)

In [None]:
plot.line(effects)

Which neurons strongly activate for game over states?

In [61]:
game_overs = torch.tensor(
    [
        [10, 0, 3, 1, 4, 2],
        [10, 3, 0, 4, 1, 5],
        [10, 6, 0, 7, 1, 8],
        [10, 0, 1, 3, 4, 6],
        [10, 1, 0, 4, 3, 7],
        [10, 2, 0, 5, 3, 8],
        [10, 0, 1, 4, 2, 8],
        [10, 2, 1, 4, 0, 6],
        [10, 2, 1, 4, 3, 6],
    ]
)
game_goings = torch.tensor(
    [
        [10, 0, 3, 1, 4],
        [10, 3, 0, 4, 1],
        [10, 6, 0, 7, 1],
        [10, 0, 1, 3, 4],
        [10, 1, 0, 4, 3],
        [10, 2, 0, 5, 3],
        [10, 0, 1, 4, 2],
        [10, 2, 1, 4, 0],
    ]
)

In [None]:
for board in game_overs:
    bd = game.play_game(list(board) + [9])
    bd.get_winner()
    print("--------------------")

In [63]:
with torch.no_grad():
    game_over_logits_cache = [
        model.run_with_cache(torch.tensor(seq)) for seq in tqdm.tqdm(game_overs)
    ]
game_over_logits, game_over_cache = map(list, zip(*game_over_logits_cache))

  model.run_with_cache(torch.tensor(seq)) for seq in tqdm.tqdm(game_overs)
100%|██████████| 9/9 [00:00<00:00, 116.27it/s]


In [None]:
game_over_activations = torch.stack([cache["post", 0][0] for cache in game_over_cache])
game_over_acts, game_over_indices = torch.sort(game_over_activations)

In [None]:
activation_differences, difference_indices = torch.sort(
    game_over_activations[:, -1, ...] - game_over_activations[:, -2, ...]
)

In [None]:
over_acts = game_over_acts[:, -1]
over_indices = game_over_indices[:, -1]
pre_over_acts = game_over_acts[:, -2]
pre_over_indices = game_over_indices[:, -2]

In [None]:
print(over_acts.shape)
print(pre_over_acts.shape)
print(activation_differences.shape)

In [None]:
game_over_logits[0].shape

In [None]:
unembedded_logits = game_over_logits[0][0] @ model.W_U.T

In [64]:
label = torch.tensor([0.0] * 9 + [1.0]).to("cuda")
loss_fn = F.cross_entropy
og_loss = loss_fn(game_over_logits[0][0, -1, :], label)

In [None]:
print(og_loss)

In [65]:
def ablate_one_neuron(neuron, seq):
    def hook(module, input, output):
        result = output.clone()
        result[:, :, neuron] = 0
        return result

    model.cfg.use_attn_result = True
    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        logits = model.run_with_cache(torch.tensor(seq))
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return logits

In [66]:
def neuron_ablated_logits_losses(data):
    modified_logits = []
    modified_losses = []
    for i in range(512):
        logits = ablate_one_neuron(i, data)[0]
        neuron_losses = []
        for i in range(game_overs.shape[0]):
            loss = loss_fn(logits[i, -1, :], label)
            neuron_losses.append(loss)
        modified_losses.append(torch.tensor(neuron_losses).to("cpu").detach().numpy())
        modified_logits.append(logits.to("cpu").detach().numpy())
    modified_losses = torch.tensor(np.array(modified_losses))
    return modified_logits, modified_losses

In [67]:
modified_logits, modified_losses = neuron_ablated_logits_losses(game_overs)

  logits = model.run_with_cache(torch.tensor(seq))


In [68]:
plot.lines(modified_losses.T, log_y=False)

In [None]:
modified_losses.T[4, 325:328]

In [None]:
vals, indx = torch.sort(modified_losses.T, descending=True)

In [None]:
top_vals = vals[:, :7]
top_indx = indx[:, :7]

In [None]:
idx_count = {}
for i in range(top_indx.shape[0]):
    for j in range(top_indx.shape[1]):
        key = top_indx[i, j].item()
        idx_count.setdefault(key, [])
        idx_count[key] += [(i, top_vals[i, j].item())]

In [None]:
for key in idx_count:
    if len(idx_count[key]) > 1:
        print("%3s - %d: " % (str(key), len(idx_count[key])))
        for game, val in idx_count[key]:
            print("    %d %.3f" % (game, val))

Getting logit ablation effects on MLP output.
Which neurons are most effected by a head ablation?

In [None]:
# get normal activations with no ablation
# ablate each head and get activations
# sort activations by neuron index within each ablated dataset
# Which neurons are affected most by which head
# see if they line up with our previous data

In [None]:
logits, cache = model.run_with_cache(game_overs)
original_acts = cache["post", 0][:, -1, :]

In [None]:
ablated_logits = []
ablated_activations = []
for i in range(8):
    logits, cache = ablate_one_head(i, game_overs)
    ablated_logits.append(logits)
    ablated_activations.append(cache["post", 0][:, -1])
ablated_logits = torch.stack(ablated_logits)
ablated_activations = torch.stack(ablated_activations)

In [None]:
ablated_activations.shape

In [None]:
ablated_diffs = ablated_activations - original_acts

In [None]:
ablated_vals, ablated_indices = torch.sort(ablated_diffs, descending=True)

In [None]:
ablated_vals.shape

In [6]:
# steps till end-state
# who won
# win condition
# What conditions are rotations of others
# All rotations of a game
# Flips - middle row and middle column
# Index in the train-test set
# Whether or not it was trained on or tested on

In [44]:
all_games = game.generate_all_games([game.Board()])

In [45]:
columns = [
    "moves played",
    "steps till end state",
    "winner",
    "rotation 1",
    "rotation 2",
    "rotation 3",
    "horizontal flip",
    "vertical flip",
    "training index",
    "train or test",
]
df = pd.DataFrame(columns=columns)

In [46]:

# dict[tuple(moves_played), int]
rotation_ref = {
    0: 2,
    1: 5,
    2: 8,
    3: 1,
    4: 4,
    5: 7,
    6: 0,
    7: 3,
    8: 6,
}
horizontal_flip_ref = {
    0: 6,
    1: 7,
    2: 8,
    3: 3,
    4: 4,
    5: 5,
    6: 0,
    7: 1,
    8: 2,
}
vertical_flip_ref = {
    0: 2,
    1: 1,
    2: 0,
    3: 5,
    4: 4,
    5: 3,
    6: 8,
    7: 7,
    8: 6,
}


def get_rotate_game_state_indices(game) -> list[int]:
    # rotates by 90 more deg clockwise each rot
    rot1 = [rotation_ref[move] for move in game.moves_played]
    rot2 = [rotation_ref[move] for move in rot1]
    rot3 = [rotation_ref[move] for move in rot2]
    return [index_lookup[tuple(rot)] for rot in [rot1, rot2, rot3]]


def get_horizontal_flip_game_state_indices(game) -> list[int]:
    return [index_lookup[tuple([horizontal_flip_ref[i] for i in game.moves_played])]]


def get_vertical_flip_game_state_indices(game) -> list[int]:
    return [index_lookup[tuple([vertical_flip_ref[i] for i in game.moves_played])]]

In [47]:
df["moves played"] = [game.moves_played for game in all_games]
index_lookup = {tuple(mp): i for i, mp in zip(df.index, df.get("moves played"))}
df["winner"] = [game.get_winner() for game in all_games]
df["steps till end state"] = [len(game.moves_played) for game in all_games]
df["first win condition"] = [game.win_conditions[0] if len(game.win_conditions) > 0 else None for game in all_games]
df["second win condition"] = [game.win_conditions[1] if len(game.win_conditions) > 1 else None for game in all_games]
game_rots = [get_rotate_game_state_indices(game) for game in all_games]
df["rotation 1"] = [rots[0] for rots in game_rots]
df["rotation 2"] = [rots[1] for rots in game_rots]
df["rotation 3"] = [rots[2] for rots in game_rots]
df["horizontal flip"] = [
    get_horizontal_flip_game_state_indices(game)[0] for game in all_games
]
df["vertical flip"] = [get_vertical_flip_game_state_indices(game)[0] for game in all_games]

In [48]:
inds = data.gen_data(
    "all",
    split_ratio=args["train_test_split"],
    device=args["device"],
    seed=args["seed"],
    returns_inds=True,
)

Generating all possible games...
Generated 255168 games
Generated array of moves
torch.Size([255168, 10])
Generated data and labels
One hot encoded labels
torch.Size([255168, 10, 10])
torch.Size([255168, 10, 10])


In [49]:
indsi = [ind.item() for ind in inds]

In [50]:
training_order_table = {
    game_index: train_index for train_index, game_index in enumerate(indsi)
}

In [51]:
df["training index"] = [training_order_table[i] for i in range(len(all_games))]

In [52]:
df.head()

Unnamed: 0,moves played,steps till end state,winner,rotation 1,rotation 2,rotation 3,horizontal flip,vertical flip,training index,train or test,first win condition,second win condition
0,"[0, 1, 3, 2, 6]",5,X,399,1439,1040,1114,325,241912,,left column,
1,"[0, 1, 3, 4, 6]",5,X,396,1438,1043,1112,327,190522,,left column,
2,"[0, 1, 3, 5, 6]",5,X,398,1437,1041,1113,326,90275,,left column,
3,"[0, 1, 3, 7, 6]",5,X,395,1436,1044,1110,329,21994,,left column,
4,"[0, 1, 3, 8, 6]",5,X,397,1435,1042,1111,328,48696,,left column,


In [53]:
split = int(0.8 * len(all_games))
train_inds, test_inds = inds[:split], inds[split:]

In [54]:
def train_or_test(index):
    if training_order_table[index] <= split:
        return "train"
    else:
        return "test"

In [55]:
df["train or test"] = [train_or_test(i) for i in range(len(all_games))]

In [56]:
df.head()

Unnamed: 0,moves played,steps till end state,winner,rotation 1,rotation 2,rotation 3,horizontal flip,vertical flip,training index,train or test,first win condition,second win condition
0,"[0, 1, 3, 2, 6]",5,X,399,1439,1040,1114,325,241912,test,left column,
1,"[0, 1, 3, 4, 6]",5,X,396,1438,1043,1112,327,190522,train,left column,
2,"[0, 1, 3, 5, 6]",5,X,398,1437,1041,1113,326,90275,train,left column,
3,"[0, 1, 3, 7, 6]",5,X,395,1436,1044,1110,329,21994,train,left column,
4,"[0, 1, 3, 8, 6]",5,X,397,1435,1042,1111,328,48696,train,left column,


In [57]:
loss_fn = F.cross_entropy

In [58]:
target = torch.tensor([0.]*9 + [1.]).to('cuda')
with torch.no_grad():
    df["end move loss"] = [loss_fn(model(torch.tensor([10]+game))[0,-1], target).to('cpu').item() for game in tqdm.tqdm(df["moves played"])]



100%|██████████| 255168/255168 [04:25<00:00, 959.30it/s]


In [59]:
#save pandas dataframe
df.to_csv('../data/prob all 8 layer control-20230718-185339_stats.csv', index=False)

In [60]:
test_df = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")
test_df.head()

  test_df = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")


Unnamed: 0,moves played,steps till end state,winner,rotation 1,rotation 2,rotation 3,horizontal flip,vertical flip,training index,train or test,first win condition,second win condition,end move loss
0,"[0, 1, 3, 2, 6]",5,X,399,1439,1040,1114,325,241912,test,left column,,5e-06
1,"[0, 1, 3, 4, 6]",5,X,396,1438,1043,1112,327,190522,train,left column,,0.000114
2,"[0, 1, 3, 5, 6]",5,X,398,1437,1041,1113,326,90275,train,left column,,7e-06
3,"[0, 1, 3, 7, 6]",5,X,395,1436,1044,1110,329,21994,train,left column,,8e-06
4,"[0, 1, 3, 8, 6]",5,X,397,1435,1042,1111,328,48696,train,left column,,1.3e-05
