# ModernNLP: #3

# Exploring BERT's attention mechanism

* Discussed paper "What does BERT look at?
An Analysis of BERT’s Attention", Clark et al., ACL 2019 (https://www.aclweb.org/anthology/W19-4828.pdf).
* Experimented with a pre-trained BERT.
* Fine-tuned on new data for grammatical error correction (GEC).
* Created attention heatmaps based on the attention weights of each head of each layer.

### Further experiments
* Run the notebook for *humor detection* (not GEC) using the dataset added below.
* Perform attention analysis of the humor detection model.

> Authored by John Pavlopoulos & Vasiliki Kougia



In [None]:
%%capture
!pip install transformers

In [None]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn

import transformers
from sklearn.metrics import *
from transformers import AdamW
from tqdm.notebook import tqdm
from scipy.special import softmax
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split as tts
from transformers import BertTokenizerFast, BertConfig, BertForSequenceClassification, AutoModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

In [None]:
# Define the device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Prepare the data

We ran this notebook using a dataset for grammatical error detection (we thank Katerina Korre for sharing the preprocessed dataset).
However, we could not share this dataset so you can use an available dataset for humor detection instead. Code to download the humor detection dataset is added below and code for the GEC dataset is commented out. No further changes in the code are needed, since the task and the model are the same (BERT for binary text classification).

#### Download

In [None]:
# # Read GEC dataset
# data = pd.read_csv("GEC_sentences.csv")
# print("\nThere are", len(data), "sentences")

# # Assign 1 to the erroneous sentences and zero to the rest
# data["text"], data["label"] = data.sentence, data["type"].apply(lambda x: 1 if x=="erroneous" else 0)
# del data["sentence"]
# print(data.head())


There are 28350 sentences
        type                                               text  label
0    correct                              Dear Sir or Madam ,\n      0
1  erroneous  I am writing in order to express my disappoint...      1
2  erroneous  I saws the show 's advertisement hanging up of...      1
3  erroneous  The problems started in the box office , where...      1
4    correct  Moreover , the show was delayed forty - five m...      0


In [None]:
# Download humor detection data
# Paper: https://arxiv.org/abs/2004.12765
data = pd.read_csv("https://raw.githubusercontent.com/Moradnejad/ColBERT-Using-BERT-Sentence-Embedding-for-Humor-Detection/master/Data/dataset.csv")
print("\nThere are", len(data), "sentences")

# Use the standard text/label columns
# Create labels: 1 --> humorous, 0 --> not humorous
data["label"] = data["humor"].apply(int)
data.head()


There are 200000 sentences


Unnamed: 0,text,humor,label
0,"Joe biden rules out 2020 bid: 'guys, i'm not r...",False,0
1,Watch: darvish gave hitter whiplash with slow ...,False,0
2,What do you call a turtle without its shell? d...,True,1
3,5 reasons the 2016 election feels so personal,False,0
4,"Pasco police shot mexican migrant from behind,...",False,0


#### Split to training, validation and test



In [None]:
# Use a subset for quick experiments
subset_data = data[:10000]

# Split to train, val and test
train, test = tts(subset_data[["text", "label"]], random_state=42, test_size=0.1)
train, val = tts(train, random_state=42, test_size=test.shape[0])

#### Tokenize and encode with BERT tokenizer

In [None]:
# Construct a BERT tokenizer based on WordPiece
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [None]:
# A sanity check of the tokenizer
encoded_instance = bert_tokenizer.batch_encode_plus([train.iloc[0].text], padding=True)
print(encoded_instance)

{'input_ids': [[101, 6737, 2144, 11657, 2135, 1010, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1]]}


In [None]:
print("Original text:", train.iloc[0].text)
print("BERT BPEs:", bert_tokenizer.convert_ids_to_tokens(encoded_instance["input_ids"][0]))

Original text: Yours sincerately ,

BERT BPEs: ['[CLS]', 'yours', 'since', '##rate', '##ly', ',', '[SEP]']


In [None]:
# Set max_len to the maximum length of the training data 
max_len = max([len(bert_tokenizer.encode(s)) for s in train.text.to_list()])
print("The maximum sentence length in training based on BERT BPEs is", max_len)

The maximum sentence length in training based on BERT BPEs is 138


