In [143]:
import torch as t
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests
import math

In [41]:
def raw_attention_pattern(
        token_activations,  # Tensor[batch_size, seq_length, hidden_size(768)],
        num_heads,
        project_query,      # nn.Module, (Tensor[..., 768]) -> Tensor[..., 768],
        project_key,        # nn.Module, (Tensor[..., 768]) -> Tensor[..., 768]
): # -> Tensor[batch_size, head_num, key_token: seq_length, query_token: seq_length]:
    Q = project_query(token_activations)
    K = project_key(token_activations)
    Q = rearrange(Q, 'b seqlen (headnum headsize) -> b headnum seqlen headsize', headnum=num_heads)
    K = rearrange(K, 'b seqlen (headnum headsize) -> b headnum seqlen headsize', headnum=num_heads)
    headsize = K.shape[-1]
    dot_prod = einsum('bhql,bhkl-> bhkq', Q, K) / math.sqrt(headsize)
    return dot_prod

bert_tests.test_attention_pattern_fn(raw_attention_pattern)

attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.005254 STD: 0.106 VALS [-0.05107 -0.06011 0.043 0.01762 -0.05089 0.07877 0.2465 0.2506 0.02985 0.01162...]


In [45]:
def bert_attention(
        token_activations, #: Tensor[batch_size, seq_length, hidden_size (768)],
        num_heads: int,
        attention_pattern, #: Tensor[batch_size,num_heads, seq_length, seq_length],
        project_value, #: function( (Tensor[..., 768]) -> Tensor[..., 768] ),
        project_output, #: function( (Tensor[..., 768]) -> Tensor[..., 768] )
): # -> Tensor[batch_size, seq_length, hidden_size]
    attention_pattern = attention_pattern.softmax(dim=2)
    V = project_value(token_activations)
    V = rearrange(V, 'b seqlen (headnum headsize) -> b headnum seqlen headsize', headnum=num_heads)
    attention = einsum('bhkq,bhkl->bhql', attention_pattern, V)
    attention = rearrange(attention, 'b headnum seqlen hiddensize -> b seqlen (headnum hiddensize)')
    ans = project_output(attention)
    return ans
bert_tests.test_attention_fn(bert_attention)

attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.003646 STD: 0.1161 VALS [0.01265 0.2171 -0.1186 0.001848 0.01751 -0.2293 -0.03623 0.01134 0.09451 -0.05312...]


