In [None]:
%load_ext autoreload
%autoreload 2
from alphatoe import models, plot, interpretability, game
import pandas as pd
import torch
from pytorch_memlab import LineProfiler, MemReporter
from showmethetypes import SMTT
import einops
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatter
import numpy as np
import tqdm
import random

In [None]:
autoenc = models.SparseAutoEncoder(512, 512).cuda()
autoenc.load_state_dict(torch.load("./512_sparse_autoencoder_on_activations_20NOV2023_parameters.pt"))

In [None]:
act_data = torch.load("./all_games_act_data.pt")

In [None]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [None]:
def neuron_posembed_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        _ = model.blocks[0].mlp(seq)
        activations = model.blocks[0].mlp.hook_post.captured_activations
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations


def neuron_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        _ = model(seq)
        activations = model.blocks[0].mlp.hook_post.captured_activations
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations

In [None]:
games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")
end_game_types = list(games["first win condition"].unique())
non_9_move_games = games[games["steps till end state"] != 9]
# game_types x game x moves
game_kinds = [
    [
        [10] + eval(move)
        for move in non_9_move_games[
            non_9_move_games["first win condition"] == game_type
        ]["moves played"]
    ]
    for game_type in end_game_types[:-1]
]

In [None]:
# 30 seconds
game_count = 1_000
all_activations = []
for games in game_kinds:
    kind_activations = []
    for _ in range(game_count):
        data = torch.tensor(random.choice(games))
        kind_activations.append(neuron_activations(data)[0][-1])
    all_activations.append(torch.stack(kind_activations))
all_activations = torch.cat(all_activations)
all_activations = all_activations.detach().cpu().T

In [None]:
all_activations = all_activations.T
all_activations.to("cuda")

In [None]:
all_features = torch.stack(
    [
        autoenc.get_activations(activation.to("cuda"))
        for activation in tqdm.tqdm(all_activations)
    ]
)

In [None]:
all_features = all_features.detach().cpu().T

In [None]:
plt.figure(figsize=(10, 10))
# dpi
plt.figure(dpi=500)
plt.imshow(all_features, cmap="jet", aspect="auto", interpolation="none")
# colorbar
plt.colorbar()

plt.xlabel("Games, sorted by end-state")
plt.xticks(
    ticks=[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000],
    labels=["LC", "TL -> BR", "TR", "MC", "BL -> TR", "RC", "MR", "BR"],
)
# xtick 45 degree
plt.xticks(rotation=45)
plt.ylabel("Features")
# title
plt.title(
    "Features Across 1,000 games sorted by end-state in SAE=1024",
    fontsize=12,
)
plt.gcf().set_facecolor("white")