In [None]:
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import cm
import torch
import os
import glob
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from tqdm import tqdm
import numpy as np
from tensorboard.backend.event_processing import event_accumulator
from itertools import product
from utils.utils import standardize, read_scalars, make_path

root = "/home/kitouni/projects/Grok/grokking-squared/runs/"
name = "modular-addition11-mlp/0429-1407"
directory = os.path.join(root, name)

scalars = read_scalars(glob.glob(os.path.join(directory, "events*"))[0])

repr_files = glob.glob(os.path.join(directory, "weights/*.embd"))
repr_files = sorted(repr_files, key=lambda x: int(os.path.basename(x).split(".")[0]))

model_files = glob.glob(os.path.join(directory, "weights/*.ckpt"))
model = torch.load(os.path.join(directory, "model.pt"))
model.eval()


loss_train, loss_test, acc_train, acc_test = scalars[
    "loss/train"], scalars["loss/test"], scalars["acc/train"], scalars["acc/test"]

def apply_model(x, file):
        model.load_state_dict(torch.load(file))
        return model(x).detach().numpy().argmax(axis=1).reshape(-1)

def animate(reducer, name, nskip=1):
    n_repr, len_repr = torch.load(repr_files[0]).cpu().numpy().shape

    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    
    _linspace = np.linspace(0.0, 1.0, 100)
    grid = torch.tensor(list(product(*[_linspace]*2))).float()
    pred = np.random.randint(0, n_repr, size=grid.shape[0])
    heatmap = ax.scatter([], c=pred, s=20, cmap=cm.jet, alpha=0.9, zorder=0, marker="s")
    plt.colorbar(heatmap, fraction=0.046, pad=0.04)

    sc = ax.scatter(*np.ones((2, n_repr)), c=np.arange(n_repr),
                    cmap=cm.viridis, s=100)
    text = ax.text(0., 1.01, "", transform=ax.transAxes, fontdict={"size": 18})
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)

    sc.set_paths([make_path(f"${m:02d}$") for m in range(n_repr)])


    def update(i):
        file = repr_files[i]
        representation = torch.load(file).cpu().numpy()
        reducer.fit_transform(representation)
        transformed = representation
        grid_to_repr = grid * (transformed.max(0) - transformed.min(0)) + transformed.min(0)
        output = apply_model(grid_to_repr, model_files[i])
        # output = (output - 0) / (n_repr - 1)
        heatmap.set_color(cm.jet(output))
            
            
        transformed = transformed[:, :2]
        transformed = (transformed - transformed.min(0)
                       ) / (transformed.max(0) - transformed.min(0))
        sc.set_offsets(transformed)

        title = os.path.basename(file).split(".")[0]
        title += f"\nLoss: {loss_train[i][1]:.2e}|{loss_test[i][1]:.2e}"
        title += f" Acc: {acc_train[i][1]:.2f}|{acc_test[i][1]:.2f}"

        text.set_text(title)
        return heatmap, sc, text

    print(f"now animating {name}")
    range_obj = range(0, len(repr_files), nskip)
    tbar = tqdm(range_obj)
    animation = FuncAnimation(
        fig, update, frames=range_obj, repeat=False, blit=False)
    fname = f"/home/kitouni/projects/Grok/grokking-squared/{name}.mp4"
    animation.save(fname,
                   writer="ffmpeg", progress_callback=lambda *_: tbar.update())
    plt.close()
    print(f"animation saved to {fname}")


if __name__ == "__main__":
    nskip = args.nskip
    which = ["tsne", "pca"] if args.which == "both" else [args.which]

    if "tsne" in which:
        tsne = TSNE(n_components=2, perplexity=40, n_iter=1000,
                    init="pca", learning_rate="auto", )
        animate(tsne, "-".join(name.split("/")[-3:-1]) + "_tsne", nskip=nskip)
    if "pca" in which:
        pca = PCA()
        animate(pca, "-".join(name.split("/")[-3:-1]) + f"_pca{nskip}", nskip=nskip)
