# Visualing the Attention Maps

This experiment is in visualizing the attention maps in GPT-2. I want to see
- which token the different heads of a single layer pay attention to
- which token the same heads in different layers pay attention to. 

In [8]:
import torch as t
import pandas as pd
from transformers import AutoModel, AutoTokenizer, GPT2Model, GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
import altair as alt

device = "cuda" if t.cuda.is_available() else "cpu"
device = "mps" if t.backends.mps.is_available() else "cpu"

In [9]:
model_id = "openai-community/gpt2-medium"
model = AutoModel.from_pretrained(model_id, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [10]:
txt = "The cat ate the mat. And then sat on a rat. But what he wanted was a hat."
tokens = tokenizer(txt)
inputs = tokenizer(txt, return_tensors='pt').input_ids
inputs.shape

torch.Size([1, 21])

In [11]:
attns = model(inputs).attentions
first_layer_attn = attns[0]
first_layer_attn.shape # B, N_h, L_s, L_s

torch.Size([1, 16, 21, 21])

## Attention Map Functions 

In [12]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    "convert a dense matrix to a data frame with row and column indices"
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s"
                % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s"
                % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        # if float(m[r,c]) != 0 and r < max_row and c < max_col],
        columns=["row", "column", "value", "row_token", "col_token"],
    )


def attn_map(attn, layer, head, row_tokens, col_tokens, max_dim=30):
    df = mtx2df(
        attn[0, head].data,
        max_dim,
        max_dim,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        .properties(height=400, width=400)
        .interactive()
    )

In [17]:
def visualize_layer(attn, layer, ntokens, row_tokens, col_tokens):
    # ntokens = last_example[0].ntokens
    n_heads = attn.shape[1]
    charts = [
        attn_map(
            attn,
            0,
            h,
            row_tokens=row_tokens,
            col_tokens=col_tokens,
            max_dim=ntokens,
        )
        for h in range(n_heads)
    ]
    assert n_heads == 16
    return alt.vconcat(
        charts[0]
        # | charts[1]
        | charts[2]
        # | charts[3]
        | charts[4]
        # | charts[5]
        | charts[6]
        # | charts[7]
        # layer + 1 due to 0-indexing
    ).properties(title=f"Layer {layer}")

In [18]:
visualize_layer(first_layer_attn, 0, 21, inputs.flatten().tolist(), inputs.flatten().tolist())