In [None]:
# Tokenize and encode sentences in each set
x_train = bert_tokenizer.batch_encode_plus(
    train.text.tolist(),
    max_length = max_len,
    padding=True,
    truncation=True
)
x_val = bert_tokenizer.batch_encode_plus(
    val.text.tolist(),
    max_length = max_len,
    padding=True,
    truncation=True
)
x_test = bert_tokenizer.batch_encode_plus(
    test.text.tolist(),
    max_length = max_len,
    padding=True,
    truncation=True
)

In [None]:
# Convert lists to tensors in order to feed them to our PyTorch model
train_seq = torch.tensor(x_train['input_ids'])
train_mask = torch.tensor(x_train['attention_mask'])
train_y = torch.tensor(train.label.tolist())

val_seq = torch.tensor(x_val['input_ids'])
val_mask = torch.tensor(x_val['attention_mask'])
val_y = torch.tensor(val.label.tolist())

test_seq = torch.tensor(x_test['input_ids'])
test_mask = torch.tensor(x_test['attention_mask'])
test_y = torch.tensor(test.label.tolist())

In [None]:
batch_size = 32

# Create a dataloader for each set

# TensorDataset: Creates a PyTorch dataset object to load data from
train_data = TensorDataset(train_seq, train_mask, train_y)
# RandomSampler: specify the sequence of indices/keys used in data loading
train_sampler = RandomSampler(train_data)
# DataLoader: a Python iterable over a dataset
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

val_data = TensorDataset(val_seq, val_mask, val_y)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)

test_data = TensorDataset(test_seq, test_mask, test_y)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=1)

## Build and train the model

