<center><h1>BERT</h1> </center>

<center><p><a href="http://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a></p></center>

<img src="https://production-media.paperswithcode.com/methods/new_BERT_Overall.jpg" width="1000"/>

[Code: https://github.com/codertimo/BERT-pytorch](https://github.com/codertimo/BERT-pytorch)


In [1]:
import math

import torch
from torch import nn

# Model Architecture

## Embeddings

### Token Embedding

In [2]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=768):
        """
        Token Embeddings.
        :param vocab_size: dictionary size of the source vocabulary.
        :param embed_size: the size of each embedding vector. Default: 768
        """
        super().__init__(vocab_size, embed_size, padding_idx=0)

### Positional Embedding

In [3]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embed_size=768, max_len=512):
        """
        Positional Embeddings.
        :param embed_size: the size of each embedding vector. Default: 768
        :param max_len: max length of the sequence. Default: 512
        """
        super().__init__()

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_size, 2).float()
            * -(math.log(10000.0) / embed_size)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Positional Embeddings.
        :param x: (batch_size, seq_length, embed_size)
        :return: (batch_size, seq_length, embed_size)
        """
        return self.pe[:, : x.size(1)].requires_grad_(False)

### Segment Embedding

We differentiate the sentences in two ways. First, we separate them with a special token ([SEP]). Second, we add a learned embedding to every token indicating whether it belongs to sentence A or sentence B.

* 2 is the number of segments in the input.
* 1 is an additional index is for padding.

In [4]:
class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size=768):
        """
        Segment Embeddings.
        :param embed_size: the size of each embedding vector. Default: 768
        """
        super().__init__(3, embed_size, padding_idx=0)

### BERT Embeddings

In [5]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size=768, dropout=0.1, max_len=512):
        """
        The Input Embeddings.
        :param vocab_size: dictionary size of the source vocabulary.
        :param embed_size: the size of each embedding vector. Default: 768
        :param dropout: probability of an element to be zeroed. Default: 0.1
        :param max_len: max length of the sequence. Default: 512
        """
        super().__init__()
        self.token = TokenEmbedding(vocab_size, embed_size)
        self.position = PositionalEmbedding(embed_size, max_len)
        self.segment = SegmentEmbedding(embed_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, sequence, segment_label):
        """
        The Input Embeddings.
        :param sequence: (batch_size, seq_length)
        :param segment_label: (batch_size, seq_length)
        :return: (batch_size, seq_length, embed_size)
        """
        x = (
                self.token(sequence)
                + self.position(sequence)
                + self.segment(segment_label)
        )
        return self.dropout(x)

## Sublayers

### Multi-Head Attention

In [6]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h=12, d_model=768, dropout=0.1):
        """
        Multi-Head Self-Attention Mechanism.
        :param h: the number of heads. Default: 12
        :param d_model: the size of each embedding vector. Default: 768
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, d_model) for _ in range(3)]
        )
        self.output_layer = nn.Linear(d_model, d_model)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        """
        Multi-Head Self-Attention Mechanism.
        :param query: (batch_size, seq_length, d_model)
        :param key: (batch_size, seq_length, d_model)
        :param value: (batch_size, seq_length, d_model)
        :param mask: (batch_size, 1, seq_length)
        :return: (batch_size, seq_length, d_model)
        """
        if mask is not None:
            mask = mask.unsqueeze(1)
        batch_size = query.size(0)

        query, key, value = [
            l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linear_layers, (query, key, value))
        ]

        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.h * self.d_k)
        )

        del query
        del key
        del value
        return self.output_layer(x)

### Feed Forward Networks

In [7]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model=768, d_ff=3072, dropout=0.1):
        """
        Position-wise Fully Connected Feed-Forward Network.
        :param d_model: the size of each embedding vector. Default: 768
        :param d_ff: dimension of the inner layer. Default: 3072
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()

    def forward(self, x):
        """
        Position-wise Fully Connected Feed-Forward Network.
        :param x: (batch_size, seq_length, d_model)
        :return: (batch_size, seq_length, d_model)
        """
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

### Sublayer Connection

In [8]:
class SublayerConnection(nn.Module):
    def __init__(self, d_model=768, dropout=0.1):
        """
        We employ a residual connection around each of the two sub-layers,
        followed by layer normalization (Actually we use pre-layer norm).
        :param d_model: the size of each embedding vector. Default: 768
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """
        We employ a residual connection around each of the two sub-layers,
        followed by layer normalization (Actually we use pre-layer norm).
        :param x: (batch_size, seq_length, d_model)
        :param sublayer: attention or feed-forward network.
        :return: (batch_size, seq_length, d_model)
        """
        return x + self.dropout(sublayer(self.norm(x)))

## BERT

### Transformer Block

Identical Transformer Encoder Layer.

