# Visualizing Var Rename Activations

In [None]:
import torch
import numpy as np

from iluvattnshun.utils import load_checkpoint, load_config_from_yaml
from iluvattnshun.nn import MultilayerTransformer

from var_rename import VariableRenamingConfig, VariableRenamingPrompter
from iluvattnshun.viz import get_fig

In [24]:
config_path = "/home/michael-lutz/iluvattnshun/logs/var_rename/run_221_20250623_143312/run_221_20250623_143312.yaml"
ckpt_path = "/home/michael-lutz/iluvattnshun/logs/var_rename/run_221_20250623_143312/ckpt_epoch_11.pt"

In [27]:
config = load_config_from_yaml(config_path, VariableRenamingConfig)

# loading the model
max_seq_len = config.num_renames * 4
model = MultilayerTransformer(
    vocab_size=39,
    d_model=config.dim_model,
    n_heads=config.num_heads,
    n_layers=config.num_layers,
)
load_checkpoint(ckpt_path, model)
model.eval()

# loading the prompter
prompter = VariableRenamingPrompter(config)

In [None]:
# printing out the prompt, pred (per token), and true answer
rng = np.random.default_rng()
prompt, answer, metadata = prompter.get_prompt(rng)


# prompt = "1>o;o>n;n>i;0>k;i>p;k>f;f>o;o>g;g>c;p>e;c>j;j>z;e>v;v>e;e>p;z>m;m>s;s>x;p>i;i>t;t>w;x>d;w>j;j>l;l>k;k>o;o>g;g>f;f>u;d>e;u>l;l>g;g>i;e>j;j>p;i>z;p>h;h>j;z>i;j>r;"
# answer = "1.1.1.1.1.1.0.0.1.1.0.0.0.0.0.0.0.0.1.1.0.0.0.0.1.1.1.1.1.1.0.0.0.0.0.0.1.1.1.1.1.1.0.0.1.1.1.1.1.1.1.1.1.1.1.1.1.1.0.0.1.1.1.1.1.1.0.0.0.0.1.1.0.0.0.0.1.1.0.0."
x = torch.tensor(prompter.tokenize(prompt)).unsqueeze(0)
logits, attn_weights, _= model.forward(x, return_attn_weights=True)
pred = prompter.detokenize(logits[0].argmax(dim=-1).tolist())

print("prompt: ", prompt)
print("pred:   ", pred)
print("answer: ", answer)


prompt:  1>o;o>n;n>i;0>k;i>p;k>f;f>o;o>g;g>c;p>e;c>j;j>z;e>v;v>e;e>p;z>m;m>s;s>x;p>i;i>t;t>w;x>d;w>j;j>l;l>k;k>o;o>g;g>f;f>u;d>e;u>l;l>g;g>i;e>j;j>p;i>z;p>h;h>j;z>i;j>r;
pred:    1111111111110000111100000000000000001111000000001111111111110000000000001111111111110000111111111110111110111111111001001011101011110101000010110000000111100001
answer:  1.1.1.1.1.1.0.0.1.1.0.0.0.0.0.0.0.0.1.1.0.0.0.0.1.1.1.1.1.1.0.0.0.0.0.0.1.1.1.1.1.1.0.0.1.1.1.1.1.1.1.1.1.1.1.1.1.1.0.0.1.1.1.1.1.1.0.0.0.0.1.1.0.0.0.0.1.1.0.0.


In [29]:
# getting batchless np attn weights and plotting
weights = [attn_weight.detach().cpu().numpy()[0] for attn_weight in attn_weights]
plotly_fig = get_fig(weights, list(prompt))

In [30]:
plotly_fig.show()

In [None]:
# With matplotlib, plot attention weights over layers and heads

# token_labels = list(prompt)
# num_layers = len(attn_weights)
# num_heads = config.num_heads
# total_heads = num_layers * num_heads

# rows = num_layers
# cols = num_heads
# fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 5), squeeze=False)

# head_idx = 0
# for layer_idx, attn_weight_heads in enumerate(attn_weights):
#     for head_number in range(num_heads):
#         row = head_idx // cols
#         col = head_idx % cols
#         ax = axes[row][col]

#         attn_weight = attn_weight_heads[0, head_number].detach().cpu().numpy()
#         im = ax.imshow(attn_weight, cmap="viridis")

#         ax.set_title(f"L{layer_idx}H{head_number}", fontsize=20)

#         ax.set_xticks(range(len(token_labels)))
#         ax.set_xticklabels(token_labels, fontsize=7)
#         ax.set_yticks(range(len(token_labels)))
#         ax.set_yticklabels(token_labels, fontsize=7)

#         head_idx += 1

# # Hide any unused subplots
# for i in range(head_idx, rows * cols):
#     fig.delaxes(axes[i // cols][i % cols])

# fig.suptitle("Attention Weights by Layer and Head", fontsize=25)
# plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.show()
