In [None]:
%load_ext autoreload
%autoreload 2
from alphatoe import models, plot, interpretability, game
import pandas as pd
import torch
from pytorch_memlab import LineProfiler, MemReporter
from showmethetypes import SMTT
import einops
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatter
import numpy as np
import tqdm
import random

In [None]:
tt = SMTT()

In [None]:
games = game.generate_all_games([game.Board()])

In [None]:
five_move_games = torch.stack(
    [
        torch.tensor([10] + game.moves_played)
        for game in games
        if len(game.moves_played) == 5
    ]
)

In [None]:
five_move_games

In [None]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [None]:
autoenc = models.SparseAutoEncoder(512, 1024).cuda()
autoenc.load_state_dict(
    torch.load("./sparse_autoencoder_on_activations_07NOV2023_parameters.pt")
)

In [None]:
def neuron_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        _ = model(seq)
        activations = model.blocks[0].mlp.hook_post.captured_activations
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations

In [None]:
activations = autoenc.get_activations(neuron_activations(five_move_games))[:, -1]

In [None]:
tt(activations)

In [None]:
"""
if act on 995 > 0 and 600 == 0 or vice versa, keep game
"""

on_995 = []
on_600 = []
on_both = []
for i in range(activations.shape[0]):
    if activations[i][995] > 0.5 and activations[i][600] == 0:
        on_995.append(i)
    elif 0 <= activations[i][600] < 2 and activations[i][995] == 0:
        on_600.append(i)
    elif activations[i][600] > 2 and activations[i][995] > 0.5:
        on_both.append(i)

In [None]:
on_both

In [None]:
for i, j, k in zip(on_995, on_600, on_both):
    # print("995")
    # game.play_game(list(five_move_games[i]))
    # print("600")
    # game.play_game(list(five_move_games[j]))
    print("both")
    game.play_game(list(five_move_games[k]))

In [None]:
for i in on_600:
    print()
    game.play_game(list(five_move_games[i]))

In [None]:
for i in on_both:
    print()
    game.play_game(list(five_move_games[i]))

In [None]:
len(on_both)

In [None]:
plt.subplot(1, 2, 1)
plt.hist(interpretability.numpy(activations[:, 995]))
plt.subplot(1, 2, 2)
plt.hist(interpretability.numpy(activations[:, 600]))
# activations[:,600]

In [None]:
eight_move_games = torch.stack(
    [
        torch.tensor([10] + game.moves_played)
        for game in games
        if len(game.moves_played) == 8
    ]
)

In [None]:
# all_moves = []
# for move in range(9):
#     yes_move = []
#     no_move = []
#     for game in eight_move_games:
#         if move in game and game[-1] != move:
#             yes_move.append(game)
#         elif move not in game:
#             no_move.append(game)
#     all_moves.append([yes_move, no_move])

In [None]:
all_moves = []
move = 8
yes_move = []
no_move = []
for game in eight_move_games:
    if move in game and game[-1] != move:
        yes_move.append(game)
    elif move not in game:
        no_move.append(game)
all_moves.append([yes_move, no_move])

In [None]:
for i, l in enumerate(all_moves):
    print(f"move {i} present in {len(l[0])}", len(l[1]))

In [None]:
games_taken = 2000
eight_move_present_or_not_games = torch.cat(
    [
        torch.cat(
            [
                torch.stack([torch.tensor(g) for g in l[0][:games_taken]]),
                torch.stack([torch.tensor(g) for g in l[1][:games_taken]]),
            ]
        )
        for l in all_moves
    ]
)

In [None]:
tt(torch.stack([torch.tensor(g) for g in l[0][:games_taken]]))

In [None]:
c = (torch.stack([torch.tensor(g) for g in random.sample(l[0], games_taken)]),)

In [None]:
tt(c)

In [None]:
tt(eight_move_present_or_not_games)

In [None]:
import random

games_taken = 2000

