In [1]:
from utils import set_size, pgf_with_latex
import sys

sys.path.insert(0, "..")
sys.path.insert(0, "../src")

from experiments.utils import *

from src.constants import *
from src.dataset import get_data_loader
from src.utils import parse_dict, load_config, iterate_models, set_seed

import jax.numpy as jnp
import os

import _pickle as pickle
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import numpy as np
import os
import seaborn as sns
import pandas as pd
import timeit

from itertools import product

# Penzai
from penzai import pz

import IPython

pz.ts.register_as_default()

# Optional automatic array visualization extras:
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [2]:
seed = 0
set_seed(seed)

In [3]:
plt.style.use("seaborn-v0_8-colorblind")
plt.rcParams.update(pgf_with_latex)

In [4]:
# log_dir = "/Users/chanb/research/ualberta/simple_icl/experiments/results"
# experiment_name = "high_prob_0.99"
# variant = "transformer-09-02-24_13_39_27-bf1eeb21-928a-4829-b008-f82e5369e4d0"

# learner_path = os.path.join(log_dir, experiment_name, variant)

# learner_path = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/results/synthetic-transformer-p_relevant/dataset_size_16384-p_relevant_context_0.99-seed_0-09-18-24_22_42_11-43f5227d-b0c5-4f3a-984f-b45ad9a8e238"
learner_path = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/results/synthetic-transformer-p_relevant/dataset_size_1024-p_relevant_context_0.99-seed_0-09-18-24_22_41_46-26651528-70d9-40f2-998a-1959fe28a94e"
# learner_path = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/results/synthetic-transformer-p_relevant/dataset_size_1048576-p_relevant_context_0.99-seed_0-09-21-24_16_36_00-11712db5-c996-4adb-bd63-2c9d977c2f30"
# learner_path = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/omniglot_models-attention/dataset_size_10000-p_relevant_context_0.9-input_noise_std_0.0-seed_0-09-25-24_22_02_30-e889951d-e10c-4e8a-96d7-b646b4e50249"
# learner_path = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/omniglot_models-attention/dataset_size_1000000-p_relevant_context_0.9-input_noise_std_0.0-seed_0-09-25-24_22_04_21-97704873-1c50-4b22-83e3-64b89b5fd30a"

In [5]:
model_iter = iterate_models(
    learner_path
)

params_init, model, checkpoint_step_init = next(model_iter)

for params_next, _, checkpoint_step_next in model_iter:
    pass

In [6]:
num_samples = 20

samples_path = "/Users/chanb/research/ualberta/icl/cc_results/paper_experiments/omniglot_models-attention/samples.pkl"
if os.path.isfile(samples_path):
    samples = pickle.load(open(samples_path, "rb"))
else:
    samples = []
    config_dict, config = load_config(learner_path)
    config_dict["batch_size"] = 1
    config_dict["dataset_kwargs"]["flip_label"] = 1
    config_dict["dataset_kwargs"]["p_high"] = 0.5
    config = parse_dict(config_dict)
    print(config)

    train_data_loader, train_dataset = get_data_loader(
        config
    )
    for batch in train_data_loader:
        # if np.argmax(batch["target"][0, -1], axis=-1) == 1:
        #     break
        labels = np.argmax(batch["target"], axis=-1)
        # if np.sum(labels[:, :-1] == labels[:, [-1]]) == 0:
        #     break
        if np.sum(labels[:, :-1] == labels[:, [-1]]) > 0:
            samples.append(batch)

        # print(labels)
        # if labels[0, -1] <= 4 and np.sum(labels[:, :-1] == labels[:, [-1]]) == 0:
        #     break

        if len(samples) == num_samples:
            break
    pickle.dump(samples, open(samples_path, "wb"))

In [7]:
results = []

for batch in samples:
    result = dict()
    result["label"] = np.argmax(batch["target"], axis=-1)
    for params, checkpoint_step in zip(
        [params_init, params_next],
        [checkpoint_step_init, checkpoint_step_next]
    ):
        out, aux = model.get_attention(
            params[CONST_MODEL],
            batch,
            eval=True,
        )

        result[checkpoint_step] = dict()
        for block in aux["gpt"]["intermediates"]:
            if not block.startswith("GPTBlock_"):
                continue

            # axis=1 -> query
            # axis=2 -> key
            self_attention_map = jax.nn.softmax(aux["gpt"]["intermediates"][block]["SelfAttentionModule_0"]["attention"][0][0])
            self_attention_map = self_attention_map.at[self_attention_map <= -1e10].set(jnp.nan)

            attention_score = jnp.sum(aux["gpt"]["intermediates"][block]["attention"][0][0], axis=-1)
            input_vector = jnp.sum(aux["gpt"]["intermediates"][block]["input"][0][0], axis=-1)
            block_output = jnp.sum(aux["gpt"]["intermediates"][block]["block_out"][0][0], axis=-1)

            result[checkpoint_step][block] = dict(
                self_attention_map=self_attention_map,
                attention_score=attention_score,
                input_vector=jnp.vstack(
                    (
                        np.argmax(batch["target"][0], axis=-1),
                        input_vector[::2],
                        jnp.concatenate((input_vector[1::2], jnp.array([jnp.nan])), axis=-1),
                    )
                ),
                block_output=jnp.vstack(
                    (
                        np.argmax(batch["target"][0], axis=-1),
                        block_output[::2],
                        jnp.concatenate((block_output[1::2], jnp.array([jnp.nan])), axis=-1),
                    )
                ),
            )
        
        pred, _ = model.forward(
            params[CONST_MODEL],
            batch,
            eval=True,
        )
        result[checkpoint_step]["prediction"] = pred

    results.append(result)

    if len(results) == num_samples:
        break

