In [None]:
import sys

sys.path.insert(0, "..")
sys.path.insert(0, "../src")

from experiments.utils import *

from src.constants import *
from src.dataset import get_data_loader
from src.models import SimpleICLModel, SimpleICLModelLearnedIWPredictor
from src.utils import parse_dict, load_config, iterate_models, set_seed

import jax.numpy as jnp
import os

# Penzai
from penzai import pz

import IPython

pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [None]:
seed = 42
set_seed(seed)

In [None]:
log_dir = "/Users/chanb/research/ualberta/simple_icl/experiments/results"
experiment_name = "high_prob_0.99"
variant = "transformer-09-02-24_13_39_27-bf1eeb21-928a-4829-b008-f82e5369e4d0"

learner_path = os.path.join(log_dir, experiment_name, variant)

In [None]:
config_dict, config = load_config(learner_path)
config_dict["batch_size"] = 1
config = parse_dict(config_dict)

In [None]:
train_data_loader, train_dataset = get_data_loader(
    config
)

In [None]:
model_iter = iterate_models(
    learner_path
)

params_init, model, checkpoint_step_init = next(model_iter)

for _ in range(40):
    params_next, _, checkpoint_step_next = next(model_iter)

In [None]:
for batch in train_data_loader:
    if np.argmax(batch["target"][0, -1], axis=-1) == 1:
        break

In [None]:
batch

In [None]:
result = dict()
for params, checkpoint_step in zip(
    [params_init, params_next],
    [checkpoint_step_init, checkpoint_step_next]
):
    out, aux = model.get_attention(
        params[CONST_MODEL],
        batch,
        eval=True,
    )

    result[checkpoint_step] = dict()
    for block in aux["gpt"]["intermediates"]:
        if not block.startswith("GPTBlock_"):
            continue

        # axis=1 -> query
        # axis=2 -> key
        self_attention_map = aux["gpt"]["intermediates"][block]["SelfAttentionModule_0"]["attention"][0][0]
        self_attention_map = self_attention_map.at[self_attention_map <= -1e10].set(jnp.nan)

        attention_score = jnp.sum(aux["gpt"]["intermediates"][block]["attention"][0][0], axis=-1)
        input_vector = jnp.sum(aux["gpt"]["intermediates"][block]["input"][0][0], axis=-1)
        block_output = jnp.sum(aux["gpt"]["intermediates"][block]["block_out"][0][0], axis=-1)

        result[checkpoint_step][block] = dict(
            self_attention_map=self_attention_map,
            attention_score=jnp.vstack(
                (
                    np.argmax(batch["target"][0], axis=-1),
                    attention_score[::2],
                    jnp.concatenate((attention_score[1::2], jnp.array([jnp.nan])), axis=-1),
                )
            ),
            input_vector=jnp.vstack(
                (
                    np.argmax(batch["target"][0], axis=-1),
                    input_vector[::2],
                    jnp.concatenate((input_vector[1::2], jnp.array([jnp.nan])), axis=-1),
                )
            ),
            block_output=jnp.vstack(
                (
                    np.argmax(batch["target"][0], axis=-1),
                    block_output[::2],
                    jnp.concatenate((block_output[1::2], jnp.array([jnp.nan])), axis=-1),
                )
            ),
        )

In [None]:
pz.ts.display(result)

In [None]:
sample["outputs"]