# 代码构建部分

In [116]:
import torch
import torch.nn as nn
import math

In [168]:
class Self_Attention(nn.Module):
    def __init__(self, d_model, n_heads, dropout):
        super(Self_Attention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.output_LayerNorm = nn.LayerNorm(d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.Dropout1 = nn.Dropout(dropout)
        self.Dropout2 = nn.Dropout(dropout)

    def forward(self, q, k, v, mask):
        residual = q  # 保存残差
        batch_size, seq_len, d_model = q.size()
        
        # 多头注意力计算（与你的逻辑一致）
        q = self.W_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(v).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.d_k)
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(1)
            attn = attn.masked_fill(mask == 0, -1e9)
        
        scores = torch.softmax(attn, dim=-1)
        scores = self.Dropout1(scores)
        
        output = torch.matmul(scores, v).transpose(2,1).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_o(output)
        output = self.Dropout2(output)
        
        # 补全残差连接（BERT的核心步骤）
        output = output + residual  # 关键：输出 + 输入残差
        output = self.output_LayerNorm(output)  # 再做LayerNorm
        
        return output

In [169]:
import torch.nn.functional as F

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, n_heads, dropout):
        super(EncoderLayer, self).__init__()
        self.attention = Self_Attention(d_model=d_model, n_heads=n_heads, dropout=dropout)
        # 前馈网络添加激活函数（BERT用GELU）
        self.intermediate = nn.Linear(d_model, d_ff)
        self.intermediate_act_fn = nn.GELU()  # 新增：激活函数
        self.output = nn.Linear(d_ff, d_model)
        self.output_LayerNorm = nn.LayerNorm(d_model)  # 更清晰的命名
        self.dropout = nn.Dropout(dropout)  # BERT在前馈网络后也会加dropout

    def forward(self, x, mask=None):
        # 1. 自注意力层（假设Self_Attention已包含残差+Norm）
        attn_output = self.attention(x, x, x, mask)
        
        # 2. 前馈网络
        ff_output = self.intermediate(attn_output)
        ff_output = self.intermediate_act_fn(ff_output)  # 应用激活函数
        ff_output = self.dropout(ff_output)  # 新增：dropout增强泛化性
        ff_output = self.output(ff_output)
        
        # 3. 前馈网络的残差+Norm
        ff_output = ff_output + attn_output  # 残差连接（基准是自注意力的输出）
        ff_output = self.output_LayerNorm(ff_output)  # LayerNorm
        
        return ff_output

In [170]:
class Embedding(nn.Module):
    def __init__(self, d_model, max_len, vocab_size, dropout, pad_token_id=0):
        super(Embedding, self).__init__()
        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.word_embeddings = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id) 
        self.token_type_embeddings = nn.Embedding(2, d_model)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "position_ids", torch.arange(max_len).expand((1, -1)), persistent=False
        )
    def forward(self, x, token_type_ids):
        seq_length = x.size(1)
        position_ids = self.position_ids[:, 0 : seq_length + 0]
        inputs_embeds = self.word_embeddings(x)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        embeddings = inputs_embeds + token_type_embeddings
        position_embeddings = self.position_embeddings(position_ids)
        embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [188]:
class BertPooler(nn.Module):
    """对应bert.pooler.dense参数"""
    def __init__(self, d_model):
        super().__init__()
        self.dense = nn.Linear(d_model, d_model)  # weight: [768,768], bias: [768]
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # hidden_states: [batch_size, seq_len, d_model]
        # 取[CLS] token的输出（第一个位置）
        cls_token = hidden_states[:, 0, :]  # [batch_size, d_model]
        pooled_output = self.dense(cls_token)  # 线性变换
        pooled_output = self.activation(pooled_output)  # tanh激活
        return pooled_output

