In [None]:
import baukit
from baukit import Widget, Property, Trigger, show

import torch, numpy as np
import os, re, json
from matplotlib import cm, pyplot as plt
from matplotlib.colors import TwoSlopeNorm
from transformers import AutoTokenizer

# Include prompt creation helper functions
from utils.prompt_utils import *
from attentionVisualizationWidget import TokenVizWidget, AttnHeadSelectorWidget

# Notes on Using the TokenVizWidget

Usage:
- Use the selectors to pick an attention head. Hover over a token to display its attention pattern and tooltip.
- Click on a token to lock its attention pattern. Click the same token to unlock. 
- If no token is hover-selected or clicked, the default display is the contents of the `default_matrix` array.

Logistics:
- For this example, we have attention data from GPT-J. The model has 28 layers with 16 heads per layer.<br>
- There are two widgets the attn head selector widget and the token viz widget. <br>
- The `TokenVizWidget` requires the tokenized prompt text, `attention matrix` of size `(n_layers,n_heads,n_tokens,n_tokens)`, and a `default matrix` of size `(n_layers,n_heads,n_tokens)` to build.<br> 
- Unless you have a specific default you're interested in, you can just create an `ndarray` of 0's to fill the slot.

- We hook up the two widgets using `current_layer = ahw.prop('current_layer'),current_head=ahw.prop('current_head')`, and `tvw.on('current_layer', update_attn_matrix), tvw.on('current_head', update_attn_matrix)`.


# Google Review Restaurant Extraction

In [None]:
# Load Tokenizer
torch.set_grad_enabled(False)
model_name = r"EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

# Load pre-computed Attention & Prompt Data
attn_matrix = np.load(f"data/restaurant_attention.npy")
prompt_data = json.load(open('data/restaurant_prompt.json', 'r'))

prompt_text = create_prompt(f"If you're in the mood for some quick fast food, Wendy's is always a solid choice.", prompt_data=prompt_data)
token_ids = tokenizer(prompt_text)['input_ids']
text_tokens = [repr(tokenizer.decode(x))[1:-1] for x in token_ids]

# Init config vars
N_TOKENS = len(text_tokens)
N_LAYERS = 28
N_HEADS = 16

# Load Default Matrix
default_matrix = np.zeros((N_LAYERS,N_HEADS,N_TOKENS))

def update_attn_matrix():
    l,h = ahw.prop('current_layer').value, ahw.prop('current_head').value
    new_attn_matrix = attn_matrix[l,h]
    tvw.token_attn = new_attn_matrix.tolist()
    # Update Default Matrix & Corresponding Color Matrix
    update_default_matrix()        

def update_default_matrix():        
    l,h = ahw.prop('current_layer').value, ahw.prop('current_head').value
    new_default_matrix = default_matrix[l,h,:]
    tvw.colors_matrix = tvw.color_sample(new_default_matrix, cm.bwr)
    tvw.default_display = new_default_matrix.tolist()   
    
    
# Initialize Widgets
# head_groupings = {'repeat_token':[(6,2), (3,9), (8,7)],'induction':[(4,0),(8,1),(16,7)], 'prev_token':[(2,11),(3,5)], 'prev_answer':[(13,13), (13,2)]}
ahw = AttnHeadSelectorWidget(n_layers=N_LAYERS,n_heads=N_HEADS)
tvw = TokenVizWidget(text = text_tokens, token_attn=attn_matrix[0,0].tolist(), 
                     default_display=default_matrix[0,0].tolist(), 
                     current_layer = ahw.prop('current_layer'),current_head=ahw.prop('current_head'))

# Setup python-side listeners to update attn matrix & token dependence matrix
tvw.on('current_layer', update_attn_matrix)
tvw.on('current_head', update_attn_matrix)

print("Ready!")

In [None]:
# Display Widgets
show([ahw,tvw])