We use a dropout probability of $0.1$ on all layers. Seems that in BERT we apply dropout after feed-forward $W_2$ Linear Layer as well.

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model=768, d_ff=3072, n_heads=12, dropout=0.1):
        """
        Identical Transformer Encoder Layer.
        :param d_model: the size of each embedding vector. Default: 768
        :param d_ff: dimension of the inner layer. Default: 3072
        :param n_heads: the number of heads. Default: 12
        :param dropout: probability of an element to be zeroed. Default: 0.1
        """
        super().__init__()
        self.self_attn = MultiHeadedAttention(n_heads, d_model, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.sublayers = nn.ModuleList(
            [SublayerConnection(d_model, dropout) for _ in range(2)]
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        """
        Identical Transformer Encoder Layer.
        :param x: (batch_size, seq_length, d_model)
        :param mask: (batch_size, 1, seq_length)
        :return: (batch_size, seq_length, d_model)
        """
        x = self.sublayers[0](x, lambda _x: self.self_attn(_x, _x, _x, mask))
        x = self.sublayers[1](x, self.feed_forward)
        return self.dropout(x)

### BERT

Transformer Encoder.

In [10]:
class BERT(nn.Module):
    def __init__(
            self,
            vocab_size,
            hidden=768,
            n_layers=12,
            n_heads=12,
            dropout=0.1,
            max_len=512,
    ):
        """
        BERT.
        :param vocab_size: dictionary size of the vocabulary.
        :param hidden: the size of each embedding vector. Default: 768
        :param n_layers: the number of decoder layers. Default: 12
        :param n_heads: the number of heads. Default: 12
        :param dropout: probability of an element to be zeroed. Default: 0.1
        :param max_len: max length of the sequence. Default: 512
        """
        super().__init__()
        self.hidden = hidden
        self.embedding = BERTEmbedding(vocab_size, hidden, dropout, max_len)
        # In all cases we set the feed-forward/filter size to be 4H,
        # i.e., 3072 for the H = 768 and 4096 for the H = 1024.
        self.transformer_blocks = nn.ModuleList(
            [
                TransformerBlock(hidden, hidden * 4, n_heads, dropout)
                for _ in range(n_layers)
            ]
        )

    def forward(self, x, segment_label):
        """
        BERT.
        :param x: (batch_size, seq_length)
        :param segment_label: (batch_size, seq_length)
        :return: (batch_size, seq_length, hidden)
        """
        # (batch_size, seq_length) -> (batch_size, 1, seq_length) -> (batch_size, seq_length, seq_length)
        mask = (x > 0).unsqueeze(1)

        x = self.embedding(x, segment_label)

        for transformer_block in self.transformer_blocks:
            x = transformer_block(x, mask)

        return x

## Pre-training BERT

### Next Sentence Prediction (NSP)

2-class classification model: is_next, is_not_next

In [11]:
class NextSentencePrediction(nn.Module):
    def __init__(self, hidden=728):
        """
        Next Sentence Prediction.
        :param hidden: the size of each embedding vector. Default: 728
        """
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        """
        Next Sentence Prediction.
        :param x: (batch_size, seq_length, hidden)
        :return: (batch_size, seq_length, 2)
        """
        return self.softmax(self.linear(x[:, 0]))

### Masked Language Model (MLM)

N-class classification problem: n-class = vocab_size

In [12]:
class MaskedLanguageModel(nn.Module):
    def __init__(self, hidden, vocab_size):
        """
        Masked Language Model.
        :param hidden: the size of each embedding vector. Default: 728
        :param vocab_size: dictionary size of the vocabulary.
        """
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        """
        Masked Language Model.
        :param x: (batch_size, seq_length, hidden)
        :return: (batch_size, seq_length, vocab_size)
        """
        return self.softmax(self.linear(x))

### BERT Language Model

Next Sentence Prediction Model + Masked Language Model

In [13]:
class BERTLM(nn.Module):
    def __init__(self, bert: BERT, vocab_size):
        """
        Pre-training BERT.
        :param bert: BERT
        :param vocab_size: dictionary size of the vocabulary.
        """
        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.hidden)
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)

    def forward(self, x, segment_label):
        """
        Pre-training BERT.
        :param x: (batch_size, seq_length)
        :param segment_label: (batch_size, seq_length)
        :return: (batch_size, seq_length, 2), (batch_size, seq_length, vocab_size)
        """
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)

# Utils

## Attention

In [14]:
def attention(query, key, value, mask=None, dropout=None):
    """
    Scaled Dot-Product Attention.
    :param query: (batch_size, head_num, seq_length, d_k)
    :param key: (batch_size, head_num, seq_length, d_k)
    :param value: (batch_size, head_num, seq_length, d_k)
    :param mask: (batch_size, 1, 1, seq_length)
    :param dropout: probability of an element to be zeroed. Default: None
    :return: (batch_size, head_num, seq_length, d_k), (batch_size, head_num, seq_length, seq_length)
    """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