In [None]:
pz.ts.display(result)

In [None]:
len(results)

In [None]:
num_cols = min(10, num_samples)
num_rows = len(results) // num_cols * 3

fig, axes = plt.subplots(num_rows, num_cols, figsize=set_size(500, 1, (num_rows, num_cols), use_golden_ratio=False), layout="constrained")

for res_i, curr_res in enumerate(results):
    curr_col = res_i % num_cols
    ax = axes[3 * (res_i // num_cols), curr_col]
    labels = curr_res["label"]
    attention_map = curr_res[50000]["GPTBlock_0"]["self_attention_map"]
    # attention_map = (np.nanmax(attention_map) - attention_map) / (np.nanmax(attention_map) - np.nanmin(attention_map))
    im = ax.imshow(attention_map[0])
    loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)
    ax.yaxis.set_major_locator(loc)
    ax.set_title("$y$: {}".format(labels[..., 0].item()))

    ax.set_xticks([])
    if curr_col > 0:
        ax.set_yticks([])

    ax = axes[3 * (res_i // num_cols) + 1, res_i % num_cols]
    attention_map = curr_res[50000]["GPTBlock_1"]["self_attention_map"]
    # attention_map = (np.nanmax(attention_map) - attention_map) / (np.nanmax(attention_map) - np.nanmin(attention_map))
    im = ax.imshow(attention_map[0])
    loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)
    ax.yaxis.set_major_locator(loc)

    if curr_col > 0:
        ax.set_yticks([])

    ax = axes[3 * (res_i // num_cols) + 2, res_i % num_cols]
    pred = curr_res[50000]["prediction"]
    # attention_map = (np.nanmax(attention_map) - attention_map) / (np.nanmax(attention_map) - np.nanmin(attention_map))
    im = ax.imshow(jax.nn.softmax(pred))
    loc = plticker.MultipleLocator(base=10.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)
    ax.set_yticks([])
    ax.set_title("$\\hat{{y}}: {}$".format(np.argmax(pred, axis=-1).item()))

# axis=1 -> query
# axis=2 -> key
cax, kw = mpl.colorbar.make_axes([ax for ax in axes.flat])
plt.colorbar(im, cax=cax, **kw)
plt.savefig("attention-{}.pdf".format(os.path.basename(learner_path)), dpi=600, format="pdf", bbox_inches="tight")

In [11]:
plt.close(fig)

In [None]:
assert 0

In [None]:
num_cols = 10
num_rows = len(results) // num_cols * 3

fig, axes = plt.subplots(num_rows, num_cols, figsize=set_size(1500, 1, (num_rows, num_cols), use_golden_ratio=False))

for res_i, curr_res in enumerate(results):
    ax = axes[3 * (res_i // num_cols), res_i % num_cols]
    labels = curr_res["label"]
    attention_map = curr_res[50000]["GPTBlock_0"]["self_attention_map"]
    # attention_map = (np.nanmax(attention_map) - attention_map) / (np.nanmax(attention_map) - np.nanmin(attention_map))
    im = ax.imshow(attention_map[0])
    loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)
    ax.yaxis.set_major_locator(loc)
    ax.set_title("sample: {} - {}".format(res_i, labels[..., 0].item()))

    ax = axes[3 * (res_i // num_cols) + 1, res_i % num_cols]
    attention_map = curr_res[50000]["GPTBlock_1"]["self_attention_map"]
    # attention_map = (np.nanmax(attention_map) - attention_map) / (np.nanmax(attention_map) - np.nanmin(attention_map))
    im = ax.imshow(attention_map[0])
    loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)
    ax.yaxis.set_major_locator(loc)

    ax = axes[3 * (res_i // num_cols) + 2, res_i % num_cols]
    pred = curr_res[50000]["prediction"]
    # attention_map = (np.nanmax(attention_map) - attention_map) / (np.nanmax(attention_map) - np.nanmin(attention_map))
    im = ax.imshow(jax.nn.softmax(pred))
    loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
    ax.xaxis.set_major_locator(loc)
    ax.set_yticks([])

# axis=1 -> query
# axis=2 -> key
cax, kw = mpl.colorbar.make_axes([ax for ax in axes.flat])
plt.colorbar(im, cax=cax, **kw)

In [15]:
plt.close(fig)