# BertViz Interactive Demo
## **Scroll down for pre-loaded visualizations** 👇👇👇👇👇👇


In [None]:
!pip3 install torchmetrics==0.4.1
!pip3 install transformers==4.8.2
!pip3 install pytorch_lightning==1.3.8

In [None]:
!pip install bertviz

In [13]:
# Load model and retrieve attention weights

from bertviz import head_view, model_view
from transformers import BertTokenizer, BertModel
from transformers import T5ForConditionalGeneration, AutoTokenizer

model = T5ForConditionalGeneration.from_pretrained('jenspt/byt5_ft_all_clean_data_lr_1e4', output_attentions=True) #
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')

# get encoded input vectors
encoder_input_ids = tokenizer("hamster smuglere hedder conni", return_tensors="pt", add_special_tokens=True).input_ids

# create ids of encoded input vectors
decoder_input_ids = tokenizer("hamstersmugleren hedder connie", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

In [19]:
model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens= decoder_text,
    include_layers = [0,1,2,3]#[1,2,3,4,5,6,7,8,9,10,11],
    #include_heads =1
)

Output hidden; open in https://colab.research.google.com to view.

In [14]:
head_view(encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens= decoder_text,
    include_layers = [0,1,2,3])

Output hidden; open in https://colab.research.google.com to view.

# Model View
The model view gives a birds-eye view of attention across all of the layers (rows) and heads (columns) in the model. In this case we are showing *bert-base*, which has 12 layers and 12 heads (zero-indexed). 

## Usage
* **Click** on any **cell** for a detailed view of attention for the associated attention head.
* Then **hover** over any **token** on the left side of detail view to filter the attention from that token.
* The lines show the attention from each token (left) to every other token (right). Darker lines indicate higher attention weights.  

# Head View
The attention-head view visualizes attention in one or more heads in a particular layer in the model.

## Usage
* **Hover** over any **token** on the left/right side of the visualization to filter attention from/to that token. The colors correspond to different attention heads.
* **Double-click** on any of the **colored tiles** at the top to filter to the corresponding attention head.
* **Single-click** on any of the **colored tiles** to toggle selection of the corresponding attention head. 
* **Click** on the **Layer** drop-down to change the model layer (zero-indexed).
* The lines show the attention from each token (left) to every other token (right). Darker lines indicate higher attention weights. When multiple heads are selected, the attention weights are overlaid on one another. 

In [18]:
print("hamster smuglere hedder conni -  hamstersmugleren hedder connie")



print("<extra_id_0> hamster <extra_id_1> smuglere hedder conni - hamstersmugleren")
print("hamster <extra_id_0> smuglere <extra_id_1> hedder conni - ")
print("hamster smuglere <extra_id_0> hedder <extra_id_1> conni - hedder")
print("hamster smuglere hedder <extra_id_0> conni <extra_id_1> - connie")

hamster smuglere hedder conni -  hamstersmugleren hedder connie
<extra_id_0> hamster <extra_id_1> smuglere hedder conni - hamstersmugleren
hamster <extra_id_0> smuglere <extra_id_1> hedder conni - 
hamster smuglere <extra_id_0> hedder <extra_id_1> conni - hedder
hamster smuglere hedder <extra_id_0> conni <extra_id_1> - connie
