In [None]:
import torch
from matplotlib import pyplot as plt
import numpy as np

from sainomore.xai import ElissabethWatcher, get_alphabet_projection

from data import cyclic

In [None]:
model_id = ""

In [None]:
watcher = ElissabethWatcher.load(model_id, on_cpu=True)

In [None]:
# state_dict = watcher.model.state_dict()
# # state_dict["layers.0.W_H"] = torch.tensor([[0, 1, 1, 1, 1]])
# state_dict["layers.0.W_O"] = torch.tensor([[[1, 1, 1, 1, 1, 1]]])
# watcher.model.load_state_dict(state_dict)

## Model based

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.W_H",
    figsize=(3, 3),
)
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.W_O",
    reduce_dims={2: 0},
    append_dims=(0,1,),
    figsize=(10, 3),
)

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "embedding.weight",
    figsize=(4, 4),
)
fig, ax = watcher.plot_parameter_matrix(
    "unembedding.weight",
    figsize=(4, 4),
)

## Example based

In [None]:
x, y = cyclic(
    1,
    length=5,
    characters=6,
)
print(x)
print(x[x!=0])
print(y, watcher.model(x)[0,-1])

In [None]:
x = torch.arange(6)
v = watcher.get_values(x)

fig, ax = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(5, 6))
for i in range(4):
    for j in range(2):
        ax[i,j].bar(
            x,
            v[i,j,0].numpy(),
            color=["red" if k == v[i,j,0].argmax() else "orange" for k in range(6)],
            width=.75,
        )
        ax[i,j].set_xticks(x)
        ax[i,j].set_xticklabels(x.numpy(), fontsize=15)
        ax[i,j].set_yticks(torch.linspace(-5, 2, 8))
        ax[i,j].set_yticklabels(["-5", "", "", "", "", "0", "", "2"])
        ax[i,j].set_ylim(-5, 2)
        ax[i,j].grid()
ax[0, 0].set_title("$v^{[1]}$", fontsize=20)
ax[0, 1].set_title("$v^{[2]}$", fontsize=20)
# plt.savefig(
#     "cyclic_arctic_values.pdf",
#     facecolor=(0, 0, 0, 0),
#     bbox_inches="tight",
# )

In [None]:
fig, ax = watcher.plot_values_time(
    x[0],
    x_axis=x[0].numpy(),
    project_heads=False,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    # reduce_dims={0: 0},
    figsize=(10, 5),
)

In [None]:
fig, ax = watcher.plot_iss_time(
    x[0],
    x_axis=x[0].numpy(),
    project_heads=False,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    # reduce_dims={0: 0},
    figsize=(10, 5),
)

In [None]:
fig, ax = watcher.plot_attention_matrices(
    torch.arange(6),#.flip(0),#x[0],#
    xlabels=list(map(str, np.arange(6))),
    show_example=False,
    # total=True,
    cmap="seismic",
    share_cmap=False,
    log_cmap=False,
    causal_mask=False,
    only_kernels=None,
    # value_direction=0,
    # all_but_first_value=False,
    project_heads=False,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    center_zero=True,
    # index_selection=((-1, torch.arange(95, 100)), (-2, torch.arange(95, 100))),
    cmap_example="tab10",
    figsize=(9, 10),
)
# fig.savefig(
#     f"cyclic_attention_{model_id}.pdf",
#     facecolor=(0, 0, 0, 0),
#     bbox_inches="tight",
# )