In [None]:
from pathlib import Path
import torch
from matplotlib import pyplot as plt
import numpy as np
from sainomore.data import occurences
from sainomore.xai import ElissabethWatcher, get_alphabet_projection

In [None]:
model_id = "kukzl2q4"

In [None]:
watcher = ElissabethWatcher.load(model_id, on_cpu=True)

## Model based

In [None]:
fig, ax = watcher.plot_alphabet_projection(
    q=False,
    k=False,
    v=True,
    transpose=False,
    reduce_dims={0: 0},
    figsize=(70, 20),
)
# fig.savefig(
#     f"copying_values_{model_id}.pdf",
#     facecolor=(0, 0, 0, 0),
#     bbox_inches="tight",
# )

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.levels.0.weightings.0.P_Q.transform.0.weight",
    figsize=(10, 5),
)
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.levels.0.weightings.0.P_Q.transform.2.weight",
    # log_cmap=(0.1, 1.0),
    figsize=(10, 5),
)

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.W_H",
    # reduce_dims={2: 0},
    # append_dims=(0,1,),
    figsize=(10, 10),
)
fig, ax = watcher.plot_parameter_matrix(
    "layers.0.W_O",
    reduce_dims={2: 0},
    append_dims=(0,1,),
    figsize=(10, 10),
)

In [None]:
fig, ax = watcher.plot_parameter_matrix(
    "unembedding.weight",
    figsize=(10, 10),
)

In [None]:
fig, ax = watcher.plot_qkv_probing(
    which="q",
    layer=0,
    length=0,
    weighting=0,
    reduce_dims={2: 0},
    append_dims=(2,),
    figsize=(25, 2),
)

In [None]:
fig, ax = watcher.plot_qkv_probing(
    which="k",
    layer=0,
    length=0,
    weighting=0,
    reduce_dims={2: 0},
    append_dims=(2,),
    figsize=(20, 5),
)

In [None]:
fig, ax = watcher.plot_qkv_probing(
    which="v",
    layer=0,
    length=0,
    weighting=0,
    reduce_dims={2: 0},
    append_dims=(2, ),
    figsize=(20, 5),
)

In [None]:
import json
with open("quotes.txt", encoding="utf-8") as f:
    file = f.readlines()
lengths = []
for line in file:
    lengths.append(len(line))
np.mean(lengths)

## Example based

In [None]:
x, y = occurences(
    1,
    length=100,
    characters=5,
    occurences=4,
)
x, y

In [None]:
fig, ax = watcher.plot_iss_time(
    x[0],
    # x_axis=watcher.model(x).argmax(-1)[0, :],
    # reduce_dims={1: 0},
    append_dims=(1, ),
    index_selection=((-1, torch.arange(100)), ),
    project_heads=False,
    project_values=False,
    figsize=(10, 4),
)
# fig, ax = watcher.plot_iss(
#     x[0],
#     append_dims=(1, ),
#     project_heads=True,
#     project_values=False,
#     figsize=(20, 20),
# )

In [None]:
fig, ax = watcher.plot_query_key_time(
    x[0],
    weighting=0,
    names=("query", "key"),
    cmap="tab10",
    # x_axis=x[0],
    project_heads=False,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    reduce_dims={0: 0},
    figsize=(20, 5),
)

In [None]:
fig, ax = watcher.plot_values_time(
    x[0],
    x_axis=x[0],
    project_heads=False,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    reduce_dims={0: 0},
    figsize=(20, 5),
)

In [None]:
fig, ax = watcher.plot_attention_matrices(
    x[0],
    show_example=True,
    total=True,
    cmap="seismic",
    share_cmap=False,
    log_cmap=False,#(10, 1),
    causal_mask=True,
    only_kernels=None,
    project_heads=False,#tuple(torch.where(W_H.abs()[0] > 5)[0].numpy()),
    center_zero=True,
    cmap_example="tab20",
    figsize=(50, 10),
)
fig.savefig(
    f"attention_{model_id}.pdf",
    facecolor=(0, 0, 0, 0),
    bbox_inches="tight",
)