In [None]:
from pathlib import Path

import numpy as np
import torch

from sainomore.xai import ElissabethWatcher

from data import LetterAssembler

In [None]:
model_id = ""

In [None]:
assembler = LetterAssembler(Path("quotes.txt"))
config = {
    "context_length": assembler.context_length,
    "characters": assembler.vocab_size,
}

In [None]:
watcher = ElissabethWatcher.load(model_id, on_cpu=True)

## Model based

In [None]:
fig, ax = watcher.plot_alphabet_projection(
    q=False,
    k=False,
    v=True,
    n=4,
    tokens=torch.tensor([0, 82, 83]),
    positions=3,
    transpose=True,
    reduce_dims={0: 0},
    figsize=(10, 2),
)
# fig.savefig(
#     f"copying_values_{model_id}.pdf",
#     facecolor=(0, 0, 0, 0),
#     bbox_inches="tight",
# )

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.levels.0.weightings.0.P_Q.transform.0.weight",
    figsize=(10, 5),
)
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.levels.0.weightings.0.P_Q.transform.2.weight",
    # log_cmap=(0.1, 1.0),
    figsize=(10, 5),
)

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.W_H",
    # reduce_dims={2: 0},
    # append_dims=(0,1,),
    figsize=(2, 2),
)
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.W_O",
    reduce_dims={2: 0},
    append_dims=(0,1,),
    figsize=(2, 5),
)

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "unembedding.weight",
    figsize=(10, 5),
)

In [None]:
fig, ax = watcher.plot_qkv_probing(
    which="q",
    layer=0,
    length=0,
    weighting=0,
    reduce_dims={2: 0},
    append_dims=(2,),
    figsize=(25, 2),
)

In [None]:
fig, ax = watcher.plot_qkv_probing(
    which="k",
    layer=0,
    length=0,
    weighting=0,
    reduce_dims={2: 0},
    append_dims=(2,),
    figsize=(20, 5),
)

In [None]:
fig, ax = watcher.plot_qkv_probing(
    which="v",
    layer=0,
    length=0,
    weighting=0,
    reduce_dims={2: 0},
    append_dims=(2, ),
    figsize=(20, 2),
)

In [None]:
import json
with open("quotes.txt", encoding="utf-8") as f:
    file = f.readlines()
lengths = []
for line in file:
    lengths.append(len(line))
np.mean(lengths)

## Example based

In [None]:
def generate(start: str, n_tokens: int, temperature: float = 0.1) -> str:
    tensor = assembler.to_tensor(start, fill=False).unsqueeze(0)
    for _ in range(n_tokens):
        prob = torch.softmax(watcher.model(tensor)[:, -1:, :]/temperature, -1)
        tensor = torch.cat(
            (tensor, torch.multinomial(prob[:, 0, :], num_samples=1)),
            dim=1,
        )
    return assembler.translate(tensor[0])

In [None]:
generate("\{\"quote\":\"Love ", 10, temperature=0.5)

In [None]:
torch.random.manual_seed(62)
np.random.seed(62)
x, y = assembler.sample()
print(assembler.translate(x[0]))
print("".join(map(
    lambda x: assembler.itos.get(x, " "),
    watcher.model(x[:, :25]).argmax(-1)[0].numpy()
)))

In [None]:
time = torch.tensor((1,2,3,4,9,94,95,96))
print(x[0][time])
list(assembler.translate(x[0][time[:-1]]))

In [None]:
fig, ax = watcher.plot_iss_time(
    x[0],
    x_axis=["{"]+list(assembler.translate(x[0][time][:-1]))+["-"],
    index_selection=((-1, time), ),
    figsize=(20, 5),
)

In [None]:
indices = [1, 2, 6, 7, 8, 9, 11] + list(range(13, 52)) + list(range(56, 84))

In [None]:
fig, ax = watcher.plot_attention_matrices(
    torch.tensor(indices),
    xlabels=list(assembler.translate(torch.tensor([0]+indices))),#["{"]+list(assembler.translate(x[0][time][:-1]))+["-"],
    show_example=False,
    # total=True,
    cmap="hot_r",
    share_cmap=True,
    log_cmap=False,
    causal_mask=False,
    only_kernels=None,
    # value_direction=1,
    all_but_first_value=False,
    project_heads=True,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    # index_selection=((-2, torch.arange(100)), (-1, torch.arange(100))),
    # index_selection=((-2, time), (-1, time)),
    center_zero=False,
    cmap_example="tab20",
    figsize=(50, 50),
)
# fig.savefig(
#     f"attention_{model_id}_all_tokens.pdf",
#     facecolor=(0, 0, 0, 0),
#     bbox_inches="tight",
# )