# Visualing the Attention Maps with Mixtral 8x7B

This experiment is in visualizing the attention maps in GPT-2. I want to see
- which token the different heads of a single layer pay attention to
- which token the same heads in different layers pay attention to. 

In [2]:
import torch as t
import pandas as pd
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import altair as alt

import os
from dotenv import load_dotenv
from huggingface_hub import login

In [3]:
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
login(HF_TOKEN)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/kj3moraes/.cache/huggingface/token
Login successful


In [8]:
model_id = "mistralai/Mixtral-8x7B-v0.1"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=t.float16
)

model = AutoModel.from_pretrained(model_id, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_prefix_space=True)

Downloading shards:   0%|          | 0/19 [00:00<?, ?it/s]

model-00001-of-00019.safetensors:   0%|          | 0.00/4.89G [00:00<?, ?B/s]

## Attention Map Functions 

In [None]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    "convert a dense matrix to a data frame with row and column indices"
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%s"
                % row_tokens[r] if len(row_tokens) > r else "<blank>",
                "%s"
                % col_tokens[c] if len(col_tokens) > c else "<blank>",
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        # if float(m[r,c]) != 0 and r < max_row and c < max_col],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def visualize_head(attn, head, row_tokens, col_tokens, max_dim=30):
    df = mtx2df(attn[0, head].data, max_dim, max_dim, row_tokens, col_tokens)
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        .properties(title=f"Head {head}", height=200, width=200)
        .interactive()
    )

def visualize_layer(attn, layer, heads, ntokens, row_words, col_words):
    charts = [
        visualize_head(
            attn,
            h,
            row_tokens=row_words,
            col_tokens=col_words,
            max_dim=ntokens,
        )
        for h in heads
    ]
    return alt.hconcat(*charts).properties(title=f"Layer {layer}")

    
def visualize_model_attns(model, tokenizer, text: str, view_layers:list=[], view_heads:list=[]): 
    """ Given a model, a tokinizer and an input string, will output the attentin heat maps for the specified layers 
        and mutli-attention heads. 

    Args:
        model: Model to run the text through 
        tokenizer: Tokenizer for the text 
        text (str): Text to be visualized 
        view_layers (list, optional): The layers that you want displayed. Every element must be 0 <= l < N_LAYERS. Defaults to [].
        view_heads (list, optional): The heads that you want displayed. Every element must be 0 <= l < N_HEADS. Defaults to [].

    Returns:
        altair heat map. 
    """

    tokens = tokenizer.encode(text, return_tensors='pt')
    n_tokens = tokens.size(-1)
    words = tokenizer.convert_ids_to_tokens(tokens[0])
    attns = model(tokens).attentions
    
    layer_maps = [] 
    for layer_num in view_layers:
        layer_maps.append(visualize_layer(attns[layer_num], layer_num, view_heads, n_tokens, words, words)) 

    return  alt.vconcat(*layer_maps)