# Experiment 12 attention map visualization

In [1]:
import random

import torch
import matplotlib.pyplot as plt

from arithmetic_lm.model import TransformerDecoder, generate
from arithmetic_lm.tokenizer import CharTokenizer
from arithmetic_lm.interp import plot_attn_maps, plot_module
from arithmetic_lm.constants import PLOTS_DIR

import warnings

warnings.filterwarnings("ignore")

In [2]:
tokenizer = CharTokenizer()

In [3]:
def load_model(ckpt_path: str) -> tuple[torch.nn.Module, dict]:
    # load model
    ckpt = torch.load(ckpt_path, map_location="mps")
    model = TransformerDecoder(
        **ckpt["hyper_parameters"]["model_hparams"],
        # vocab_size=tokenizer.vocab_size,
    )
    # state dict has a prefix "model." in the key names
    model.load_state_dict({k[6:]: v for k, v in ckpt["state_dict"].items()})
    model.eval()
    return model, ckpt["hyper_parameters"]

In [4]:
ckpt_path = "../checkpoints/addition-generalize-to-longer/trans_dec_6layers_768embd_4head_cot/step670000-train_loss1.4532-val_loss1.4517.ckpt"

In [5]:
model, hparams = load_model(ckpt_path)

In [6]:
module_names = [
    "transformer_encoder.layers[0].self_attn",
    "transformer_encoder.layers[1].self_attn",
    "transformer_encoder.layers[2].self_attn",
    "transformer_encoder.layers[3].self_attn",
    "transformer_encoder.layers[4].self_attn",
    "transformer_encoder.layers[5].self_attn",
]
figsize = (14, 8)

In [19]:
def get_attention_map(name: str, cache: dict):
    def hook(module, inputs, output):
        # nn.MultiheadAttention outputs 2 tensors by default:
        # - the output of the last linear transformation with shape [bs, tgt_len, embed_dim]
        # - the attention map (weights) with shape [bs, n_heads, tgt_len, src_len]
        # keeps only last output, which is fine for our purposes
        cache[name] = output[1].detach()
        print("HOOK CALLED")

    return hook


def set_attn_kwargs_prehook(module, args, kwargs):
    """
    make sure self.attn module is called with need_weights=True and
    average_attn_weights=False so that we get per-head attention weights
    """
    kwargs["need_weights"] = True
    kwargs["average_attn_weights"] = False
    return args, kwargs


def generate_hooked(
    model: torch.nn.Module,
    prompt: torch.Tensor,
    stop_token: int,
    hook_config: dict[str, dict[str, callable]],
    decoder_prompt: torch.Tensor = None,
) -> str:
    model.eval()

    handles = []

    for module_name, hook_dict in hook_config.items():
        module = eval(f"model.{module_name}", {"model": model})

        if pre_hook := hook_dict.get("pre_hook"):
            handles.append(module.register_forward_pre_hook(pre_hook, with_kwargs=True))

        if hook := hook_dict.get("hook"):
            handles.append(module.register_forward_hook(hook))

    # HACK: encode, since just calling generate does not call
    # forward hook in the encoder for some weird reason (decoder hooks work fine)
    if model.enc_dec:
        model.encode(prompt)

    pred_tensor = generate(
        model,
        idx=decoder_prompt if model.enc_dec else prompt,
        encoder_source=prompt if model.enc_dec else None,
        max_new_tokens=100,
        stop_token=stop_token,
    )

    # remove hooks
    for handle in handles:
        handle.remove()

    return pred_tensor


def plot_head(
    ax: plt.Axes,
    map: torch.Tensor,
    title: str,
    cmap: str = "binary",
    xticks: list = None,
    yticks: list = None,
    colorbar: bool = False,
    alpha: float = 1.0,
):
    ax.imshow(map, cmap=cmap, interpolation="none", alpha=alpha)
    if yticks:
        ax.set_yticks(np.arange(len(yticks)) - 0.5)
        ax.set_yticklabels(yticks, va="top")
    if xticks:
        ax.set_xticks(np.arange(len(xticks)) - 0.5)
        ax.set_xticklabels(xticks, ha="left")
    ax.set_title(title)
    if colorbar:
        ax.figure.colorbar(ax.images[0], ax=ax, shrink=0.3)
    # grid
    ax.grid(which="both", color="k", linestyle=":", linewidth=0.5, alpha=0.5)
    ax.set_xlabel("source")
    ax.set_ylabel("target")


def plot_module(
    fig: plt.Figure,
    module_name: str,
    attn_map: torch.Tensor,
    ticks: list[str],
    plot_combined: bool = True,
):
    n_heads = attn_map.shape[1]
    axs = fig.subplots(1, n_heads + 1 if plot_combined else n_heads)
    fig.suptitle(module_name)

    # choose cmaps for combined attn map
    cmaps = ["Reds", "Blues", "Purples", "Greens", "Oranges"]

    for i in range(n_heads):
        plot_head(
            axs[i],
            attn_map[0, i],
            title=f"head {i}",
            xticks=ticks,
            yticks=ticks,
        )
        # combined attn map
        if plot_combined:
            plot_head(
                axs[-1],
                attn_map[0, i],
                title="combined",
                cmap=cmaps[i % len(cmaps)],
                alpha=0.5,
                xticks=ticks,
                yticks=ticks,
                colorbar=False,
            )
    # rotate yticks
    for ax in axs:
        ax.tick_params(axis="y", rotation=90)


