# Attention in LLM Transformer Architecture


### Some references

* https://towardsdatascience.com/openai-gpt-2-understanding-language-generation-through-visualization-8252f683b2f8
* https://jalammar.github.io/illustrated-gpt2/ (very detailed and somewhat technical)

### Setup

In [None]:
from transformers import AutoTokenizer, AutoModel, utils, AutoModelForCausalLM
from bertviz import model_view, head_view
import torch

#### Set up the model and tokenizer for GPT2

In [None]:

cache_dir='/Commjhub/HF_cache'
utils.logging.set_verbosity_error()  # Suppress standard warnings

model_name = 'gpt2'

gpt2 = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir,
                                            return_dict_in_generate=True)

model = AutoModel.from_pretrained(model_name, 
                                  cache_dir=cache_dir,
                                  output_attentions=True,
                                 )  
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id



#### Input data to process

In [None]:
input_text1 = 'The dog on the ship ran'
input_text2 = 'The motor on the ship ran'

* Tokenize input text

In [None]:
inputs = tokenizer.encode(input_text1, return_tensors='pt')  

* Process the input text using the model and retreive attention weights

In [None]:
# Run model
outputs = model(inputs) 

# Retrieve attention from model outputs
attention = outputs[-1]  

# Convert input ids to token strings
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  

* Check tokenization 
  - Note that GPT2 tokenizer uses a `Ġ` character for whitespace there are some unclear historical reasons for this but it does not effect how the model works or the visualizations below

In [None]:
tokens

In [None]:
model_view(attention, tokens)  # Display model view

In [None]:
head_view(attention, tokens)

In [None]:
with torch.inference_mode():
  outputs = gpt2(inputs)

next_token_logits = outputs.logits[0, -1, :]

next_token_probs = torch.softmax(next_token_logits, -1)

topk_next_tokens= torch.topk(next_token_probs, 5)

for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values):
    print(f"{tokenizer.decode(idx): <20}{prob:.1%}")