In [236]:
class BertPredictionHeadTransform(nn.Module):
    """对应cls.predictions.transform参数"""
    def __init__(self, d_model):
        super().__init__()
        self.dense = nn.Linear(d_model, d_model)  # weight: [768,768], bias: [768]
        self.LayerNorm = nn.LayerNorm(d_model, eps=1e-12)  # gamma: [768], beta: [768]

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = F.gelu(hidden_states)  # BERT用GELU激活
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
    """对应cls.predictions参数（MLM任务头）"""
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.transform = BertPredictionHeadTransform(d_model)
        # 预测头：映射到词表大小（复用词嵌入矩阵的转置，这里先定义为独立参数）
        self.decoder = nn.Linear(d_model, vocab_size, bias=False)  # 实际BERT中会复用embedding权重
        self.bias = nn.Parameter(torch.zeros(vocab_size))  # cls.predictions.bias: [30522]

    def forward(self, hidden_states):
        # hidden_states: [batch_size, seq_len, d_model]（编码器输出）
        hidden_states = self.transform(hidden_states)  # 特征变换
        logits = self.decoder(hidden_states) + self.bias  # 计算词表logits
        return logits


class BertOnlyNSPHead(nn.Module):
    """对应cls.seq_relationship参数（NSP任务头）"""
    def __init__(self, d_model):
        super().__init__()
        self.seq_relationship = nn.Linear(d_model, 2)  # weight: [2,768], bias: [2]

    def forward(self, pooled_output):
        # pooled_output: [batch_size, d_model]（池化层输出）
        logits = self.seq_relationship(pooled_output)  # 二分类logits
        return logits
class BertPreTrainingHeads(nn.Module):
    """整合MLM和NSP任务头，对应cls整体"""
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.predictions = BertLMPredictionHead(d_model, vocab_size)  # MLM头
        self.seq_relationship = BertOnlyNSPHead(d_model)  # NSP头

    def forward(self, hidden_states, pooled_output):
        mlm_logits = self.predictions(hidden_states)  # [batch_size, seq_len, vocab_size]
        nsp_logits = self.seq_relationship(pooled_output)  # [batch_size, 2]
        return hidden_states, pooled_output

In [237]:
class Bert(nn.Module):
    def __init__(self, d_model, max_len, vocab_size, d_ff, n_heads, num_hidden_layers, dropout=0.1):
        super(Bert, self).__init__()
        self.embeddings = Embedding(d_model, max_len, vocab_size, dropout)
        self.encoder = nn.ModuleList([EncoderLayer(d_model=d_model, d_ff=d_ff, n_heads=n_heads, dropout=dropout) for _ in range(num_hidden_layers)])
        self.pooler = BertPooler(d_model)
        self.cls = BertPreTrainingHeads(d_model, vocab_size)
    def forward(self, input_ids, token_type_ids, attention_mask=None):
        embedding_output = self.embeddings(input_ids, token_type_ids)
        
        encoder_output = embedding_output
        for layer in self.encoder:
            encoder_output = layer(encoder_output, attention_mask)
        pooled_output = self.pooler(encoder_output)
        mlm_logits, nsp_logits = self.cls(encoder_output, pooled_output)
        return mlm_logits, nsp_logits

    def load_dict(self, safetensor_path):
        from safetensors.torch import save_file, load_file
        loaded_state = load_file(safetensor_path)
        self.embeddings.LayerNorm.weight = torch.nn.Parameter(loaded_state['bert.embeddings.LayerNorm.gamma'])
        self.embeddings.LayerNorm.bias = torch.nn.Parameter(loaded_state['bert.embeddings.LayerNorm.beta'])
        self.embeddings.position_embeddings.weight = torch.nn.Parameter(loaded_state['bert.embeddings.position_embeddings.weight'])
        self.embeddings.word_embeddings.weight = torch.nn.Parameter(loaded_state['bert.embeddings.word_embeddings.weight'])
        self.embeddings.token_type_embeddings.weight = torch.nn.Parameter(loaded_state['bert.embeddings.token_type_embeddings.weight'])
        for i in range(12):
            self.encoder[i].attention.output_LayerNorm.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.output.LayerNorm.gamma'])
            self.encoder[i].attention.output_LayerNorm.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.output.LayerNorm.beta'])
            self.encoder[i].attention.W_o.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.output.dense.weight'])
            self.encoder[i].attention.W_o.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.output.dense.bias'])
            self.encoder[i].attention.W_q.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.self.query.weight'])
            self.encoder[i].attention.W_q.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.self.query.bias'])
            self.encoder[i].attention.W_k.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.self.key.weight'])
            self.encoder[i].attention.W_k.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.self.key.bias'])
            self.encoder[i].attention.W_v.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.self.value.weight'])
            self.encoder[i].attention.W_v.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.attention.self.value.bias'])
            self.encoder[i].intermediate.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.intermediate.dense.weight'])
            self.encoder[i].intermediate.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.intermediate.dense.bias'])
            self.encoder[i].output.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.output.dense.weight'])
            self.encoder[i].output.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.output.dense.bias'])
            self.encoder[i].output_LayerNorm.weight = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.output.LayerNorm.gamma'])
            self.encoder[i].output_LayerNorm.bias = torch.nn.Parameter(loaded_state[f'bert.encoder.layer.{i}.output.LayerNorm.beta'])
        self.pooler.dense.weight = torch.nn.Parameter(loaded_state['bert.pooler.dense.weight'])
        self.pooler.dense.bias = torch.nn.Parameter(loaded_state['bert.pooler.dense.bias'])
        self.cls.predictions.transform.dense.weight = torch.nn.Parameter(loaded_state['cls.predictions.transform.dense.weight'])
        self.cls.predictions.transform.dense.bias = torch.nn.Parameter(loaded_state['cls.predictions.transform.dense.bias'])
        self.cls.predictions.transform.LayerNorm.weight = torch.nn.Parameter(loaded_state['cls.predictions.transform.LayerNorm.gamma'])
        self.cls.predictions.transform.LayerNorm.bias = torch.nn.Parameter(loaded_state['cls.predictions.transform.LayerNorm.beta'])
        self.cls.predictions.decoder.weight = torch.nn.Parameter(loaded_state['bert.embeddings.word_embeddings.weight'])
        self.cls.predictions.bias = torch.nn.Parameter(loaded_state['cls.predictions.bias'])
        self.cls.seq_relationship.seq_relationship.weight = torch.nn.Parameter(loaded_state['cls.seq_relationship.weight'])
        self.cls.seq_relationship.seq_relationship.bias = torch.nn.Parameter(loaded_state['cls.seq_relationship.bias'])
        self.eval()

