In [None]:
%load_ext autoreload
%autoreload 2
from alphatoe import models, plot, interpretability, game
import pandas as pd
import torch
from pytorch_memlab import LineProfiler, MemReporter
from showmethetypes import SMTT
import einops
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatter
import numpy as np
import tqdm
import random

In [None]:
autoenc = torch.load("./sparse_autoencoder_on_activations_07NOV2023.pt")
act_data = torch.load("./all_games_act_data.pt")

In [None]:
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(
        local_encoder.W_in.shape[1], dtype=torch.float32
    ).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = act_data[torch.randperm(len(act_data))][: 2**14].to("cuda")

        hidden = local_encoder.get_act_density(tokens)

        act_freq_scores += hidden
        total += tokens.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores == 0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

In [None]:
freqs = get_freqs(local_encoder=autoenc)

In [None]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [None]:
def neuron_posembed_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        _ = model.blocks[0].mlp(seq)
        activations = model.blocks[0].mlp.hook_post.captured_activations
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations


def neuron_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        _ = model(seq)
        activations = model.blocks[0].mlp.hook_post.captured_activations
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations

In [None]:
games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")
games.head()

In [None]:
end_game_types = list(games["first win condition"].unique())
print(end_game_types)

In [None]:
non_9_move_games = games[games["steps till end state"] != 9]

In [None]:
# game_types x game x moves
game_kinds = [
    [
        [10] + eval(move)
        for move in non_9_move_games[
            non_9_move_games["first win condition"] == game_type
        ]["moves played"]
    ]
    for game_type in end_game_types[:-1]
]

In [None]:
# 30 seconds
game_count = 1_000
all_activations = []
for games in game_kinds:
    kind_activations = []
    for _ in range(game_count):
        data = torch.tensor(random.choice(games))
        kind_activations.append(neuron_activations(data)[0][-1])
    all_activations.append(torch.stack(kind_activations))
all_activations = torch.cat(all_activations)
all_activations = all_activations.detach().cpu().T

In [None]:
plt.figure(figsize=(10, 10))
# dpi
plt.figure(dpi=500)
plt.imshow(all_activations, cmap="jet", aspect="auto", interpolation="none")
# colorbar
plt.colorbar()

plt.xlabel("Games, sorted by end-state")

plt.ylabel("Neurons")
# title
plt.title(
    "Neuron activations Across 1,000 games sorted by end-state in dMLP=512 Model",
    fontsize=12,
)
plt.gcf().set_facecolor("white")

In [None]:
all_activations = all_activations.T

In [None]:
torch.save(
    autoenc.state_dict(), "./sparse_autoencoder_on_activations_07NOV2023_parameters.pt"
)

In [None]:
autoenc_cool = models.SparseAutoEncoder(512, 1024).cuda()

In [None]:
autoenc_cool.load_state_dict(
    torch.load("./sparse_autoencoder_on_activations_07NOV2023_parameters.pt")
)

In [None]:
all_activations.to("cuda")

In [None]:
all_features = torch.stack(
    [
        autoenc_cool.get_activations(activation.to("cuda"))
        for activation in tqdm.tqdm(all_activations)
    ]
)

In [None]:
tt = SMTT("torch")

In [None]:
tt(all_features)

In [None]:
all_features = all_features.detach().cpu().T

In [None]:
plt.figure(figsize=(10, 10))
# dpi
plt.figure(dpi=500)
plt.imshow(all_features, cmap="jet", aspect="auto", interpolation="none")
# colorbar
plt.colorbar()

plt.xlabel("Games, sorted by end-state")
plt.xticks(
    ticks=[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000],
    labels=["LC", "TL -> BR", "TR", "MC", "BL -> TR", "RC", "MR", "BR"],
)
# xtick 45 degree
plt.xticks(rotation=45)
plt.ylabel("Features")
# title
plt.title(
    "Features Across 1,000 games sorted by end-state in SAE=1024",
    fontsize=12,
)
plt.gcf().set_facecolor("white")

In [1]:
# plot.imshow_div(all_features, width=100, height=100)

In [None]:
activation_indices = {
    245: 4,
    496: 3,
    566: 0,
    600: 8,
    631: 1,
    639: 7,
    811: 2,
    830: 5,
    931: 6,
}

win_type_to_moves = {
    "left column": {0, 3, 6},
    "top left -> bottom right": {0, 4, 8},
    "top row": {0, 1, 2},
    "middle column": {1, 4, 7},
    "bottom left -> top right": {6, 4, 2},
    "right column": {2, 5, 8},
    "middle row": {3, 4, 5},
    "bottom row": {6, 7, 8},
}

In [None]:
games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")
games.head()

In [None]:
two_col_games = games[["moves played", "first win condition"]]

In [None]:
non_9_move_games = two_col_games.iloc[:100_000]

In [None]:
len(non_9_move_games)

In [None]:
all_moves = [
    torch.tensor([10] + eval(non_9_move_games["moves played"][i]))
    for i in range(100_000)
]

In [None]:
tt(all_moves)

In [None]:
acts_on_all_games = []
for moves in tqdm.tqdm(all_moves):
    with torch.no_grad():
        acts_on_all_games.append(
            autoenc_cool.get_activations(neuron_activations(moves))
        )
    torch.cuda.empty_cache()

In [None]:
all_act_tensor = torch.stack([acts[0, -1] for acts in acts_on_all_games])

In [None]:
tt(all_act_tensor)

In [None]:
move_indices = torch.zeros(100_000, 9)

In [None]:
for k, v in activation_indices.items():
    move_indices[:, v] = all_act_tensor[:, k]

In [None]:
tt(move_indices)

- All moves playedc
- winner moves played
- moves played by winner

In [None]:
all_moves

In [None]:
game = move_indices[0, [i for i in range(9)]] > 0

In [None]:
move_indices

In [None]:
game_cmp = torch.tensor([True if i in all_moves[0][1:] else False for i in range(9)])

In [None]:
torch.equal(game, game_cmp)

In [None]:
move

In [None]:
game_index = 1
game = move_indices[game_index, list(range(9))] > 0
game_cmp = torch.tensor([i in all_moves[game_index][1:] for i in range(9)])
isqual = torch.equal(game, game_cmp)
print(game)
print(game_cmp)
print(isqual)

In [None]:
is_goods = []
"""
Are the acts above 0?
are the acts above 0 also the ones that correspond to the moves played?
"""
for game_index in range(len(all_moves)):
    game = move_indices[game_index, list(range(9))] > 0
    game_cmp = torch.tensor([i in all_moves[game_index][1:] for i in range(9)])
    is_goods.append(torch.equal(game, game_cmp))

In [None]:
len(is_goods)

In [None]:
sum(is_goods)

In [None]:
is_off_properly = []
for game_index in range(len(all_moves)):
    game = move_indices[game_index, list(range(9))] > 0
    for move, act in enumerate(game):
        if act == 0:
            pass
        else:
            if move in all_moves[game_index][1:]:
                is_off_properly.append(True)
            else:
                is_off_properly.append(False)
print(len(is_off_properly))
print(sum(is_off_properly))
print(sum(is_off_properly) / len(is_off_properly))

# 99% Monosemantic babyyyyyyy

There's two things we can check:
- Are the features monosemantic (checking the negative case, making sure that feature isn't present when it shouldn't be present)
- Do the features completely cover the features we thought of (is there a 1 to 1 correspondence between the features and moves)

In [None]:
2 current avenues:
- What features do we need to include to get our 64% inference ability up to 99%
- What happens if we ablate the features of the autoencoder, does that actually correspond to how we should predict it will?
