# Bert

In [1]:
import torch
from torch import nn
import d2l

## Bert Encoder

![jupyter](../images/12/bert-input.svg)

In [3]:
#@save
class BERTEncoder(nn.Module):
    def __init__(self, vocab_size, num_hiddens, 
                 norm_shape, ffn_num_hiddens, num_heads, 
                 num_layers, dropout, max_len=1000):
        super(BERTEncoder, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(
                f"{i}",
                d2l.EncoderBlock(num_hiddens, num_hiddens, num_hiddens, num_hiddens, 
                                 norm_shape, ffn_num_hiddens, num_heads, 
                                 dropout, True))
        # In BERT, positional embeddings are learnable, thus we create a
        # parameter of positional embeddings that are long enough
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # Shape of `X` remains unchanged in the following code snippet:
        # (batch size, max sequence length, `num_hiddens`)
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:, :X.shape[1], :]
        for blk in self.blks:
            X = blk(X, valid_lens)
        return X

## Masked Language Modeling