In [1]:
import torch as t
from torch.nn.functional import gelu, softmax, dropout
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests

import numpy as np

# Step 1: Raw Attention Scores

Let’s first make the part of the attention layer that computes the raw attention scores (pre-softmax) between each pair of tokens. Write a function to do this, called `raw_attention_scores`. 
Type signature of `raw_attention_scores`:
```
token_activations: Tensor[batch_size, seq_length, hidden_size (which is 768)], 
num_heads: int, 
project_query: function(Tensor[..., 768] -> Tensor[..., num_heads*head_size]), 
project_key: function(Tensor[..., 768] -> Tensor[..., num_heads*head_size])
-> 
Tensor[batch_size, head_num, key_token: seq_length, query_token: seq_length]
```
If the dimensions of project_query and project_key functions don’t make sense, reread the general guidelines above. 

Gotcha: remember the “divide by sqrt(head_size)” from the Illustrated Transformer!
Gotcha #2: "raw attention pattern" means pre-softmax (otherwise known as "attention score").

Test your function with `bert_tests.test_attention_pattern_fn`.

In [2]:
def raw_attention_pattern(
        token_activations,  # Tensor[batch_size, seq_length, hidden_size(768)],
        num_heads,
        project_query,      # nn.Module, (Tensor[..., 768]) -> Tensor[..., 768],
        project_key,        # nn.Module, (Tensor[..., 768]) -> Tensor[..., 768]
): # -> Tensor[batch_size, head_num, key_token: seq_length, query_token: seq_length]:
    head_size = token_activations.shape[-1] / num_heads

    Q = project_query(token_activations)
    Q = rearrange(Q, 'b seq_length (num_head head_sz) -> b num_head seq_length head_sz', num_head=num_heads)
    K = project_key(token_activations)
    K = rearrange(K, 'b seq_length (num_head head_sz) -> b num_head seq_length head_sz', num_head=num_heads)

    A = einsum('bhql,bhkl->bhkq', Q, K) / np.sqrt(head_size)

    return A
bert_tests.test_attention_pattern_fn(raw_attention_pattern)


attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: -0.01208 STD: 0.1096 VALS [0.05786 0.0006444 0.0845 0.01998 -0.02516 -0.05008 -0.0319 -0.04448 0.09316 0.06063...]


In [12]:
def bert_attention(
        token_activations, #: Tensor[batch_size, seq_length, hidden_size (768)],
        num_heads: int,
        attention_pattern, #: Tensor[batch_size,num_heads, seq_length, seq_length],
        project_value, # nn.Module, (Tensor[..., 768]) -> Tensor[..., 768],
        project_output, # nn.Module, (Tensor[..., 768]) -> Tensor[..., 768],
): #-> Tensor[batch_size, seq_length, hidden_size]))
    V = project_value(token_activations)
    V = rearrange(V, 'b seq_length (num_head head_sz) -> b num_head seq_length head_sz', num_head=num_heads)
    A = softmax(attention_pattern, dim=2)
    out = einsum('bhkq,bhkn->bhqn', A, V)
    out = rearrange(out, 'b h q n -> b q (h n)')
    return project_output(out)
bert_tests.test_attention_fn(bert_attention)

tensor([[0.3720, 0.4221, 0.4499],
        [0.4408, 0.3297, 0.3393],
        [0.1872, 0.2482, 0.2107]])
torch.Size([2, 12, 3, 3]) torch.Size([2, 12, 3, 64])
torch.Size([2, 12, 3, 64])
torch.Size([2, 3, 768])
attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.004708 STD: 0.1157 VALS [-0.1737 -0.04187 -0.03834 0.02038 0.0409 -0.07649 -0.1073 0.04715 -0.04157 -0.01852...]


