## 2.2 Scaled dot product attention

In [1]:
from transformers import BertModel  # Let's use a BERT model

In [2]:
model = BertModel.from_pretrained('bert-base-uncased')

In [3]:
len(model.encoder.layer)  # Base BERT has 12 encoders in the encoder stack

12

In [4]:
model.encoder.layer[0]  # the first encoder

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [5]:
model.encoder.layer[0].attention  # The attention in the first encoder

BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=768, out_features=768, bias=True)
    (key): Linear(in_features=768, out_features=768, bias=True)
    (value): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

## 2.3 Multi-headed attention

In [6]:
from transformers import BertModel, BertTokenizer
from bertviz import head_view
import torch
import pandas as pd

In [7]:
# Let's load a vanilla BERT-base model. 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [8]:
text = "My friend told me about this class and I love it so far! She was right."

tokens = tokenizer.encode(text)
inputs = torch.tensor(tokens).unsqueeze(0) # unsqueeze changes the shape from (20,) -> (1, 20)
inputs

tensor([[ 101, 2026, 2767, 2409, 2033, 2055, 2023, 2465, 1998, 1045, 2293, 2009,
         2061, 2521,  999, 2016, 2001, 2157, 1012,  102]])

In [9]:
attention = model(inputs, output_attentions=True)[2]  # Grab the attention scores from BERT

In [10]:
# average attention in the last encoder
final_attention = attention[-1].mean(1)[0]

In [11]:
attention_df = pd.DataFrame(final_attention.detach()).applymap(float).round(3)

attention_df.columns = tokenizer.convert_ids_to_tokens(tokens)
attention_df.index = tokenizer.convert_ids_to_tokens(tokens)

attention_df  # sums across rows add up to 1. sums across columns do not


Unnamed: 0,[CLS],my,friend,told,me,about,this,class,and,i,love,it,so,far,!,she,was,right,.,[SEP]
[CLS],0.092,0.028,0.019,0.011,0.012,0.022,0.05,0.087,0.031,0.023,0.023,0.031,0.007,0.028,0.067,0.057,0.065,0.124,0.104,0.12
my,0.021,0.023,0.014,0.01,0.013,0.021,0.028,0.015,0.014,0.012,0.01,0.023,0.011,0.009,0.016,0.022,0.021,0.019,0.312,0.388
friend,0.018,0.009,0.129,0.009,0.005,0.008,0.008,0.012,0.009,0.005,0.009,0.006,0.004,0.005,0.009,0.023,0.01,0.006,0.314,0.401
told,0.01,0.004,0.013,0.084,0.004,0.011,0.005,0.005,0.005,0.002,0.008,0.005,0.005,0.003,0.006,0.008,0.004,0.003,0.351,0.464
me,0.024,0.013,0.01,0.011,0.017,0.016,0.018,0.011,0.014,0.01,0.01,0.014,0.007,0.008,0.014,0.009,0.006,0.005,0.347,0.436
about,0.019,0.01,0.007,0.018,0.01,0.079,0.021,0.012,0.012,0.006,0.014,0.019,0.008,0.008,0.012,0.005,0.003,0.005,0.32,0.412
this,0.026,0.014,0.003,0.004,0.01,0.015,0.069,0.02,0.011,0.01,0.011,0.018,0.006,0.008,0.012,0.005,0.003,0.004,0.331,0.421
class,0.028,0.01,0.007,0.006,0.006,0.015,0.029,0.096,0.01,0.009,0.013,0.019,0.006,0.009,0.015,0.01,0.005,0.005,0.312,0.39
and,0.031,0.016,0.006,0.007,0.012,0.009,0.013,0.009,0.08,0.013,0.01,0.01,0.008,0.009,0.024,0.014,0.012,0.011,0.316,0.386
i,0.023,0.014,0.008,0.005,0.011,0.011,0.019,0.012,0.021,0.029,0.014,0.013,0.008,0.014,0.019,0.012,0.009,0.008,0.334,0.414


In [12]:
# https://nlp.stanford.edu/pubs/clark2019what.pdf
# Layer index 2 seems to be attending to the previous token
# Layer index 6 seems to be for pronouns

In [13]:
tokens_as_list = tokenizer.convert_ids_to_tokens(inputs[0])
head_view(attention, tokens_as_list)

<IPython.core.display.Javascript object>

In [14]:
# Head 3-1 attends to previous token
head_view(attention, tokenizer.convert_ids_to_tokens(inputs[0]), layer=2, heads=[0])

<IPython.core.display.Javascript object>

In [15]:
# Head 8-10 relating direct objects to their verbs eg told -> me
head_view(attention, tokenizer.convert_ids_to_tokens(inputs[0]), layer=7, heads=[9])

<IPython.core.display.Javascript object>

In [16]:
# attention in the 8th encoder's 10th head to see direct object attention
eight_ten = attention[7][0][9]

In [17]:
# Get the attention matrix
attention_df = pd.DataFrame(eight_ten.detach()).applymap(float).round(3)

attention_df.columns = tokenizer.convert_ids_to_tokens(tokens)
attention_df.index = tokenizer.convert_ids_to_tokens(tokens)

attention_df  # sums across rows add up to 1. sums across columns do not


Unnamed: 0,[CLS],my,friend,told,me,about,this,class,and,i,love,it,so,far,!,she,was,right,.,[SEP]
[CLS],0.007,0.004,0.005,0.002,0.001,0.001,0.003,0.004,0.002,0.001,0.002,0.003,0.001,0.005,0.006,0.005,0.01,0.039,0.033,0.867
my,0.031,0.03,0.027,0.009,0.004,0.002,0.001,0.004,0.006,0.001,0.001,0.002,0.001,0.002,0.033,0.002,0.002,0.004,0.05,0.788
friend,0.022,0.128,0.024,0.002,0.002,0.0,0.001,0.001,0.004,0.001,0.001,0.001,0.0,0.002,0.016,0.001,0.002,0.001,0.025,0.765
told,0.035,0.072,0.014,0.013,0.005,0.002,0.002,0.002,0.002,0.0,0.0,0.0,0.0,0.001,0.009,0.001,0.0,0.0,0.034,0.808
me,0.01,0.01,0.005,0.683,0.015,0.007,0.001,0.001,0.001,0.0,0.0,0.0,0.0,0.0,0.003,0.0,0.001,0.003,0.004,0.255
about,0.017,0.015,0.025,0.222,0.024,0.015,0.011,0.02,0.005,0.001,0.001,0.0,0.0,0.001,0.004,0.001,0.001,0.006,0.012,0.618
this,0.005,0.002,0.008,0.223,0.03,0.452,0.073,0.046,0.003,0.001,0.002,0.001,0.001,0.002,0.001,0.0,0.0,0.007,0.001,0.143
class,0.012,0.002,0.004,0.074,0.02,0.204,0.339,0.138,0.004,0.001,0.008,0.007,0.002,0.004,0.001,0.002,0.001,0.018,0.006,0.154
and,0.03,0.008,0.001,0.077,0.019,0.091,0.017,0.013,0.084,0.008,0.002,0.001,0.001,0.001,0.009,0.001,0.002,0.003,0.022,0.61
i,0.022,0.016,0.008,0.181,0.018,0.021,0.003,0.007,0.363,0.033,0.004,0.001,0.002,0.002,0.01,0.0,0.001,0.001,0.004,0.302