def plot_attn_maps(
    model: torch.nn.Module,
    tokenizer,
    a: int,
    b: int,
    module_names: list[str],
    savepath: str,
    pad_zeros: int = 0,
    filler_tokens_prompt: int = 0,
    save: bool = False,
    figsize: tuple[int, int] = (8, 8),
    reverse_ops: bool = False,
    reverse_ans: bool = False,
    figtitle_prefix: str = "",
) -> dict[str, torch.Tensor]:
    astr = str(a)
    bstr = str(b)

    if reverse_ops:
        astr = astr[::-1]
        bstr = bstr[::-1]

    prompt_str = (
        f"${'.' * filler_tokens_prompt}{astr.zfill(pad_zeros)}+{bstr.zfill(pad_zeros)}="
    )
    # prompt_str = "\n" + prompt_str
    print("prompt:", repr(prompt_str), f"{len(astr)}+{len(bstr)}")
    true_ans = str(a + b)
    if reverse_ans:
        true_ans = true_ans[::-1]
    print("true_ans:", true_ans)

    prompt = torch.tensor([tokenizer.encode(prompt_str)])
    stop_token_id = tokenizer.encode("$")[0]

    attn_maps = {}

    # generate answer
    pred_tensor = generate_hooked(
        model,
        prompt=prompt,
        stop_token=stop_token_id,
        hook_config={
            mn: {
                "hook": get_attention_map(mn, attn_maps),
                "pre_hook": set_attn_kwargs_prehook,
            }
            for mn in module_names
        },
    )

    pred_answer_str = tokenizer.decode(pred_tensor[0].tolist())
    pred_answer_num = "".join(c for c in pred_answer_str if c.isdigit())
    print("pred_answer:", pred_answer_str)

    for mn, matts in attn_maps.items():
        print(mn, matts.shape)

    # tokens for easier visualization
    ticks = list(prompt_str + pred_answer_str)
    ticks[0] = "\\n" if ticks[0] == "\n" else ticks[0]

    # for each module, in a subfigure plot heads as subplots
    fig = plt.figure(layout="constrained", figsize=figsize)
    fig.suptitle(
        f"{figtitle_prefix} Attention maps for prompt: {repr(prompt_str).replace('$', '\$')}, [{len(astr)}+{len(bstr)}]"
        f"\n predicted answer: {repr(pred_answer_str).replace('$', '\$')} ({'correct' if pred_answer_num == true_ans else 'incorrect, true: ' + true_ans})",
    )

    subfigs = fig.subfigures(len(attn_maps), 1, hspace=0, wspace=0)
    for i, (module_name, attn_map) in enumerate(attn_maps.items()):
        plot_module(subfigs[i], module_name, attn_map, ticks)

    if save:
        plt.savefig(savepath, dpi=90)
    plt.show()

    return attn_maps


attn_maps = plot_attn_maps(
    model=model,
    tokenizer=tokenizer,
    a=123,
    b=456,
    module_names=module_names,
    figsize=figsize,
    savepath=str(""),
    figtitle_prefix="(CoT + finetuned)",
    reverse_ops=False,
    reverse_ans=False,
    save=False,
)

prompt: '$123+456=' 3+3
true_ans: 579
pred_answer: 579$


ValueError: Number of rows must be a positive integer, not 0

<Figure size 1400x800 with 0 Axes>

In [7]:
# def eval_answer(model: torch.nn.Module, tokenizer, prompt: str, answer: str) -> bool:
#     """Return whether the model predicts the correct answer."""

#     prompt_tokens = torch.tensor(tokenizer.encode(prompt))
#     stop_token_id = tokenizer.encode("$")[0]

#     pred_ans = generate(
#         model, idx=prompt_tokens, max_new_tokens=20, stop_token=stop_token_id
#     )

#     pred_ans = tokenizer.decode(pred_ans[0])
#     pred_ans = pred_ans.strip("$")
#     return pred_ans == answer

In [8]:
# # find failure cases
# while True:
#     a = random.randint(10**5, 10**6)
#     b = random.randint(10**5, 10**6)
#     prompt = f"${a}+{b}="
#     true_ans = str(a + b)
#     if not eval_answer(model_before, tokenizer, prompt, true_ans) and not eval_answer(
#         model_after, tokenizer, prompt, true_ans
#     ):
#         break

# print(f"prompt: {prompt}")

In [9]:
subdir = PLOTS_DIR / "exp_15"
subdir.mkdir(exist_ok=True)

In [10]:
a, b = 123456, 678901
savepath = subdir / f"exp15_attention_maps_{a}+{b}_cot_finetuned.png"
kwargs = dict(
    tokenizer=tokenizer,
    a=a,
    b=b,
    module_names=module_names,
    figsize=figsize,
    # save=True,
)
attn_maps = plot_attn_maps(
    model=model,
    savepath=str(savepath),
    figtitle_prefix="(CoT + finetuned)",
    reverse_ops=False,
    reverse_ans=False,
    **kwargs,
)

prompt: '$123456+678901=' 6+6
true_ans: 802357
pred_answer: 802357$


ValueError: Number of rows must be a positive integer, not 0

<Figure size 1400x800 with 0 Axes>