# 代码构建部分

In [1]:
import torch
import torch.nn as nn

In [None]:
class Self_Attention(nn.Module):
    def __init__(self, d_model, dropout):
        super(Self_Attention, self).__init__()
        self.output_LayerNorm = nn.LayerNorm(d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.Dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask):
        pass

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.attention = Self_Attention(d_model=d_model, dropout=dropout)
        self.LayerNorm = nn.LayerNorm(d_model)
        self.intermediate = nn.Linear(d_model, d_ff)
        self.output = nn.Linear(d_ff, d_model)
    def forward(self):
        pass

In [2]:
class Embedding(nn.Module):
    def __init__(self, d_model, max_len, vocab_size, dropout, pad_token_id=0):
        super(Embedding, self).__init__()
        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) 
        self.token_type_embeddings = nn.Embedding(2, d_model)
        self.dropout = nn.Dropout()
        self.register_buffer(
            "position_ids", torch.arange(max_len).expand((1, -1)), persistent=False
        )
    def forward(self, x, token_type_ids):
        seq_length = x.size(1)
        position_ids = self.position_ids[:, 0 : seq_length + 0]
        inputs_embeds = self.word_embeddings(x)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
class Bert(nn.Module):
    def __init__(self, d_model, max_len, vocab_size, d_ff, dropout=0.1):
        super(Bert, self).__init__()
        self.embeddings = Embedding(d_model, max_len, vocab_size, dropout)
        self.encoder_layer = EncoderLayer(d_model=d_model, d_ff=d_ff, dropout=dropout)
    def forward(self, input_ids, attention_mask, token_type_ids):
        pass

    def load_dict(self, safetensor_path):
        from safetensors.torch import save_file, load_file
        loaded_state = load_file(safetensor_path)
        self.embeddings.LayerNorm.weight = torch.nn.Parameter(loaded_state['bert.embeddings.LayerNorm.gamma'])
        self.embeddings.LayerNorm.bias = torch.nn.Parameter(loaded_state['bert.embeddings.LayerNorm.beta'])
        self.embeddings.position_embeddings.weight = torch.nn.Parameter(loaded_state['bert.embeddings.position_embeddings.weight'])
        self.embeddings.word_embeddings.weight = torch.nn.Parameter(loaded_state['bert.embeddings.word_embeddings.weight'])
        self.embeddings.token_type_embeddings.weight = torch.nn.Parameter(loaded_state['bert.embeddings.token_type_embeddings.weight'])
        self.eval()

# 查看模型的参数

In [None]:
model = Bert(
    d_model=768,
    max_len=512,
    vocab_size=30522,
    d_ff=3072
)

In [5]:
state_dict = model.state_dict()
for key, value in state_dict.items():
    print(f"{key}: {value.shape}")

embeddings.LayerNorm.weight: torch.Size([768])
embeddings.LayerNorm.bias: torch.Size([768])
embeddings.position_embeddings.weight: torch.Size([512, 768])
embeddings.word_embeddings.weight: torch.Size([30522, 768])
embeddings.token_type_embeddings.weight: torch.Size([2, 768])


# 验证自己写的代码输出内容

In [6]:
input_ids = torch.zeros((2, 128), dtype=torch.long)  # 批次2，序列长度128
attention_mask = torch.ones(2, 128, dtype=torch.float32)        # 无padding，全1掩码
token_type_ids = torch.zeros(2, 128, dtype=torch.long)  

In [7]:
model = Bert(
    d_model=768,
    max_len=512,
    vocab_size=30522
)
safetensor_path = "../model/bert/model.safetensors"
model.load_dict(safetensor_path)

In [8]:
embedding_output = model.embeddings(input_ids, token_type_ids)
print(embedding_output[0][0][:2])

tensor([ 0.2977, -0.6056], grad_fn=<SliceBackward0>)


# 验证官方输出内容

In [9]:
import torch
from transformers import BertTokenizer, BertModel
model_path = "../model/bert"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
input_ids = torch.zeros((2, 128), dtype=torch.long)  # 批次2，序列长度128
attention_mask = torch.ones(2, 128, dtype=torch.float32)        # 无padding，全1掩码
token_type_ids = torch.zeros(2, 128, dtype=torch.long)          # 单段落，全0
sequence_output, pooled_output = model(input_ids, attention_mask, token_type_ids)
output = model(input_ids, attention_mask, token_type_ids)

In [11]:
embedding_output = model.embeddings(input_ids=input_ids,token_type_ids=token_type_ids)
print(embedding_output[0][0][:2])

tensor([ 0.2977, -0.6056], grad_fn=<SliceBackward0>)


# 官方模型框架

In [12]:
from safetensors.torch import save_file, load_file
loaded_state = load_file("../model/bert/model.safetensors")

In [18]:
for key, value in loaded_state.items():
    print(key, value.shape)

bert.embeddings.LayerNorm.beta torch.Size([768])
bert.embeddings.LayerNorm.gamma torch.Size([768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.word_embeddings.weight torch.Size([30522, 768])
bert.encoder.layer.0.attention.output.LayerNorm.beta torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.l