In [1]:
!pip install bertviz

Collecting bertviz
  Downloading bertviz-1.4.0-py3-none-any.whl (157 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m157.6/157.6 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: bertviz
Successfully installed bertviz-1.4.0
[0m

In [2]:
from math import sqrt

from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show
from bertviz import head_view

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoConfig, AutoModel

# Encoder

In [3]:
ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = BertModel.from_pretrained(ckpt)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

100%|██████████| 433/433 [00:00<00:00, 395155.27B/s]
100%|██████████| 440473133/440473133 [00:18<00:00, 23714675.11B/s]


In [4]:
text = 'time flikes like an arrow'
show(model, 'bert', tokenizer, text, display_mode='light', layer=0, head=8)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Scaled Dot-Product Attention

In [5]:
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
inputs

{'input_ids': tensor([[ 2051, 13109, 17339,  2015,  2066,  2019,  8612]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [6]:
config = AutoConfig.from_pretrained(ckpt)
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.20.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [7]:
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [8]:
inputs_emb = token_emb(inputs.input_ids)
inputs_emb.size()  # bs, seq len, hidden_dim

torch.Size([1, 7, 768])

In [9]:
query = key = value = inputs_emb
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
scores.size()

torch.Size([1, 7, 7])

In [10]:
scores

tensor([[[ 2.7903e+01,  1.0821e+00, -1.8298e+00, -5.4426e-01,  8.3116e-01,
           8.0224e-01,  1.2557e+00],
         [ 1.0821e+00,  2.7273e+01,  4.3572e-01, -3.3205e-01,  5.9216e-01,
           5.9905e-02, -8.5838e-02],
         [-1.8298e+00,  4.3572e-01,  2.9096e+01,  3.4328e-01,  2.3644e-01,
          -5.5304e-01, -2.6988e-02],
         [-5.4426e-01, -3.3205e-01,  3.4328e-01,  2.8426e+01, -2.6737e+00,
           1.6197e+00,  1.1794e-01],
         [ 8.3116e-01,  5.9216e-01,  2.3644e-01, -2.6737e+00,  2.8379e+01,
          -4.3337e-02, -1.0949e+00],
         [ 8.0224e-01,  5.9905e-02, -5.5304e-01,  1.6197e+00, -4.3337e-02,
           2.6695e+01, -2.3412e-01],
         [ 1.2557e+00, -8.5838e-02, -2.6988e-02,  1.1794e-01, -1.0949e+00,
          -2.3412e-01,  2.8479e+01]]], grad_fn=<DivBackward0>)

In [11]:
weights = F.softmax(scores, dim=-1)
weights.sum(1)

tensor([[1., 1., 1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

In [12]:
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 7, 768])

In [13]:
def scaled_dot_prod_attn(query, key, value):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)

In [14]:
scaled_dot_prod_attn(query, key, value)

tensor([[[ 0.0260,  0.3956, -0.3039,  ..., -0.4523, -1.6973,  2.2849],
         [-1.3075,  1.1852,  1.9812,  ...,  0.6425,  1.3456, -0.0645],
         [-0.9984, -1.3476, -0.8125,  ..., -0.0326,  0.0667, -1.3985],
         ...,
         [ 1.5969, -2.4145,  0.8188,  ..., -1.1758,  0.3238,  0.1500],
         [ 0.4395, -0.1840,  0.0310,  ...,  0.5084, -0.2811, -1.2429],
         [ 0.4268,  1.1873, -1.3772,  ..., -0.3203,  0.1119,  0.3547]]],
       grad_fn=<BmmBackward0>)

<hr>

## Multi-headed Attention

In [15]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        
    def forward(self, hidden_state):
        attn_outputs = scaled_dot_prod_attn(
            self.q(hidden_state), self.k(hidden_state), self.v(hidden_state)
        )
        return attn_outputs

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        self.out_layer = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        return self.out_layer(x)

In [17]:
multihead_attn = MultiHeadAttention(config)
attn_output = multihead_attn(inputs_emb)
attn_output.shape

torch.Size([1, 7, 768])

In [18]:
model = AutoModel.from_pretrained(ckpt, output_attentions=True)

sent_a = 'time flies like an arrow'
sent_b = 'fruit flikes like a banana'

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
viz_inputs = tokenizer(sent_a, sent_b, return_tensors='pt')
viz_inputs

{'input_ids': tensor([[  101,  2051, 10029,  2066,  2019,  8612,   102,  5909, 13109, 17339,
          2015,  2066,  1037, 15212,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [20]:
attention = model(**viz_inputs).attentions
attention[0].shape

torch.Size([1, 12, 15, 15])

In [21]:
sent_b_start = (viz_inputs.token_type_ids == 0).sum(1)
tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[0])
tokens

['[CLS]',
 'time',
 'flies',
 'like',
 'an',
 'arrow',
 '[SEP]',
 'fruit',
 'fl',
 '##ike',
 '##s',
 'like',
 'a',
 'banana',
 '[SEP]']

In [22]:
head_view(attention, tokens, sent_b_start, heads=[8])

<IPython.core.display.Javascript object>

<hr>

## Feed Forward Layer

In [23]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
            nn.Dropout(config.hidden_dropout_prob)
        )

    def forward(self, x):
        return self.layers(x)

In [24]:
feed_forward = FeedForward(config)
ff_outputs = feed_forward(attn_outputs)
ff_outputs.shape

torch.Size([1, 7, 768])

In [25]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        
    def forward(self, x):
        hidden_state = self.layer_norm_1(x)
        x += self.attention(hidden_state)  # apply w/skip connection
        x += self.feed_forward(self.layer_norm_2(x))  # apply w/skip connection
        return x

In [26]:
encoder_layer= TransformerEncoderLayer(config)
inputs_emb.shape, encoder_layer(inputs_emb).shape

(torch.Size([1, 7, 768]), torch.Size([1, 7, 768]))

In [27]:
inputs_emb[:, :, 0]

tensor([[ 0.2841, -1.2986, -1.1782, -0.7829,  1.5485,  0.6202,  0.4860]],
       grad_fn=<SelectBackward0>)

In [28]:
encoder_layer(inputs_emb)[:, :, 0]

tensor([[ 0.5194, -1.2731, -1.3564, -0.8815,  1.4916,  0.7036,  0.5418]],
       grad_fn=<SelectBackward0>)

<hr>

## Positional Embeddings

In [29]:
inputs

{'input_ids': tensor([[ 2051, 13109, 17339,  2015,  2066,  2019,  8612]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [30]:
class PositionalEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_emb = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()
        
    def forward(self, input_ids):
        seq_len = input_ids.size(1)
        position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
        token_emb = self.token_emb(input_ids)
        position_emb = self.position_emb(position_ids)
        embeddings = token_emb + position_emb
        embeddings = self.layer_norm(embeddings)
        return self.dropout(embeddings)

In [31]:
embed_layer = PositionalEmbeddings(config)
embed_layer(inputs.input_ids).shape

torch.Size([1, 7, 768])

In [32]:
embed_layer(inputs.input_ids)[:, :, 0]

tensor([[-0.6853, -1.5872,  2.1090,  2.9029,  0.0000, -1.2191,  0.5237]],
       grad_fn=<SelectBackward0>)

In [33]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeds = PositionalEmbeddings(config)
        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]
        )
        
    def forward(self, x):
        embeds = self.embeds(x)
        for layer in self.layers:
            x = layer(embeds)
        return x

In [34]:
encoder = TransformerEncoder(config)
encoder(inputs.input_ids).shape

torch.Size([1, 7, 768])

<hr>

## Adding Classification Head

In [37]:
class TransformerForSequenceClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
    def forward(self, x):
        x = self.encoder(x)[:, 0, :]  # only want hidden state of first ([CLS]) token
        x = self.dropout(x)
        return self.classifier(x)

In [38]:
config.num_labels = 3

enc_cls = TransformerForSequenceClassification(config)
enc_cls(inputs.input_ids)

tensor([[-0.1449, -0.9349, -0.3019]], grad_fn=<AddmmBackward0>)