## Layer Normalization

In [15]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

## GELU

GELU：高斯误差线性单元激活函数：

$$\text{GELU}(x)=0.5x\left[1+\tanh\left(\sqrt{\dfrac2\pi}(x+0.044715x^3)\right)\right]$$

BERT中使用GeLU代替ReLU。

In [16]:
class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (
                1 +
                torch.tanh(
                    math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))
                )
        )

# Summary

## Data

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

vocab_size = 30000
batch_size = 2
seq_length = 500

bert_input = torch.randint(1, vocab_size, size=(batch_size, seq_length)).to(device)
segment_label = torch.randint(1, 3, size=(batch_size, seq_length)).to(device)

## BERT<sub>BASE</sub>

### Model Architecture

In [18]:
from torchkeras import summary

net = BERT(
    vocab_size=vocab_size,
    hidden=768,
    n_layers=12,
    n_heads=12,
    dropout=0.1,
    max_len=512,
).to(device)
summary(net, input_data_args=[bert_input, segment_label]);

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
TokenEmbedding-1                      [-1, 500, 768]           23,040,000
PositionalEmbedding-2                 [-1, 500, 768]                    0
SegmentEmbedding-3                    [-1, 500, 768]                2,304
Dropout-4                             [-1, 500, 768]                    0
LayerNorm-5                           [-1, 500, 768]                1,536
Linear-6                              [-1, 500, 768]              590,592
Linear-7                              [-1, 500, 768]              590,592
Linear-8                              [-1, 500, 768]              590,592
Dropout-9                         [-1, 12, 500, 500]                    0
Linear-10                             [-1, 500, 768]              590,592
Dropout-11                            [-1, 500, 768]                    0
LayerNorm-12                         

### Pre-training

In [19]:
pre_train_net = BERTLM(net, vocab_size).to(device)

summary(pre_train_net, input_data_args=[bert_input, segment_label]);

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
TokenEmbedding-1                      [-1, 500, 768]           23,040,000
PositionalEmbedding-2                 [-1, 500, 768]                    0
SegmentEmbedding-3                    [-1, 500, 768]                2,304
Dropout-4                             [-1, 500, 768]                    0
LayerNorm-5                           [-1, 500, 768]                1,536
Linear-6                              [-1, 500, 768]              590,592
Linear-7                              [-1, 500, 768]              590,592
Linear-8                              [-1, 500, 768]              590,592
Dropout-9                         [-1, 12, 500, 500]                    0
Linear-10                             [-1, 500, 768]              590,592
Dropout-11                            [-1, 500, 768]                    0
LayerNorm-12                         

## BERT<sub>LARGE</sub>

### Model Architecture

In [20]:
net = BERT(
    vocab_size=vocab_size,
    hidden=1024,
    n_layers=24,
    n_heads=16,
    dropout=0.1,
    max_len=512,
).to(device)

summary(net, input_data_args=[bert_input, segment_label]);

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
TokenEmbedding-1                     [-1, 500, 1024]           30,720,000
PositionalEmbedding-2                [-1, 500, 1024]                    0
SegmentEmbedding-3                   [-1, 500, 1024]                3,072
Dropout-4                            [-1, 500, 1024]                    0
LayerNorm-5                          [-1, 500, 1024]                2,048
Linear-6                             [-1, 500, 1024]            1,049,600
Linear-7                             [-1, 500, 1024]            1,049,600
Linear-8                             [-1, 500, 1024]            1,049,600
Dropout-9                         [-1, 16, 500, 500]                    0
Linear-10                            [-1, 500, 1024]            1,049,600
Dropout-11                           [-1, 500, 1024]                    0
LayerNorm-12                         

### Pre-training

In [21]:
pre_train_net = BERTLM(net, vocab_size).to(device)

summary(pre_train_net, input_data_args=[bert_input, segment_label]);

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
TokenEmbedding-1                     [-1, 500, 1024]           30,720,000
PositionalEmbedding-2                [-1, 500, 1024]                    0
SegmentEmbedding-3                   [-1, 500, 1024]                3,072
Dropout-4                            [-1, 500, 1024]                    0
LayerNorm-5                          [-1, 500, 1024]                2,048
Linear-6                             [-1, 500, 1024]            1,049,600
Linear-7                             [-1, 500, 1024]            1,049,600
Linear-8                             [-1, 500, 1024]            1,049,600
Dropout-9                         [-1, 16, 500, 500]                    0
Linear-10                            [-1, 500, 1024]            1,049,600
Dropout-11                           [-1, 500, 1024]                    0
LayerNorm-12                         