In [19]:
from math import sqrt

#from bertviz import head_view
#from bertviz.transformers_neuron_view import BertModel
#from bertviz.neuron_view import show
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [2]:
model_ckpt = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
#model = BertModel.from_pretrained(model_ckpt)

In [3]:
text = 'time flies like an arrow'
#show(
#    model, 
#    'bert',
#    tokenizer,
#    text,
#    display_mode='light',
#    layer=0,
#    head=8)

In [4]:
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
inputs.input_ids

tensor([[ 2051, 10029,  2066,  2019,  8612]])

In [5]:
config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
token_emb

Embedding(30522, 768)

In [6]:
input_embeds = token_emb(inputs.input_ids)
input_embeds.size() # one batch of our 5 words in a 768-vector encoding

torch.Size([1, 5, 768])

In [7]:
query = key = value = input_embeds
dim_k = key.size(-1)
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
scores.size()

torch.Size([1, 5, 5])

In [8]:
scores

tensor([[[28.5866, -1.1095, -0.8704,  1.4623,  0.8429],
         [-1.1095, 25.5054,  0.3590,  2.7437,  0.1270],
         [-0.8704,  0.3590, 29.2243,  0.9720, -0.6094],
         [ 1.4623,  2.7437,  0.9720, 28.7171,  0.2261],
         [ 0.8429,  0.1270, -0.6094,  0.2261, 28.0545]]],
       grad_fn=<DivBackward0>)

In [9]:
weights = F.softmax(scores, dim=-1)
weights

tensor([[[1.0000e+00, 1.2680e-13, 1.6106e-13, 1.6598e-12, 8.9342e-13],
         [2.7624e-12, 1.0000e+00, 1.1997e-11, 1.3024e-10, 9.5132e-12],
         [8.5119e-14, 2.9104e-13, 1.0000e+00, 5.3725e-13, 1.1050e-13],
         [1.4568e-12, 5.2473e-12, 8.9225e-13, 1.0000e+00, 4.2318e-13],
         [1.5211e-12, 7.4348e-13, 3.5598e-13, 8.2088e-13, 1.0000e+00]]],
       grad_fn=<SoftmaxBackward0>)

In [10]:
weights.sum(dim=-1)

tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

In [11]:
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 5, 768])

In [12]:
# wrap it up
def scaled_dot_product_attention(query, key, value):
    dim_k = key.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    attn = torch.bmm(weights, value)
    return attn

In [13]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)
        
    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(
            self.q(hidden_state),
            self.k(hidden_state),
            self.v(hidden_state))
        return attn_outputs

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        n_heads = config.num_attention_heads
        head_dim = embed_dim // n_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(n_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, hidden_state):
        x = torch.cat([head(hidden_state) for head in self.heads],
                      dim=-1)
        x = self.output_linear(x)
        return x

In [18]:
multihead_attn = MultiHeadAttention(config)
attn_output = multihead_attn(input_embeds)
attn_output.size() # batch, words, embedding vec

torch.Size([1, 5, 768])

In [21]:
mod = AutoModel.from_pretrained(model_ckpt, output_attentions=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
sent_a = 'time flies like an arrow'
sent_b = 'fruit flies like a banana'

viz_inputs = tokenizer(sent_a, sent_b, return_tensors='pt')
attention = mod(**viz_inputs).attentions
sent_b_start = (viz_inputs.token_type_ids == 0).sum(dim=1)
tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[0])

In [23]:
#head_view(attention, tokens, sent_b_start, heads=[8])

In [24]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.lin1 = nn.Linear(config.hidden_size,
                              config.intermediate_size)
        self.lin2 = nn.Linear(config.intermediate_size,
                              config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        
    def forward(self, x):
        x = self.lin1(x)
        x = self.gelu(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x

In [25]:
ff = FeedForward(config)
ff_outputs = ff(attn_outputs)
ff_outputs.size()

torch.Size([1, 5, 768])