# 查看模型的参数

In [158]:
model = Bert(
    d_model=768,
    max_len=512,
    vocab_size=30522,
    d_ff=3072,
    n_heads=12,
    num_hidden_layers=12,
)

In [159]:
state_dict = model.state_dict()
for key, value in state_dict.items():
    print(f"{key}: {value.shape}")

embeddings.LayerNorm.weight: torch.Size([768])
embeddings.LayerNorm.bias: torch.Size([768])
embeddings.position_embeddings.weight: torch.Size([512, 768])
embeddings.word_embeddings.weight: torch.Size([30522, 768])
embeddings.token_type_embeddings.weight: torch.Size([2, 768])
encoder.0.attention.output_LayerNorm.weight: torch.Size([768])
encoder.0.attention.output_LayerNorm.bias: torch.Size([768])
encoder.0.attention.W_o.weight: torch.Size([768, 768])
encoder.0.attention.W_o.bias: torch.Size([768])
encoder.0.attention.W_q.weight: torch.Size([768, 768])
encoder.0.attention.W_q.bias: torch.Size([768])
encoder.0.attention.W_k.weight: torch.Size([768, 768])
encoder.0.attention.W_k.bias: torch.Size([768])
encoder.0.attention.W_v.weight: torch.Size([768, 768])
encoder.0.attention.W_v.bias: torch.Size([768])
encoder.0.LayerNorm.weight: torch.Size([768])
encoder.0.LayerNorm.bias: torch.Size([768])
encoder.0.intermediate.weight: torch.Size([3072, 768])
encoder.0.intermediate.bias: torch.Size([30

# 验证自己写的代码输出内容

In [238]:
input_ids = torch.zeros((2, 128), dtype=torch.long)  # 批次2，序列长度128
attention_mask = torch.ones(2, 128, dtype=torch.float32)        # 无padding，全1掩码
# attention_mask[0][-1]=0
# attention_mask[1][-1]=0
token_type_ids = torch.zeros(2, 128, dtype=torch.long)  

In [239]:
model = Bert(
    d_model=768,
    max_len=512,
    vocab_size=30522,
    d_ff=3072,
    n_heads=12,
    num_hidden_layers=12,
)
safetensor_path = "../model/bert/model.safetensors"
model.load_dict(safetensor_path)

In [240]:
embedding_output = model.embeddings(input_ids, token_type_ids)
encoder_output = model(input_ids, token_type_ids, attention_mask)
# print(embedding_output[0][0][:2])
print(encoder_output[0].shape)  # MLM logits
print(encoder_output[0][0][0][:10])
print(encoder_output[1])  # NSP logits

torch.Size([2, 128, 768])
tensor([-0.3210, -0.1248, -0.2172,  0.5433, -0.5840, -0.5502,  0.2830,  0.3881,
         0.0953, -1.1299], grad_fn=<SliceBackward0>)
tensor([[-0.1242, -0.3776, -0.8278,  ..., -0.8794, -0.6126, -0.0377],
        [-0.1242, -0.3776, -0.8278,  ..., -0.8794, -0.6126, -0.0377]],
       grad_fn=<TanhBackward0>)


# 验证官方输出内容

In [241]:
import torch
from transformers import BertTokenizer, BertModel
model_path = "../model/bert"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)
encoder_output = model(input_ids, attention_mask, token_type_ids)

In [242]:
input_ids = torch.zeros((2, 128), dtype=torch.long)  # 批次2，序列长度128
attention_mask = torch.ones(2, 128, dtype=torch.float32)        # 无padding，全1掩码
token_type_ids = torch.zeros(2, 128, dtype=torch.long)          # 单段落，全0
# attention_mask[0][-1]=0
# attention_mask[1][-1]=0
sequence_output, pooled_output = model(input_ids, attention_mask, token_type_ids)
encoder_output = model(input_ids, attention_mask, token_type_ids)
print(encoder_output[0].shape)  # MLM logits
print(encoder_output[0])
print(encoder_output[1])

torch.Size([2, 128, 768])
tensor([[[-0.3210, -0.1248, -0.2172,  ..., -0.5346,  0.7774, -0.2815],
         [-0.2491, -0.1476, -0.1710,  ..., -0.5708,  0.7885, -0.1987],
         [-0.2603, -0.1823, -0.1663,  ..., -0.5687,  0.7649, -0.2008],
         ...,
         [-0.3194, -0.2450, -0.0910,  ..., -0.5497,  0.8502, -0.1324],
         [-0.3259, -0.2308, -0.0936,  ..., -0.5407,  0.8343, -0.1311],
         [-0.3107, -0.1982, -0.0991,  ..., -0.5395,  0.8027, -0.1485]],

        [[-0.3210, -0.1248, -0.2172,  ..., -0.5346,  0.7774, -0.2815],
         [-0.2491, -0.1476, -0.1710,  ..., -0.5708,  0.7885, -0.1987],
         [-0.2603, -0.1823, -0.1663,  ..., -0.5687,  0.7649, -0.2008],
         ...,
         [-0.3194, -0.2450, -0.0910,  ..., -0.5497,  0.8502, -0.1324],
         [-0.3259, -0.2308, -0.0936,  ..., -0.5407,  0.8343, -0.1311],
         [-0.3107, -0.1982, -0.0991,  ..., -0.5395,  0.8027, -0.1485]]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[-0.1242, -0.3776, -0.8278,  ..., -0.87

In [202]:
embedding_output = model.embeddings(input_ids=input_ids,token_type_ids=token_type_ids)
encoder_output = model(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask)
print(embedding_output[0][0][:2])
print(encoder_output[0].shape)
print(encoder_output[0][0][0][:10])

TypeError: Embedding.forward() got an unexpected keyword argument 'input_ids'

# 官方模型框架

In [39]:
from safetensors.torch import save_file, load_file
loaded_state = load_file("../model/bert/model.safetensors")

In [40]:
for key, value in loaded_state.items():
    print(key, value.shape)

bert.embeddings.LayerNorm.beta torch.Size([768])
bert.embeddings.LayerNorm.gamma torch.Size([768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.word_embeddings.weight torch.Size([30522, 768])
bert.encoder.layer.0.attention.output.LayerNorm.beta torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.gamma torch.Size([768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.l