In [17]:
class MultiHeadedSelfAttention(t.nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        head_size = 64
        attention_hidden_size = num_heads * head_size
        self.num_heads = num_heads
        self.project_query = t.nn.Linear(hidden_size, attention_hidden_size)
        self.project_key = t.nn.Linear(hidden_size, attention_hidden_size)
        self.project_value = t.nn.Linear(hidden_size, attention_hidden_size)
        self.project_output = t.nn.Linear(attention_hidden_size, hidden_size)

    def forward(self, x):
        attention_pattern = raw_attention_pattern(x, self.num_heads, self.project_query, self.project_key)
        head_out = bert_attention(x, self.num_heads, attention_pattern, self.project_value, self.project_output)
        return head_out

bert_tests.test_bert_attention(MultiHeadedSelfAttention)

tensor([[0.3454, 0.3640, 0.3404],
        [0.3531, 0.3300, 0.3391],
        [0.3015, 0.3060, 0.3205]], grad_fn=<SelectBackward0>)
torch.Size([2, 12, 3, 3]) torch.Size([2, 12, 3, 64])
torch.Size([2, 12, 3, 64])
torch.Size([2, 3, 768])
bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [19]:
def bert_mlp(
        token_activations, #: torch.Tensor[batch_size,seq_length,768],
        linear_1, #: nn.Module,
        linear_2, #: nn.Module
    ): # -> torch.Tensor[batch_size, seq_length, 768]
    return linear_2(gelu(linear_1(token_activations)))
bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.0001934 STD: 0.1044 VALS [-0.1153 0.1189 -0.0813 0.1021 0.0296 0.06182 0.0341 0.1446 0.2622 -0.08507...]


In [21]:
class BertMLP(t.nn.Module):
    def __init__(self, input_size, intermediate_size):
        super().__init__()
        self.linear1 = t.nn.Linear(input_size, intermediate_size)
        self.linear2 = t.nn.Linear(intermediate_size, input_size)

    def forward(self, x):
        return bert_mlp(x, self.linear1, self.linear2)

In [32]:
class LayerNorm(t.nn.Module):
    def __init__(self, size_of_normalized_dim):
        super().__init__()
        self.weight = t.nn.Parameter(t.ones(size_of_normalized_dim))
        self.bias = t.nn.Parameter(t.zeros(size_of_normalized_dim))

    def forward(self, x):
        eps = 1e-5
        x = (x - x.mean(dim=-1, keepdim=True)) / (x.var(dim=-1, keepdim=True, unbiased=False) + eps).sqrt()
        x = x * self.weight + self.bias
        return x

bert_tests.test_layer_norm(LayerNorm)

layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: 9.537e-09 STD: 1.003 VALS [-1.352 1.454 -0.5328 1.027 1.477 -0.1402 -1.172 -0.5576 -0.7403 0.5375...]


In [35]:
class BertBlock(t.nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_heads, dropout):
        super().__init__()
        self.attention = MultiHeadedSelfAttention(num_heads, hidden_size)
        self.mlp = BertMLP(hidden_size, intermediate_size)
        self.layernorm1 = LayerNorm(hidden_size)
        self.layernorm2 = LayerNorm(hidden_size)
        self.dropout = t.nn.Dropout(dropout)

    def forward(self, x):
        out1 = self.layernorm1(self.attention(x) + x)
        out2 = self.layernorm2(self.dropout(self.mlp(out1) + out1))
        return out2

bert_tests.test_bert_block(BertBlock)

tensor([[0.3195, 0.3199, 0.3158],
        [0.3304, 0.3499, 0.3461],
        [0.3501, 0.3302, 0.3381]], grad_fn=<SelectBackward0>)
torch.Size([2, 12, 3, 3]) torch.Size([2, 12, 3, 64])
torch.Size([2, 12, 3, 64])
torch.Size([2, 3, 768])
bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 2.07e-09 STD: 1 VALS [0.007132 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


In [38]:
class Embedding(t.nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embedding = t.nn.Parameter(t.randn(vocab_size, embed_size))

    def forward(self, x):
        return self.embedding[x]

bert_tests.test_embedding(Embedding)

embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.2095 STD: 0.8819 VALS [-0.8435 0.0199 -0.7648 1.023 -1.396 -0.8435 0.0199 -0.7648 1.023 -1.396...]


In [80]:
def bert_embedding(
        input_ids, # [batch, seqlen]
        token_type_ids, # [batch, seqlen]
        position_embedding: Embedding,
        token_embedding: Embedding,
        token_type_embedding: Embedding,
        layer_norm: LayerNorm,
        dropout: t.nn.Dropout):
    pass
    # batch, seqlen = input_ids.shape
    # position_idxs = t.arange(end=batch)
    # token_embeds = token_embedding[input_ids]
    # token_type_embeds = token_type_embedding[token_type_ids]
    #
    # position_embeds = position_embedding[]