In [113]:
class MultiHeadedSelfAttention(t.nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.project_query = t.nn.Linear(hidden_size, hidden_size)
        self.project_key = t.nn.Linear(hidden_size, hidden_size)
        self.project_value = t.nn.Linear(hidden_size, hidden_size)
        self.project_output = t.nn.Linear(hidden_size, hidden_size)
        self.num_heads = num_heads
        self.hidden_size = hidden_size

    def forward(self, x): # Tensor[batch_size, seq_length, hidden_size]
        raw_attention = raw_attention_pattern(x, self.num_heads, self.project_query, self.project_key)
        attention = bert_attention(x, self.num_heads, raw_attention, self.project_value, self.project_output)
        return attention

bert_tests.test_bert_attention(MultiHeadedSelfAttention)

torch.Size([2, 3, 768]) torch.Size([2, 3, 768])
bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [54]:
def bert_mlp(token_activations, #: torch.Tensor[batch_size,seq_length,768],
             linear_1: t.nn.Module, linear_2: t.nn.Module
             ): #-> torch.Tensor[batch_size, seq_length, 768]
    return linear_2(t.nn.functional.gelu(linear_1(token_activations)))

bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.003054 STD: 0.1041 VALS [0.1262 0.01134 0.06912 0.05845 0.06832 0.06498 -0.07017 -0.1155 -0.004871 0.2145...]


In [55]:
class BertMLP(t.nn.Module):
    def __init__(self, input_size: int, intermediate_size: int):
        super().__init__()
        self.linear1 = t.nn.Linear(input_size, intermediate_size)
        self.linear2 = t.nn.Linear(intermediate_size, input_size)

    def forward(self, x):
        return bert_mlp(x, self.linear1, self.linear2)

In [62]:
class LayerNorm(t.nn.Module):
    def __init__(self, normalized_dim: int):
        super().__init__()
        self.weight = t.nn.Parameter(t.ones(normalized_dim))
        self.bias = t.nn.Parameter(t.zeros(normalized_dim))

    def forward(self, input):
        input = input - input.mean(dim=-1, keepdim=True).detach()
        input = input / input.std(dim=-1, keepdim=True, unbiased=False)
        return input * self.weight + self.bias
bert_tests.test_layer_norm(LayerNorm)

layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: -5.96e-10 STD: 1.003 VALS [-1.876 0.9704 -0.2068 0.07342 0.6658 1.202 -0.8645 0.569 -1.36 0.8267...]


In [66]:
class ResNet(t.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.m = module

    def forward(self, x):
        return self.m(x) + x

class BertBlock(t.nn.Module):
    def __init__(self,
                 hidden_size: int,
                 intermediate_size: int,
                 num_heads: int,
                 dropout: float):
        super().__init__()
        self.layers = t.nn.Sequential(
            ResNet(MultiHeadedSelfAttention(num_heads, hidden_size)),
            LayerNorm(hidden_size),
            ResNet(t.nn.Sequential(BertMLP(hidden_size, intermediate_size), t.nn.Dropout(dropout))),
            LayerNorm(hidden_size)
        )

    def forward(self, x):
        return self.layers(x)
bert_tests.test_bert_block(BertBlock)

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -4.139e-10 STD: 1 VALS [0.007132 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


In [72]:
class Embedding(t.nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embedding = t.nn.Parameter(t.randn((vocab_size, embed_size)))

    def forward(self, x):
        return self.embedding[x]

bert_tests.test_embedding(Embedding)

embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.2095 STD: 0.8819 VALS [-0.8435 0.0199 -0.7648 1.023 -1.396 -0.8435 0.0199 -0.7648 1.023 -1.396...]


In addition to embedding the token itself, Bert also explicitly represents the token’s position in the sentence (“positional embedding”) and the token type (“token type embedding”).

So in total Bert stores three (learned) embedding matrices:
The token embedding matrix [28996, embedding_size]
The token_type embedding matrix [2, embedding_size]
The positional embedding matrix [512, embedding_size]

(Not all transformers learn their positional embeddings, some use hardcoded sinusoidal embeddings.)

To represent a token (word or word piece), it:
Looks up the input_id in the token_embedding matrix to get a vector [embedding_size]
Looks up the token_type of the token in the token_type embedding matrix. The token_type is always either 0 or 1, indicating whether the word belongs to “sentence A” or “sentence B”. This is relevant for tasks like paraphrasing (first phrasing, second phrasing), question answering (question, answer), or similar. For language modelling, this doesn’t matter at all – just assume all tokens have token_type = 0.
Looks up the position of the token in the sentence (e.g. “dog” has position 2, “is” has position 3) in the positional embedding matrix. The max sentence length is 512.
Adds up the three embeddings and applies layer norm, then dropout.


In [94]:
def bert_embedding(
        input_ids, # [batch, seqlen]
        token_type_ids, # [batch, seqlen]
        position_embedding: Embedding,
        token_embedding: Embedding,
        token_type_embedding: Embedding,
        layer_norm: LayerNorm,
        dropout: t.nn.Dropout):

    word_embeddings = token_embedding(input_ids) # [batch, embedding_size]
    type_embeddings = token_type_embedding(token_type_ids)
    positions = repeat(t.arange(0, input_ids.shape[-1], device=token_type_ids.device), 'n -> b n', b=input_ids.shape[0])
    position_embeddings = position_embedding(t.squeeze(positions))
    # summed_embedding = einsum('bse,bse,bse->bse', word_embeddings, type_embeddings, position_embeddings)
    summed_embedding = word_embeddings + type_embeddings + position_embeddings
    return dropout(layer_norm(summed_embedding))

bert_tests.test_bert_embedding_fn(bert_embedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 8.278e-10 STD: 1 VALS [-0.1558 -0.906 1.358 -0.1096 0.02568 -0.9749 -0.2617 0.05282 -2.021 -0.3563...]


max_position_embeddings is the maximum sentence length, 512 above
type_vocab_size is the number of token_types, 2 above
Use your Embedding to store the embedding matrices.
Initialise the embeddings in the order token, position, token_type.
Test your module with bert_tests.test_bert_embedding

In [103]:
class BertEmbedding(t.nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int, max_position_embeddings: int, type_vocab_size: int, dropout: float):
        super().__init__()
        self.token_embedding = Embedding(vocab_size, hidden_size)
        self.position_embedding = Embedding(max_position_embeddings, hidden_size)
        self.token_type_embedding = Embedding(type_vocab_size, hidden_size)
        self.layernorm = LayerNorm(hidden_size)
        self.dropout = t.nn.Dropout(dropout)

    def forward(self, input_ids, token_type_ids):
        return bert_embedding(input_ids, token_type_ids, self.position_embedding, self.token_embedding, self.token_type_embedding, self.layernorm, self.dropout)

bert_tests.test_bert_embedding(BertEmbedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -7.14e-09 STD: 1 VALS [-1.084 -0.9279 0.962 -0.7941 -2.865 1.584 0.5704 1.182 0.2184 0.7135...]


Make a nn.Module called Bert, with the following methods:
__init__(
self, vocab_size: int, hidden_size: int,
max_position_embeddings: int, type_vocab_size: int,
dropout: float, intermediate_size: int, num_heads: int,
num_layers: int
)
forward(self, input_ids)

- hidden_size is the embedding size, 768
- dropout probability is the same for both the embedding and the BertMLP
- num_layers is the number of BertBlocks
- Assume token_type_ids are zero. Make sure your code is GPU-friendly, e.g. place token_type_ids on the same device as input_ids.
- Use the architecture diagram on page 1 to guide you.
- In addition to the N encoding blocks, we will add a Linear layer [hidden_size, hidden_size], GELU, LayerNorm, and an unembedding layer (Linear[hidden_size, vocab_size]) at the end.
- If you think the whole thing works, test it with bert_tests.test_bert.

Note: don’t actually compute the softmax at the end. This allows you to use log_softmax later instead, which preserves the differences between very small probabilities, such as 1e-30 and 1e-35 which would be destroyed by limited precision under a softmax.

In [116]:
class Bert(t.nn.Module):
    def __init__(self,
                 vocab_size: int,
                 hidden_size: int,
                 max_position_embeddings: int,
                 type_vocab_size: int,
                 dropout: float,
                 intermediate_size: int,
                 num_heads: int,
                 num_layers: int):
        super().__init__()
        self.embed = BertEmbedding(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
        self.layers = t.nn.Sequential(
            *[BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)],
            t.nn.Linear(hidden_size, hidden_size),
            t.nn.GELU(),
            LayerNorm(hidden_size),
            t.nn.Linear(hidden_size, vocab_size)
            )

    def forward(self, input_ids):
        token_type_ids = t.zeros_like(input_ids, device=input_ids.device)
        ans = self.embed(input_ids, token_type_ids)
        ans2 = self.layers(ans)
        print(ans.shape, ans2.shape)
        return ans2

bert_tests.test_bert(Bert)

torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 768])
torch.Size([1, 4, 768]) torch.Size([1, 4, 28996])
bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.4321 0.1186 -0.7165 -0.5262 0.4967 1.223 0.3165 -0.3247 -0.5717...]


In [127]:
my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512,
    type_vocab_size=2, dropout=0.1, intermediate_size=3072,
    num_heads=12, num_layers=12
)
pretrained_bert = bert_tests.get_pretrained_bert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


embedding.token_embedding.weight embed.token_embedding.embedding embedding.token_embedding.weight
embedding.position_embedding.weight embed.position_embedding.embedding embedding.position_embedding.weight
embedding.token_type_embedding.weight embed.token_type_embedding.embedding embedding.token_type_embedding.weight
embedding.layer_norm.weight embed.layernorm.weight embed.layernorm. 
embedding.layer_norm.bias embed.layernorm.bias embed.layernorm. 
transformer.0.layer_norm.weight layers.0.layers.0.m.project_query.weight transformer.0.layer_norm.weight
transformer.0.layer_norm.bias layers.0.layers.0.m.project_query.bias transformer.0.layer_norm.bias
transformer.0.attention.pattern.project_query.weight layers.0.layers.0.m.project_key.weight transformer.0.attention.pattern.project_query.weight
transformer.0.attention.pattern.project_query.bias layers.0.layers.0.m.project_key.bias transformer.0.attention.pattern.project_query.bias
transformer.0.attention.pattern.project_key.weight layers.0.

In [154]:
import re
def convert(s):
    # embedding blocks
    key = re.sub(r"^embedding", r"embed", s)
    key = re.sub(r"^embed.(\w+_embedding).weight", r"embed.\1.embedding", key)
    key = re.sub(r"^embed.layer_norm.(\w+)$", r"embed.layernorm.\1", key)

    # layers
    key = re.sub(r"^transformer.(\d+)", r"layers.\1", key)
    key = re.sub(r"^layers.(\d+).layer_norm", r"layers.\1.layers.1", key)
    key = re.sub(r"^layers.(\d+).attention(.pattern)?", r"layers.\1.layers.0.m", key)
    key = re.sub(r"^layers.(\d+).residual.mlp(\d+)", r"layers.\1.layers.2.m.0.linear\2", key)
    key = re.sub(r"^layers.(\d+).residual.layer_norm", r"layers.\1.layers.3", key)
    key = re.sub(r"project_out", r"project_output", key)

    # end bit
    key = re.sub(r"^lm_head.mlp", r"layers.12", key)
    key = re.sub(r"^lm_head.unembedding", r"layers.15", key)
    key = re.sub(r"^lm_head.layer_norm", r"layers.14", key)
    return key

for (n1, p1), (n2, p2) in zip(pretrained_bert.named_parameters(), my_bert.named_parameters()):
    if n1.startswith('lm_head.unembedding.weight'):
        print(n1, n2, convert(n1))

mapped_params = {convert(k): v for k, v in pretrained_bert.state_dict().items()
                 if not k.startswith('classification_head')}
my_bert.load_state_dict(mapped_params)
bert_tests.test_same_output(my_bert, pretrained_bert, tol=0.3)

torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 768])
torch.Size([10, 20, 768]) torch.Size([10, 20, 28996])
comparing Berts MATCH!!!!!!!!
 SHAPE (10, 20, 28996) MEAN: -2.796 STD: 2.486 VALS [-4.243 -4.423 -4.44 -4.375 -4.438 -4.2 -4.534 -4.422 -4.317 -4.508...]
