In [None]:
%load_ext autoreload
%autoreload 2
from alphatoe import models, plot, interpretability, game
import pandas as pd
import torch
from pytorch_memlab import LineProfiler, MemReporter
from showmethetypes import SMTT
import einops
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatter
import numpy as np
import tqdm
import random

In [None]:
autoenc = models.SparseAutoEncoder(512, 512).cuda()
autoenc.load_state_dict(torch.load("./512_sparse_autoencoder_on_activations_20NOV2023_parameters.pt"))

In [None]:
act_data = torch.load("./all_games_act_data.pt")

In [None]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [None]:
def neuron_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
        _ = model(seq)
        activations = model.blocks[0].mlp.hook_post.captured_activations
        handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations

In [None]:
boards = game.generate_all_games([game.Board()])

In [None]:
eight_move_games = torch.stack(
    [
        torch.tensor([10] + board.moves_played)
        for board in boards
        if len(board.moves_played) == 8
    ]
)

In [None]:
eight_move_games.shape

In [None]:
all_moves = []
move = 8
yes_move = []
no_move = []
for board in eight_move_games:
    if move in board and board[-1] != move:
        yes_move.append(board)
    elif move not in board:
        no_move.append(board)
all_moves.append([yes_move, no_move])

In [None]:
import random

games_taken = 2000

eight_move_present_or_not_games = torch.cat(
    [
        torch.cat(
            [
                torch.stack(
                    [torch.tensor(g) for g in random.sample(l[0], games_taken)]
                ),
                torch.stack(
                    [torch.tensor(g) for g in random.sample(l[1], games_taken)]
                ),
            ]
        )
        for l in all_moves
        if len(l[0]) >= games_taken and len(l[1]) >= games_taken
    ]
)

In [None]:
act = torch.cat(
    [
        autoenc.get_activations(neuron_activations(eight_move_present_or_not_games))[
            :, -2
        ],
        autoenc.get_activations(neuron_activations(eight_move_present_or_not_games))[
            :, -1
        ],
    ],
    dim=0,
)

In [None]:
plt.hist(act[4000:6000, 314].detach().cpu().numpy(), bins=100);
plt.hist(act[:2000, 314].detach().cpu().numpy(), bins=100);
plt.hist(act[2000:4000, 314].detach().cpu().numpy(), bins=100);
plt.hist(act[6000:, 314].detach().cpu().numpy(), bins=100);

In [None]:
plot.imshow_comp_acts(act, groups=['gno, mp', 'gno, mnp', 'go, mp', 'go, mnp'])

In [None]:
from alphatoe import game

In [None]:
#for game not over
game.play_game(eight_move_present_or_not_games[1325 % 4000][:-1])

In [None]:
#for game over
game.play_game(eight_move_present_or_not_games[473 % 4000])