In [1]:
!pip install bertviz

Collecting bertviz
  Downloading bertviz-1.4.0-py3-none-any.whl (157 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m157.6/157.6 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: bertviz
Successfully installed bertviz-1.4.0
[0m

In [19]:
from math import sqrt

from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoConfig

In [5]:
ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = BertModel.from_pretrained(ckpt)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

100%|██████████| 433/433 [00:00<00:00, 345864.34B/s]
100%|██████████| 440473133/440473133 [00:19<00:00, 23090560.30B/s]


In [7]:
text = 'time flikes like an arrow'
show(model, 'bert', tokenizer, text, display_mode='light', layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Scaled Dot-Product Attention

In [8]:
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
inputs

{'input_ids': tensor([[ 2051, 13109, 17339,  2015,  2066,  2019,  8612]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [10]:
config = AutoConfig.from_pretrained(ckpt)
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.20.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [11]:
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [13]:
inputs_emb = token_emb(inputs.input_ids)
inputs_emb.size()  # bs, seq len, hidden_dim

torch.Size([1, 7, 768])

In [17]:
query = key = value = inputs_emb
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
scores.size()

torch.Size([1, 7, 7])

In [18]:
scores

tensor([[[ 2.6471e+01,  1.2825e+00,  4.4241e-01,  7.3064e-01, -7.1260e-02,
          -9.8635e-01,  8.4713e-01],
         [ 1.2825e+00,  2.5512e+01,  6.5703e-01, -1.7611e+00,  3.1429e-01,
           7.1619e-03,  1.4178e+00],
         [ 4.4241e-01,  6.5703e-01,  2.5662e+01, -1.7411e+00, -1.8633e-01,
           2.7245e-02,  2.4185e-02],
         [ 7.3064e-01, -1.7611e+00, -1.7411e+00,  2.7245e+01,  1.5673e-01,
           4.8511e-01,  6.8352e-01],
         [-7.1260e-02,  3.1429e-01, -1.8633e-01,  1.5673e-01,  2.6166e+01,
           3.4168e-01,  3.8949e-01],
         [-9.8635e-01,  7.1619e-03,  2.7245e-02,  4.8511e-01,  3.4168e-01,
           2.7943e+01, -4.6105e-01],
         [ 8.4713e-01,  1.4178e+00,  2.4185e-02,  6.8352e-01,  3.8949e-01,
          -4.6105e-01,  2.8038e+01]]], grad_fn=<DivBackward0>)

In [21]:
weights = F.softmax(scores, dim=-1)
weights.sum(1)

tensor([[1., 1., 1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

In [23]:
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 7, 768])

In [24]:
def scaled_dot_prod_attn(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [25]:
scaled_dot_prod_attn(query, key, value)

tensor([[[-0.0520, -0.9558, -0.3015,  ...,  1.0993,  0.0760, -0.9341],
         [ 0.7414, -0.3783,  0.5303,  ..., -0.7745, -0.7273,  0.5555],
         [-1.2966,  1.1547,  0.8478,  ..., -1.0117, -1.2110,  0.5836],
         ...,
         [-0.0670,  2.0728, -0.0061,  ...,  0.1919,  0.5736, -0.3407],
         [ 0.0669,  1.9025, -0.3545,  ..., -0.4316, -0.2005, -1.1991],
         [-1.3470, -0.3744,  1.9080,  ...,  1.5640, -0.3873,  0.9812]]],
       grad_fn=<BmmBackward0>)

<hr>

## Multi-headed Attention