In [2]:
import torch
from transformers import GPT2Tokenizer, GPT2Model
from circuitsvis import attention
from IPython.display import display

# Load model and tokenizer
model_name = "gpt2"
model = GPT2Model.from_pretrained(model_name, output_attentions=True)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.eval()

# Define sentences to visualize attention
sentences = [
    "A goose (PL: geese) is a bird of any of several waterfowl species in the family Anatidae.",
    "This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).",
    "Some other birds, mostly related to the shelducks, have 'goose' as part of their names."
]

# Loop through each sentence and visualize attention maps for specified layers
for sentence in sentences:
    print(f"\nAnalyzing sentence: '{sentence}'")
    inputs = tokenizer(sentence, return_tensors="pt")
    outputs = model(**inputs)
    attentions = outputs.attentions

    # Visualize attention for specified layers directly in Jupyter
    for layer_index in [1, 8, 10]:  # Choose layers to inspect
        attention_data = attentions[layer_index].squeeze(0).detach().numpy()
        
        # Check the shape and a sample of attention data
        print(f"Layer {layer_index} attention data shape: {attention_data.shape}")
        
        # Display the visualization inline
        tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().tolist())
        html_vis = attention.attention_heads(tokens=tokens, attention=attention_data)
        
        # Render in Jupyter notebook
        display(html_vis)  # This will display the HTML visualization inline



Analyzing sentence: 'A goose (PL: geese) is a bird of any of several waterfowl species in the family Anatidae.'
Layer 1 attention data shape: (12, 25, 25)


Layer 8 attention data shape: (12, 25, 25)


Layer 10 attention data shape: (12, 25, 25)



Analyzing sentence: 'This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).'
Layer 1 attention data shape: (12, 27, 27)


Layer 8 attention data shape: (12, 27, 27)


Layer 10 attention data shape: (12, 27, 27)



Analyzing sentence: 'Some other birds, mostly related to the shelducks, have 'goose' as part of their names.'
Layer 1 attention data shape: (12, 23, 23)


Layer 8 attention data shape: (12, 23, 23)


Layer 10 attention data shape: (12, 23, 23)
