# Check Attention Layers on MNIST Dataset

## Imports

In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.datasets.wrappers import (
    ContextDataset,
    StandardSupervisedDataset,
    FixedLengthContextDataset,
    RepeatedContextDataset,
)
from jaxl.models import load_config, load_model, get_model, get_activation
from jaxl.models.modules import GPTModule
from jaxl.plot_utils import set_size
from jaxl.utils import parse_dict, get_device

import _pickle as pickle
import copy
import jax
import jax.random as jrandom
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import torchvision.datasets as torch_datasets

from collections import OrderedDict
from functools import partial
from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader, SequentialSampler
from types import SimpleNamespace

# Setup

In [None]:
# device = "cpu"
device = "gpu:0"
get_device(device)

In [None]:
doc_width_pt = 750.0

base_path = "/home/bryanpu1/projects/icl/jaxl/"
data_path = os.path.join(base_path, "data")
log_path = os.path.join(base_path, "jaxl/logs")
project_name = "icl-mnist"
ablation_name = "include_query_class-random_label"
run_name = (
    # "default-03-17-24_13_01_35-af1cc36f-8698-4a61-a16a-4d2c726a22b9"
    # "variable_len-03-18-24_10_21_12-2ac5a55c-b1cf-448b-945b-e6b8f821431f"
    "variable_len-include_query_class-03-18-24_10_21_18-31a5371d-c10d-4595-a26e-1b0158a4b0d4"
)

learner_path = os.path.join(
    log_path,
    project_name,
    ablation_name,
    run_name,
)

exp_name = "-".join(run_name.split("-")[:-8])

## Experiment Configuration

In [None]:
config_dict, config = load_config(learner_path)
fixed_length = config.learner_config.dataset_config.dataset_wrapper.type in ["FixedLengthContextDataset"]

In [None]:
config

## Load Train Dataset and Model

In [None]:
train_dataset = get_dataset(
    config.learner_config.dataset_config,
    config.learner_config.seeds.data_seed,
)
data_loader = DataLoader(
    train_dataset,
    batch_size=1,
    sampler=SequentialSampler(train_dataset),
    drop_last=False,
    num_workers=0,
)

In [None]:
params, model = load_model(
    train_dataset.input_dim, train_dataset.output_dim, learner_path, -1
)

In [None]:
num_tasks = 1
max_label = CONST_AUTO

In [None]:
config

In [None]:
gpt = GPTModule(
    num_blocks=config.model_config.num_blocks,
    num_heads=config.model_config.num_heads,
    embed_dim=config.model_config.embed_dim,
    widening_factor=1,
)

In [None]:
params[CONST_MODEL_DICT][CONST_MODEL].keys()

In [None]:

all_preds = []
all_labels = []
all_outputs = []
num_query_class_in_context = []

for batch_i, samples in enumerate(data_loader):
    if batch_i >= num_tasks:
        break

    (context_inputs, context_outputs, queries, one_hot_labels) = samples

    in_tokens, _, _ = model.tokenize(
        params[CONST_MODEL_DICT][CONST_MODEL],
        queries.numpy(),
        {
            CONST_CONTEXT_INPUT: context_inputs.numpy(),
            CONST_CONTEXT_OUTPUT: context_outputs.numpy(),
        },
        eval=True,
    )

    out_tokens, latents = gpt.apply(
        params[CONST_MODEL_DICT][CONST_MODEL][CONST_GPT],
        in_tokens,
        eval=True,
        capture_intermediates=True
    )

In [None]:
print(latents["intermediates"]["GPTBlock_{}".format(block_i)].keys())

In [None]:
all_tokens = [in_tokens]
for block_i in range(config.model_config.num_blocks):
    all_tokens.append(latents["intermediates"]["GPTBlock_{}".format(block_i)]["SelfAttention_0"]["out"]["__call__"][0])

In [None]:
num_examples = 4
for example_i in range(num_examples):
    ci, co, q, l = train_dataset[
        example_i
    ]

    nrows = 2
    ncols = 8
    fig, axes = plt.subplots(
        nrows,
        ncols + 1,
        figsize=set_size(doc_width_pt, 0.95, (nrows, ncols), False),
        layout="constrained",
    )

    for idx, (img, label) in enumerate(zip(ci, co)):
        axes[idx // ncols, idx % ncols].imshow(img)
        axes[idx // ncols, idx % ncols].set_title(np.argmax(label))
        axes[idx // ncols, idx % ncols].axis('off')
    axes[0, -1].axis('off')
    axes[1, -1].axis('off')
    axes[1, -1].imshow(q[0])
    axes[1, -1].set_title(np.argmax(l, axis=-1))
    plt.show()
    plt.close()

In [None]:
def cosine_distance(a, b):
    return a @ b.T

dist = jax.vmap(cosine_distance)

for plot_i, (in_tokens, out_tokens) in enumerate(zip(all_tokens[:-1], all_tokens[1:])):
    for sample in dist(in_tokens, out_tokens):
        print(sample[-1])
        ax = sns.heatmap(sample, linewidth=0.5)
        plt.title("Layer {}".format(plot_i))
        plt.xlabel("output tokens")
        plt.ylabel("input tokens")
        plt.show()