# Explore Third-Party Models
Since we don't have time to train a transformer, this notebook is designed to allow users to explore models trained by third parties.

In [None]:
import torch, torch.nn.functional as F
from transformers import pipeline

from bertviz import head_view, neuron_view, model_view
import matplotlib.pyplot as plt
import seaborn as sns

## GPT-2
An early predecessor to GPT-3.5, which took the world by storm with the release of ChatGPT in 2022.

References:
- Radford, Alec, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language Models Are Unsupervised Multitask Learners. February 14, 2019. https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf
- Original blog post: https://openai.com/index/better-language-models/
- https://en.wikipedia.org/wiki/GPT-2
- https://github.com/openai/gpt-2

In [None]:
from transformers import GPT2Tokenizer, GPT2Model

In [None]:
#load model
model_id = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2Model.from_pretrained(model_id, output_attentions=True)

### Visualize Attention

In [None]:
#prepare input (keep the sentence relatively short or the visualizations below might become slow)
sentence = "The robot must obey the orders given it by human beings."
inputs = tokenizer.encode(sentence, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

In [None]:
#get attention weights
outputs = model(inputs)
attentions = outputs.attentions

In [None]:
#interactive visualization (requires, e.g., jupyter)
head_view(attentions, tokens)

In [None]:
#alternative, non-interactive visualization
layer =  0 #depth (valid range: 0 to 11)
head  =  0 #width (valid range: 0 to 5)
attention_matrix = attentions[layer][0, head].detach().numpy()

clean_tokens = [t.replace('Ġ', ' ') for t in tokens] #replace funky character representing space

plt.figure(figsize=(5, 4))
sns.heatmap(attention_matrix, xticklabels=clean_tokens, yticklabels=clean_tokens, cmap='viridis')
plt.title(f"Attention Heatmap: Layer {layer}, Head {head}")
plt.show()

In [None]:
#interactive visualization (requires, e.g., jupyter)
model_view(attentions, tokens, display_mode="light")

In [None]:
#this is a very cool visualization of attetion and the queries and keys that make it up
#however, it can take some time to run
from bertviz.transformers_neuron_view import GPT2Model, GPT2Tokenizer
model_type = 'gpt2'
model_version = 'gpt2'
model     = GPT2Model.from_pretrained(model_version)
tokenizer = GPT2Tokenizer.from_pretrained(model_version)
neuron_view.show(model, model_type, tokenizer, sentence, display_mode='dark')

### Make predictions

In [None]:
# Load the text generation pipeline
generator = pipeline('text-generation', model='gpt2')

# Generate text
prompt = "The data scientist decided to"
results = generator(prompt, num_return_sequences=1)

print(results[0]['generated_text'])

### Look at prediction probability

In [None]:
#load model and tokenizer (slightly different signature from before)
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

In [None]:
#tokenize input
prompt = "The data scientist decided to take his data and"
inputs = tokenizer(prompt, return_tensors="pt")

In [None]:
#get model output (logits)
with torch.no_grad():
    outputs = model(**inputs)
    #get logits for the last token in the sequence
    next_token_logits = outputs.logits[0, -1, :]

#convert Logits to probabilities using softmax
probs = F.softmax(next_token_logits, dim=-1)

#get the top n candidates
n = 5
top_n = torch.topk(probs, n)

print(f"Prompt: {prompt}\n")
print(f"{'Token':<15} | {'Probability':<10}")
print("-" * 30)

for score, token_id in zip(top_n.values, top_n.indices):
    token_str = tokenizer.decode([token_id])
    print(f"{token_str:<15} | {score.item() * 100:.2f}%")

## BERT
A foundational transformer-based natural language processing modeling released by Google researchers in 2018.

See:
- Devlin, Jacob, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. “BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding.” arXiv.Org, October 11, 2018. https://arxiv.org/abs/1810.04805v2.
- https://en.wikipedia.org/wiki/BERT_(language_model)

In [None]:
from transformers import AutoTokenizer, AutoModel

In [None]:
#load model and tokenizer
#use output_attentions=True to ensure the model returns the weights
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

### Visualize Attention

In [None]:
#tokenization
sentence = "The cat sat on the mat because it was comfortable."
inputs = tokenizer(sentence, return_tensors="pt")

In [None]:
#compute attention
#model returns a tuple: (last_hidden_state, pooler_output, attentions)
outputs = model(**inputs)
attentions = outputs.attentions  # This is a list of 12 tensors (one per layer)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

In [None]:
#interactive visualization (requires, e.g., jupyter)
head_view(attentions, tokens)

In [None]:
#alternative, non-interactive visualization
layer =  0 #depth (valid range: 0 to 11)
head  =  0 #width (valid range: 0 to 5)
attention_matrix = attentions[layer][0, head].detach().numpy()

plt.figure(figsize=(5, 4))
sns.heatmap(attention_matrix, xticklabels=tokens, yticklabels=tokens, cmap='viridis')
plt.title(f"Attention Heatmap: Layer {layer}, Head {head}")
plt.show()

### Make predictions
BERT is not an auto-regressive model - it looks forwards and backwards - so we need to indicate where to make a prediction with `[MASK]`.

In [None]:
from transformers import pipeline

# 1. Initialize the Fill-Mask pipeline
fill_mask = pipeline("fill-mask", model="bert-base-uncased")

# 2. Define a sentence with a [MASK] token
text = "The data scientist visualized his data using [MASK]."

# 3. Get predictions
results = fill_mask(text)

# 4. Display the top 5 candidates
for res in results:
    print(f"Score: {res['score']:.4f} | Word: {res['token_str']} | Sentence: {res['sequence']}")