In [None]:
# Define which BERT model to use
# We will use BERT base pre-trained on uncased text
model_name = "bert-base-uncased"
# The BertForSequenceClassification class creates a model with BERT and a classifier on top
# The classifier is a linear layer with two outputs (two is the default, if you have more labels change the config)
# It uses the CrossEntropyLoss from PyTorch
# from_pretrained() is used to load pre-trained weights
model = BertForSequenceClassification.from_pretrained(model_name, output_attentions=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [None]:
# Training method
def training():
  # Set to train mode
  model.train()
  total_loss, total_accuracy = 0, 0
  # Iterate through the training batches
  for batch in tqdm(train_dataloader, desc="Iteration"):    
    # Push the batch to gpu
    batch = [r.to(device) for r in batch] 
    sent_id, mask, labels = batch
    # Clear gradients 
    model.zero_grad()
    # Get model outputs
    outputs = model(sent_id, attention_mask=mask, labels=labels)
    # Get loss
    loss = outputs.loss
    # Add to the total loss
    total_loss = total_loss + loss
    # Backward pass to calculate the gradients
    loss.backward()
    # Update parameters
    optimizer.step()
  # Compute the training loss of the epoch
  epoch_loss = total_loss / len(train_dataloader)

  return epoch_loss

In [None]:
# Evaluation method
def evaluate():  
  print("\nEvaluating...")  
  # Set to eval mode
  model.eval()
  total_loss, total_accuracy = 0, 0
  predictions, targets = [], []
  # Iterate through the validation batches
  for batch in val_dataloader:
    # Push the batch to gpu
    batch = [t.to(device) for t in batch]
    sent_id, mask, labels = batch
    # Save the gold labels to use them for evaluation
    targets.extend(labels.detach().cpu().numpy())
    # Deactivate autograd
    with torch.no_grad():
      # Get model outputs
      outputs = model(sent_id, attention_mask=mask, labels=labels)
      # Get loss
      loss = outputs.loss
      total_loss = total_loss + loss
      # Apply softmax to the output of the model
      output_probs = softmax(outputs.logits.detach().cpu().numpy(), axis=1)
      # Get the index with the largest probability as the predicted label
      predictions.extend(np.argmax(output_probs, axis=1))
  # Compute the validation loss of the epoch
  epoch_loss = total_loss / len(val_dataloader)

  return epoch_loss, targets, predictions

In [None]:
# Push model to gpu
model = model.to(device)
# Define the optimizer and the learning rate
optimizer = AdamW(model.parameters(), lr = 2e-5)

best_val_loss = float('inf')
best_epoch = -1
train_losses=[]
val_losses=[]
epochs = 100
# Define the number of epochs to wait for early stopping
patience = 3

# Train the model
for epoch in range(epochs):     
  print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))    
  train_loss = training()
  val_loss, val_targets, val_predictions = evaluate()

  train_losses.append(train_loss)
  val_losses.append(val_loss)

  print("\nTraining Loss:", train_loss)
  print("Validation Loss:", val_loss)
  # Calculate the validation F1 score for the current epoch
  f1 = f1_score(val_targets, val_predictions, average="binary")
  print("F1 score:", round(f1, 3))

  # Save the model with the best validation loss
  if val_loss < best_val_loss:
    best_val_loss = val_loss
    best_epoch = epoch
    torch.save(model.state_dict(), 'saved_weights.pt')

  # Early stopping
  if ((epoch - best_epoch) >= patience):
    print("No improvement in", patience, "epochs. Stopped training.")
    break



 Epoch 1 / 100


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=250.0, style=ProgressStyle(description_wi…



Evaluating...

Training Loss: tensor(0.4356, device='cuda:0', grad_fn=<DivBackward0>)
Validation Loss: tensor(0.4149, device='cuda:0')
F1 score: 0.806

 Epoch 2 / 100


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=250.0, style=ProgressStyle(description_wi…



Evaluating...

Training Loss: tensor(0.2851, device='cuda:0', grad_fn=<DivBackward0>)
Validation Loss: tensor(0.4015, device='cuda:0')
F1 score: 0.864

 Epoch 3 / 100


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=250.0, style=ProgressStyle(description_wi…



Evaluating...

Training Loss: tensor(0.1721, device='cuda:0', grad_fn=<DivBackward0>)
Validation Loss: tensor(0.4755, device='cuda:0')
F1 score: 0.83

 Epoch 4 / 100


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=250.0, style=ProgressStyle(description_wi…



Evaluating...

Training Loss: tensor(0.1164, device='cuda:0', grad_fn=<DivBackward0>)
Validation Loss: tensor(0.5378, device='cuda:0')
F1 score: 0.856

 Epoch 5 / 100


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=250.0, style=ProgressStyle(description_wi…



Evaluating...

Training Loss: tensor(0.0838, device='cuda:0', grad_fn=<DivBackward0>)
Validation Loss: tensor(0.6258, device='cuda:0')
F1 score: 0.806
No improvement in 3 epochs. Stopped training.


In [None]:
# Save checkpoint to your drive
# Zip
!zip saved_weights.zip  saved_weights.pt
# Mount
from google.colab import drive
drive.mount('/content/gdrive')
# Copy to your drive folder
!cp -r saved_weights.zip /content/gdrive/MyDrive/

  adding: saved_weights.pt (deflated 7%)
Mounted at /content/gdrive


## Inference

#### Load the saved checkpoint

In [None]:
# Use this code to download the model saved in your drive 
# Add the id from the shareable link of the file 
# !gdown --id add_shareable_link_id
# !unzip saved_weights.zip

In [None]:
# Create the model
model_e = BertForSequenceClassification.from_pretrained("bert-base-uncased", output_attentions=True)
# Load pre-trained weights
checkpoint = torch.load("saved_weights.pt", map_location="cpu")
# Add them to the model
model_e.load_state_dict(checkpoint)
model_e = model_e.to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

#### Get predictions for test

In [None]:
# Predict for the test set and save the results
model_e.eval()
test_predictions = []
test_targets = []
test_attentions = []
test_inputs = []

for batch in test_dataloader:
  batch = [t.to(device) for t in batch]
  sent_id, mask, labels = batch
  # Get gold labels
  test_targets.extend(labels.detach().cpu().numpy())
  # Get input words
  test_inputs.append(bert_tokenizer.convert_ids_to_tokens(sent_id.detach().cpu().numpy()[0]))
  with torch.no_grad():
    # Get predictions
    outputs = model_e(sent_id, attention_mask=mask)
    # Apply softmax to the outputs
    output_probs = softmax(outputs.logits.detach().cpu().numpy(), axis=1)
    # Get the with the highest probability as the predicted label
    test_predictions.extend(np.argmax(output_probs, axis=1))
    # Get attention weights
    # Attention weights from all layers are returned in a tuple
    # The weights from each layer are in a tensor with shape (batch_size, attention_heads, max_len, max_len)
    test_attentions.append(outputs.attentions)

#### Evaluate

In [None]:
print("F1:", f1_score(test_targets, test_predictions, average="binary"))
print("ACC:", accuracy_score(test_targets, test_predictions))
print("AUPR:", average_precision_score(test_targets, test_predictions))
print("AUC:", roc_auc_score(test_targets, test_predictions))

F1: 0.8551842330762639
ACC: 0.831
AUPR: 0.8634273731466714
AUC: 0.8429668492459519


## Attention analysis
* WARNING: The GEC dataset is replaced by a humor detection one. 

In [None]:
# Get attention heatmaps
import matplotlib
from IPython.core.display import display, HTML
def colorize(words, color_array):
    cmap=matplotlib.cm.Reds
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for word, color in zip(words, color_array):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')
    return colored_string

#### What does the CLS token attend to?




In [None]:
# Select some sentences randomly
sent_index = [25, 30, 50]

for s in sent_index:
  print("*" * 100)
  # Get the sentence's words
  tokens = test_inputs[s]
  # For each layer...
  for l in range(12):
    print("\nLayer", l+1)
    attention = np.squeeze(test_attentions[s][l].detach().cpu().numpy(), axis=0)
    # and for each head
    for h, head in enumerate(attention):
      print("Head", h+1)
      # Get the attention for the cls token
      cls_attentions = head[0]
      display(HTML(colorize(tokens, cls_attentions)))
    

****************************************************************************************************

Layer 1
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 2
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 3
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 4
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 5
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 6
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 7
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 8
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 9
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 10
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 11
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 12
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12


****************************************************************************************************

Layer 1
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 2
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 3
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 4
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 5
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 6
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 7
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 8
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 9
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 10
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 11
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 12
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12


****************************************************************************************************

Layer 1
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 2
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 3
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 4
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 5
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 6
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 7
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 8
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 9
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 10
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 11
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 12
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12


#### Visualize attentions for specific types of grammatical errors

In [None]:
# Lack of Subject-Verb Agreement
sentence1 = "She are going to the park."
# Pronoun Disagreement
sentence2 = "Every girl brings their own lunch."

In [None]:
# Encode the first sentence
encoded_sentence1 = bert_tokenizer.batch_encode_plus([sentence1], padding=True)

# Give as input to the model and get the outputs
inputs = torch.tensor(encoded_sentence1["input_ids"]).to(device)
att = torch.tensor(encoded_sentence1["attention_mask"]).to(device)
outputs = model_e(inputs, attention_mask=att)

In [None]:
# Get the predictions
output_probs = softmax(outputs.logits.detach().cpu().numpy(), axis=1)
predictions = (np.argmax(output_probs, axis=1))
print(sentence1, ":", predictions[0])

She are going to the park. : 1


In [None]:
# Visualize the attention heatmaps for the CLS token
tokens = bert_tokenizer.convert_ids_to_tokens(inputs.detach().cpu().numpy()[0])
for l in range(12):
  print("\nLayer", l+1)
  attention = np.squeeze(outputs.attentions[l].detach().cpu().numpy(), axis=0)
  cls_attentions = []
  for h, head in enumerate(attention):
    print("Head", h+1)
    # Get the attention for the cls token
    cls_attentions = head[0]
    display(HTML(colorize(tokens, cls_attentions)))


Layer 1
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 2
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 3
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 4
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 5
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 6
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 7
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 8
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 9
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 10
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 11
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 12
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12


In [None]:
# Encode the second sentence
encoded_sentence2 = bert_tokenizer.batch_encode_plus([sentence2], padding=True)

# Give as input to the model and get the outputs
inputs = torch.tensor(encoded_sentence2["input_ids"]).to(device)
att = torch.tensor(encoded_sentence2["attention_mask"]).to(device)
outputs = model_e(inputs, attention_mask=att)

# Get the predictions
output_probs = softmax(outputs.logits.detach().cpu().numpy(), axis=1)
predictions = (np.argmax(output_probs, axis=1))
print(sentence2, ":", predictions[0])

Every girl brings their own lunch. : 0


In [None]:
# Visualize the attention heatmaps for the CLS token
tokens = bert_tokenizer.convert_ids_to_tokens(inputs.detach().cpu().numpy()[0])
for l in range(12):
  print("\nLayer", l+1)
  attention = np.squeeze(outputs.attentions[l].detach().cpu().numpy(), axis=0)
  cls_attentions = []
  for h, head in enumerate(attention):
    print("Head", h+1)
    # Get the attention for the cls token
    cls_attentions = head[0]
    display(HTML(colorize(tokens, cls_attentions)))


Layer 1
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 2
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 3
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 4
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 5
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 6
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 7
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 8
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 9
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 10
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 11
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12



Layer 12
Head 1


Head 2


Head 3


Head 4


Head 5


Head 6


Head 7


Head 8


Head 9


Head 10


Head 11


Head 12
