# BERT源码实现与解读(Pytorch)

# 本文内容

[论文地址](https://arxiv.org/abs/1810.04805):  https://arxiv.org/abs/1810.04805

在BERT的论文中，描述的基本的都是模型如何训练等，对于本身的模型架构并没有过多的说明。这也可以理解，因为BERT架构本身也确实比较简单，就是一些TransformerEncoder的堆叠。虽然这么说，但不看代码很多人还是无法具体知道BERT是怎么样的，所以本文就来搭建一个BERT模型，并使用论文中提到的MLM任务和NSP任务对模型进行训练。

本篇需要大家有Transformer的基础，默认你已经熟悉Transformer，所以本篇会直接使用Pytorch中的`nn.Transformer`进行实现。

相关文章：

[Pytorch中 nn.Transformer的使用详解与Transformer的黑盒讲解](https://blog.csdn.net/zhaohongfei_358/article/details/126019181): https://blog.csdn.net/zhaohongfei_358/article/details/126019181

[万字逐行解析与实现Transformer，并进行德译英实战](https://blog.csdn.net/zhaohongfei_358/article/details/126085246): https://blog.csdn.net/zhaohongfei_358/article/details/126085246

# 环境准备

导入本文需要的包：

In [185]:
import math
import tqdm
import random

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [186]:
torch.__version__

'1.11.0+cu115'

# BERT模型定义

原始BERT使用了两种任务对BERT进行预训练，所以我们将数据集和训练放在后面，先把BERT模型定义出来。

## BERT Embedding

在Transformer中，对token的编码使用的是token embedding+position embedding，而在BERT中，增加了segment embedding，即文本的段落信息，在原论文中，bert的inputs是两句话，该embedding用于区分这是第一句话还是第二句话。我们先将这3中Embedding定义出来：

In [187]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size):
        super().__init__(vocab_size, embed_size, padding_idx=0)

Token Embedding就是一个`nn.Emebdding`，和Transformer一致。

In [188]:
class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]

Position Embedding和Transformer也一致。

In [189]:
class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size=512):
        super().__init__(3, embed_size, padding_idx=0)

Segment Embedding也是一个nn.Embedding，但需要注意的是其词典大小只有3，其中0是填充，1代表第一句话，2代表第二句话。

定义完上面三个Embedding类，就可以把BERT Embedding类定义出来了，其比较简单，就是它们三个相加：

In [192]:
class BERTEmbedding(nn.Module):

    def __init__(self, vocab_size, embed_size, dropout=0.1):
        """
        :param vocab_size: token的词典大小
        :param embed_size: 词向量大小
        """
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
        self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

## BERT

In [193]:
class BERT(nn.Module):

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        # 论文中提到它们使用的feed_forward_hidden的大小为hidde_size*4
        feed_forward_hidden = hidden * 4

        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)

        # 在论文中提到，BERT中Transformer的激活函数使用的是GELU
        transformer_encoder = nn.TransformerEncoderLayer(d_model=hidden, nhead=attn_heads, dim_feedforward=feed_forward_hidden, dropout=dropout, activation=F.gelu, batch_first=True)

        # 多层TransformerEncoder堆叠
        self.transformer_blocks = nn.ModuleList([transformer_encoder for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # torch.ByteTensor([batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks
        for transformer in self.transformer_blocks:
            print("x size:", x.size())
            print("mask size:", mask.size())
            x = transformer.forward(x, mask)

        return x

In [None]:
bert = BERT(vocab_size=20000, hidden=1024, n_layers=24, attn_heads=16)

In [194]:
x = torch.tensor([[1,2,3,4,5,5,4,3,2,1, 0,0]])
segment_info = torch.tensor([[1,1,1,1,1,2,2,2,2,2, 0,0]])
bert(x, segment_info).size()

x size: torch.Size([1, 12, 1024])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 16, 12, 64])
key.size: torch.Size([1, 16, 12, 64])
value.size: torch.Size([1, 16, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 1024])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 16, 12, 64])
key.size: torch.Size([1, 16, 12, 64])
value.size: torch.Size([1, 16, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 1024])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 16, 12, 64])
key.size: torch.Size([1, 16, 12, 64])
value.size: torch.Size([1, 16, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 1024])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 16, 12, 64])
key.size: torch.Size([1, 16, 12, 64])
value.size: torch.Size([1, 16, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 1024])
mask size: torch.Size([1, 1, 12, 12])
query.size: torc

torch.Size([1, 12, 1024])

# Model

## Attention

In [136]:
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """

    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn

In [137]:
class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]

        print("query.size:", query.size())
        print("key.size:", key.size())
        print("value.size:", value.size())
        print("mask.size:", mask.size())
        # 2) Apply attention on all the projected vectors in batch.
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)

## Embedding

## Transformer

In [142]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

In [143]:
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [144]:
class GELU(nn.Module):
    """
    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
    """

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

In [145]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()

    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(x))))

In [146]:
class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

## BERT

In [147]:
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        # paper noted they used 4*hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = hidden * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)

        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x, segment_info):
        # attention masking for padded token
        # torch.ByteTensor([batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        # running over multiple transformer blocks
        for transformer in self.transformer_blocks:
            print("x size:", x.size())
            print("mask size:", mask.size())
            x = transformer.forward(x, mask)

        return x

In [181]:
bert = BERT(vocab_size=20000, hidden=1024, n_layers=24, attn_heads=16)

In [172]:
# bert

In [173]:
x = torch.tensor([[1,2,3,4,5,5,4,3,2,1, 0,0]])
segment_info = torch.tensor([[1,1,1,1,1,2,2,2,2,2, 0,0]])
bert(x, segment_info).size()

x size: torch.Size([1, 12, 768])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 12, 12, 64])
key.size: torch.Size([1, 12, 12, 64])
value.size: torch.Size([1, 12, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 768])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 12, 12, 64])
key.size: torch.Size([1, 12, 12, 64])
value.size: torch.Size([1, 12, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 768])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 12, 12, 64])
key.size: torch.Size([1, 12, 12, 64])
value.size: torch.Size([1, 12, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 768])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Size([1, 12, 12, 64])
key.size: torch.Size([1, 12, 12, 64])
value.size: torch.Size([1, 12, 12, 64])
mask.size: torch.Size([1, 1, 12, 12])
x size: torch.Size([1, 12, 768])
mask size: torch.Size([1, 1, 12, 12])
query.size: torch.Siz

torch.Size([1, 12, 768])

In [182]:
sum([param.nelement() for param in bert.parameters()])

322792448