eight_move_present_or_not_games = torch.cat(
    [
        torch.cat(
            [
                torch.stack(
                    [torch.tensor(g) for g in random.sample(l[0], games_taken)]
                ),
                torch.stack(
                    [torch.tensor(g) for g in random.sample(l[1], games_taken)]
                ),
            ]
        )
        for l in all_moves
        if len(l[0]) >= games_taken and len(l[1]) >= games_taken
    ]
)

In [None]:
tt(eight_move_present_or_not_games)

In [None]:
# acts = autoenc.get_activations(neuron_activations(eight_move_present_or_not_games))

In [None]:
# tt(acts)

In [None]:
# assert false

In [None]:
tt(eight_move_present_or_not_games[0])

In [None]:
act = torch.cat(
    [
        autoenc.get_activations(neuron_activations(eight_move_present_or_not_games))[
            :, -2
        ],
        autoenc.get_activations(neuron_activations(eight_move_present_or_not_games))[
            :, -1
        ],
    ],
    dim=0,
)

In [None]:
tt(act)

In [None]:
plt.xticks(
    ticks=[1000, 3000, 5000, 7000], labels=["gno, mp", "gno, mnp", "go, mp", "go, mnp"]
)
plt.axvline(x=2000, color="r")
plt.axvline(x=4000, color="r")
plt.axvline(x=6000, color="r")
plt.imshow(
    interpretability.numpy(act).T, aspect="auto", cmap="Greys", interpolation="none"
)
#fig size
plt.rcParams["figure.figsize"] = (20, 10)
#fig resolution
plt.rcParams["figure.dpi"] = 600

In [None]:
acc = 0
for g in eight_move_games:
    if game.play_game(g).winner == "O":
        print()
        acc += 1

In [None]:
len(set(eight_move_games))

In [None]:
acc

In [None]:
len(eight_move_games)

In [None]:
plot.imshow_div(act.T, aspect="auto", width=1000, height=1000)

In [None]:
plot.imshow_comp_acts(act, groups=['gno, mp', 'gno, mnp', 'go, mp', 'go, mnp'])

In [None]:
from alphatoe import game

In [None]:
game.play_game(eight_move_present_or_not_games[467 % 4000])


In [None]:
# feature 600 is on
for i in range(1528, 1538):
    game.play_game(eight_move_present_or_not_games[i])
    print("")

In [None]:
# feature 600 is on
for i in range(761, 767):
    game.play_game(eight_move_present_or_not_games[i])
    print("")

In [None]:
# feature 995 is on
for i in range(722, 730):
    game.play_game(eight_move_present_or_not_games[i])
    print("")

In [None]:
# feature 995 is off
for i in range(957, 974):
    game.play_game(eight_move_present_or_not_games[i])
    print("")

In [None]:
neuron_activations(
    model(
        torch.tensor(
            [
                10,
            ]
        )
    )
)

In [None]:
tt(act)

In [None]:
non_zero_acts = act[:, :, act.max(0).values.max(0).values > 0]

In [None]:
tt(non_zero_acts)

In [None]:
plt.subplot(1, 2, 1)
plt.hist(interpretability.numpy(activations[:, 995]))
plt.subplot(1, 2, 2)
plt.hist(interpretability.numpy(activations[:, 600]))

In [None]:
"""
- Clean up dead neurons
- Make 9 even split plots containing a bunch of games with and without particular moves
"""

In [None]:
# plt.subplot(9,1, 1)
plt.imshow(
    interpretability.numpy(non_zero_acts[4]).T,
    cmap="jet",
    aspect="auto",
    interpolation="none",
)
plt.xticks(ticks=[2500, 7500], labels=["move present", "move not present"])
# plt.subplot(9,1, 2)
# plt.imshow(non_zero_acts[1], cmap="jet", aspect="auto", interpolation="none")
# plt.subplot(9,1, 3)
# plt.imshow(non_zero_acts[2], cmap="jet", aspect="auto", interpolation="none")