In [None]:
from alphatoe import models, plot, interpretability, game
import pandas as pd
import torch
from pytorch_memlab import LineProfiler, MemReporter
from showmethetypes import SMTT
import einops
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatter
import numpy as np
import tqdm

In [None]:
tt = SMTT("torch")

In [None]:
model = interpretability.load_model(
    "../scripts/models/prob all 8 layer control-20230718-185339"
)

In [None]:
games = pd.read_csv("../data/prob all 8 layer control-20230718-185339_stats.csv")
games.head()

- Sort games by game length [X]
- batch inference for games of length [X]
- extract activations with hooks [X]
- train autoencoder on data reconstruction (anthropic has tips here) [X]
- find good metrics + start looking at data (anthropic has tips here)

In [None]:
all_games = game.generate_all_games([game.Board()])

In [None]:
len(all_games)

In [None]:
games_len_dict = {5: [], 6: [], 7: [], 8: [], 9: []}
for game in all_games:
    games_len_dict[len(game.moves_played)].append(game.moves_played)

In [None]:
s = 0
for key in games_len_dict.keys():
    s += len(games_len_dict[key])
print(s)

In [None]:
games_len_tensors = {}
for key in games_len_dict.keys():
    games_len_tensors[key] = torch.stack(
        [torch.tensor([10] + game) for game in games_len_dict[key]]
    )

In [None]:
games_len_tensors.keys()

In [None]:
games_len_tensors[9].device

In [None]:
# Activation hook
def neuron_activations(seq):
    def hook(module, input, output):
        result = output.clone()
        module.captured_activations = result

    try:
        with torch.inference_mode():
            handle = model.blocks[0].mlp.hook_post.register_forward_hook(hook)
            _ = model(seq)
            activations = model.blocks[0].mlp.hook_post.captured_activations
            handle.remove()
    except Exception as e:
        handle.remove()
        raise e

    return activations

In [None]:
# a = neuron_activations(games_len_tensors[5])[:, -1]

In [None]:
# a.numel() * a.element_size()

In [None]:
# b = neuron_activations(games_len_tensors[9])
# reporter = MemReporter()
# reporter.report()

In [None]:
torch.cuda.empty_cache()

In [None]:
length = games_len_tensors[9].shape[0]
batchy_size = length // 4

In [None]:
# Inference loop
# will there be a difference across game lengths?
# Doesn't fit in memory?
# 512 neurons * 255168 games * 32 bit floats * 10 seq len = 5.22 gigabytes???

all_acts = []
for i, key in enumerate(games_len_tensors.keys()):
    print(i)
    if games_len_tensors[key].shape[0] < batchy_size:
        acts = neuron_activations(games_len_tensors[key])
        all_acts.append(acts.to("cpu"))
    else:
        for j in tqdm.trange(
            0, games_len_tensors[key].shape[0], batchy_size, desc=f"Batch {i}"
        ):
            acts = neuron_activations(games_len_tensors[key][j : j + batchy_size])
            all_acts.append(acts.to("cpu"))
    print(acts.device)
    torch.cuda.empty_cache()

In [None]:
reporter = MemReporter()
reporter.report()

In [None]:
games_len_tensors[9].device

In [None]:
for act in all_acts:
    tt(act)

In [None]:
out = 0
for act in all_acts:
    out += act.shape[0] * act.shape[1]
print(out)


In [None]:
act_data = torch.cat(
    [einops.rearrange(acts, "batch seq dim -> (batch seq) dim") for acts in all_acts],
    dim=0,
)

In [None]:
torch.save(act_data, "all_games_act_data.pt")

Note that since there's a lot of repeated phrases in the input, we'll have lots of identical activations. Not sure how that'll change things yet though..

In [None]:
tt(act_data)

In [None]:
autoenc = models.SparseAutoEncoder(512, 1024).to("cuda")

loss_fn = torch.nn.functional.mse_loss
optimizer = torch.optim.Adam(autoenc.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
act_data = torch.load("all_games_act_data.pt")

In [None]:
tt(act_data)

In [None]:
act_data[0].mean()

In [None]:
test = loss_fn(torch.zeros(2,2), torch.ones(2,2), reduction="none")
test

In [None]:
epochs = 10
batch_size = 2**15
lam = 1e-7
losses = []
for epoch in range(epochs):
    for batch in range(0, act_data.shape[0], batch_size):
        dat = act_data[batch : batch + batch_size].to("cuda")
        
        reg, guess = autoenc(dat)
        mse_loss = loss_fn(guess, dat)
        
        sparse_loss = lam * reg
        #sparse_loss = 0
        loss = mse_loss + sparse_loss
        #losses.append(interpretability.numpy(loss))
        losses.append([interpretability.numpy(mse_loss), sparse_loss])
        optimizer.zero_grad()
        loss.backward()
        print(losses[-1])
        optimizer.step()

        with torch.no_grad():
            last_loss = loss_fn(guess, dat, reduction='none')

In [None]:
last_loss.sort()

In [None]:
plt.yscale("log")
plt.plot(range(len(losses)), losses)

Great. Now we've got an autoencoder, what do we do with it?

In [None]:
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(
        local_encoder.W_in.shape[1], dtype=torch.float32
    ).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = act_data[torch.randperm(len(act_data))][: 2**14].to("cuda")

        hidden = local_encoder.get_act_density(tokens)

        act_freq_scores += hidden
        total += tokens.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores == 0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

In [None]:
freqs = get_freqs(local_encoder=autoenc)

In [None]:
tt(freqs)

In [None]:
print(freqs[112]*act_data.shape[0])

In [None]:
x = interpretability.numpy(freqs)*act_data.shape[0];
# x = interpretability.numpy(freqs)
x = x[np.isfinite(x)];
fig, ax = plt.subplots();
#set figure size
fig.set_size_inches(10, 6);
ax.hist(x, bins=np.logspace(np.log10(5), np.log10(10000000), 100));
ax.set_xscale("log");
#x label
#ax.xlabel("Number of Moves (log 10 scale)");
#y label
#ax.ylabel("Count of Features(neuron acts)");
#set xtick and labels of ticks
tick_positions = [1, 10, 100, 1000, 10000, 100000, 1000000]
tick_labels = ['1', '10', '100', '1k', '10k', '100k', '1M']
ax.set_xticks(tick_positions);
ax.set_xticklabels(tick_labels);
#ax.get_xaxis().set_major_formatter(plt.ScalarFormatter());
ax.set_xlabel("Number of Times Fired Out of 2,361,456");
ax.set_ylabel("Count of Features(neuron acts)");

In [None]:
torch.save(autoenc, "sparse_autoencoder_on_activations_02NOV2023.pt")