In [None]:
%load_ext autoreload
%autoreload 2
from alphatoe import models, plot, interpretability, game
import pandas as pd
import torch
from pytorch_memlab import LineProfiler, MemReporter
from showmethetypes import SMTT
import einops
import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatter
import numpy as np
import tqdm

In [None]:
tt = SMTT()

In [None]:
autoenc = models.SparseAutoEncoder(512, 512).to("cuda")

loss_fn = torch.nn.functional.mse_loss
optimizer = torch.optim.Adam(autoenc.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
act_data = torch.load("all_games_act_data.pt")

In [None]:
act_data[0].mean()

In [None]:
test = loss_fn(torch.zeros(2, 2), torch.ones(2, 2), reduction="none")
test

### Getting a sparser encoder! (actually following instructions)
- L0 around 10 or 20 on average across 1000 games
- feature density is mostly under 1%
- reconstruction loss stays low

In [None]:
epochs = 30
batch_size = 2**15
lam = 1e-7
losses = []
for epoch in range(epochs):
    for batch in range(0, act_data.shape[0], batch_size):
        dat = act_data[batch : batch + batch_size].to("cuda")

        l0, reg, guess = autoenc(dat)
        mse_loss = loss_fn(guess, dat)

        sparse_loss = lam * reg
        # sparse_loss = 0
        loss = mse_loss + sparse_loss
        # losses.append(interpretability.numpy(loss))
        losses.append([mse_loss.item(), sparse_loss.item(), l0.item()])
        optimizer.zero_grad()
        loss.backward()
        print(losses[-1])
        optimizer.step()

        with torch.no_grad():
            last_loss = loss_fn(guess, dat, reduction="none")

## It should be noting that the 512 SAE seems to achieve the same loss and L0 than the 1024. Probably not worth looking into too much, but interesting observation

In [None]:
last_loss.sort()

In [None]:
plt.yscale("log")
plt.plot(range(len(losses)), losses)

Great. Now we've got an autoencoder, what do we do with it?

In [None]:
@torch.no_grad()
def get_freqs(num_batches=25, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder
    act_freq_scores = torch.zeros(
        local_encoder.W_in.shape[1], dtype=torch.float32
    ).cuda()
    total = 0
    for i in tqdm.trange(num_batches):
        tokens = act_data[torch.randperm(len(act_data))][: 2**14].to("cuda")

        hidden = local_encoder.get_act_density(tokens)

        act_freq_scores += hidden
        total += tokens.shape[0]
    act_freq_scores /= total
    num_dead = (act_freq_scores == 0).float().mean()
    print("Num dead", num_dead)
    return act_freq_scores

In [None]:
freqs = get_freqs(local_encoder=autoenc)

In [None]:
tt(freqs)

In [None]:
print(freqs[112] * act_data.shape[0])

In [None]:
x = interpretability.numpy(freqs) * act_data.shape[0]
# x = interpretability.numpy(freqs)
x = x[np.isfinite(x)]
fig, ax = plt.subplots()
# set figure size
fig.set_size_inches(10, 6)
ax.hist(x, bins=np.logspace(np.log10(5), np.log10(10000000), 100))
ax.set_xscale("log")
# x label
# ax.xlabel("Number of Moves (log 10 scale)");
# y label
# ax.ylabel("Count of Features(neuron acts)");
# set xtick and labels of ticks
tick_positions = [1, 10, 100, 1000, 10000, 100000, 1000000]
tick_labels = ["1", "10", "100", "1k", "10k", "100k", "1M"]
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels)
# ax.get_xaxis().set_major_formatter(plt.ScalarFormatter());
ax.set_xlabel("Number of Times Fired Out of 2,361,456")
ax.set_ylabel("Count of Features(neuron acts)")

In [None]:
torch.save(
    autoenc.state_dict(), "./512_sparse_autoencoder_on_activations_20NOV2023_parameters.pt"
)