In [1]:
from bert.bert_model import BertModel
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import math
from transformers import BertTokenizer
input_path="./data/bert_output_data2.json"
class MyDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                self.data.append(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"]),
            "input_mask": torch.tensor(item["input_mask"]),
            "segment_ids": torch.tensor(item["segment_ids"]),
            "masked_lm_ids": torch.tensor(item["masked_lm_ids"]),
            "masked_lm_positions": torch.tensor(item["masked_lm_positions"]),
            "masked_lm_weights": torch.tensor(item["masked_lm_weights"]),
            "next_sentence_labels": torch.tensor(item["next_sentence_labels"])
        }

class BertConfig:
    def __init__(self, vocab_size, hidden_size=144, num_hidden_layers=3, num_attention_heads=12,
                 intermediate_size=512, hidden_act='gelu', hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1, max_position_embeddings=512,
                 type_vocab_size=2, initializer_range=0.02):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range



In [2]:

train_dataset = MyDataset(input_path)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
for data in train_loader:
    data
    break
print(data["input_ids"].shape)
print(data["input_mask"].shape)
input_ids=data["input_ids"]
token_type_ids=data["segment_ids"]
attention_mask=data["input_mask"]
masked_lm_positions=data["masked_lm_positions"]
masked_lm_ids=data["masked_lm_ids"]
masked_lm_weights=data["masked_lm_weights"]
next_sentence_labels=data["next_sentence_labels"]
data

torch.Size([32, 512])
torch.Size([32, 512])


{'input_ids': tensor([[ 101, 4777,  782,  ..., 1912,  769,  102],
         [ 101, 3688, 3813,  ...,  712, 2476,  102],
         [ 101, 1062, 1066,  ...,    0,    0,    0],
         ...,
         [ 101, 6432, 1168,  ...,    0,    0,    0],
         [ 101,  677,  103,  ...,    0,    0,    0],
         [ 101, 4638, 2356,  ..., 1171,  520,  102]]),
 'input_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'segment_ids': tensor([[0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1]]),
 'masked_lm_ids': tensor([[3136, 1187, 1447, 2157, 4777, 4955, 4060, 2533,  511,  511,  511, 1054,
           833, 1447, 3341, 4638, 6789, 3341, 2456,  704],
         [ 511

In [3]:
def gather_indexes(sequence_tensor, positions):
    sequence_shape = list(sequence_tensor.shape)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = torch.arange(0, batch_size, dtype=torch.int64).reshape(-1, 1) * seq_length
    flat_positions = (positions + flat_offsets).reshape(-1)
    flat_sequence_tensor = sequence_tensor.reshape(batch_size * seq_length, width)
    output_tensor = flat_sequence_tensor[flat_positions]
    return output_tensor

In [4]:
input_=torch.rand(2,6,10)
positions=torch.tensor([0,3,5])#torch.arange(10)
print(input_)
gather_indexes(input_,positions)

tensor([[[0.6192, 0.9202, 0.4370, 0.6178, 0.1265, 0.6543, 0.1944, 0.4210,
          0.7286, 0.3453],
         [0.2426, 0.7143, 0.2243, 0.8417, 0.8804, 0.9084, 0.6356, 0.0085,
          0.1740, 0.3323],
         [0.9957, 0.7029, 0.5443, 0.8433, 0.4333, 0.3006, 0.3595, 0.1784,
          0.7058, 0.1251],
         [0.1414, 0.8563, 0.4120, 0.5211, 0.9991, 0.1006, 0.6647, 0.1092,
          0.6681, 0.2601],
         [0.3065, 0.9432, 0.5592, 0.3568, 0.4359, 0.0126, 0.4728, 0.3069,
          0.4900, 0.6574],
         [0.5425, 0.8609, 0.2502, 0.6821, 0.8000, 0.2218, 0.9561, 0.6038,
          0.9912, 0.1466]],

        [[0.3784, 0.5543, 0.0117, 0.9622, 0.2962, 0.3833, 0.5392, 0.0209,
          0.2235, 0.2522],
         [0.9660, 0.6147, 0.1299, 0.3373, 0.6129, 0.8223, 0.4620, 0.8901,
          0.4853, 0.2207],
         [0.3283, 0.0312, 0.1678, 0.6374, 0.2692, 0.1102, 0.8474, 0.3824,
          0.9538, 0.2237],
         [0.8582, 0.1526, 0.0985, 0.8792, 0.8305, 0.0206, 0.9023, 0.0195,
          0.287

tensor([[0.6192, 0.9202, 0.4370, 0.6178, 0.1265, 0.6543, 0.1944, 0.4210, 0.7286,
         0.3453],
        [0.1414, 0.8563, 0.4120, 0.5211, 0.9991, 0.1006, 0.6647, 0.1092, 0.6681,
         0.2601],
        [0.5425, 0.8609, 0.2502, 0.6821, 0.8000, 0.2218, 0.9561, 0.6038, 0.9912,
         0.1466],
        [0.3784, 0.5543, 0.0117, 0.9622, 0.2962, 0.3833, 0.5392, 0.0209, 0.2235,
         0.2522],
        [0.8582, 0.1526, 0.0985, 0.8792, 0.8305, 0.0206, 0.9023, 0.0195, 0.2873,
         0.7935],
        [0.1993, 0.7979, 0.3765, 0.3066, 0.0966, 0.6415, 0.7462, 0.3115, 0.9905,
         0.2307]])

In [5]:
tokenizer = BertTokenizer.from_pretrained('/Users/wangaijun/pythoncode/github/model/bert-base-chinese')
vocab_words = list(tokenizer.vocab.keys())
config=BertConfig(len(vocab_words))
config.vocab_size
bert_model = BertModel(config)x

pooled_output, sequence_output, encoded_layers=bert_model(input_ids,token_type_ids,attention_mask)
print(pooled_output.shape,sequence_output.shape,len(encoded_layers))
print(bert_model)

torch.Size([32, 144]) torch.Size([32, 512, 144]) 3
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 144)
    (position_embeddings): Embedding(512, 144)
    (token_type_embeddings): Embedding(2, 144)
    (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=144, out_features=144, bias=True)
            (key): Linear(in_features=144, out_features=144, bias=True)
            (value): Linear(in_features=144, out_features=144, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=144, out_features=144, bias=True)
            (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
            (dropout)

# 2 mask预测损失

In [6]:
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, label_ids, label_weights):
    input_tensor = gather_indexes(input_tensor, positions)
    sequential = nn.Sequential(
        nn.Linear(bert_config.hidden_size, bert_config.hidden_size),
        nn.LayerNorm(bert_config.hidden_size),
        nn.ReLU()
    )
    input_tensor = sequential(input_tensor)  # 使用sequential处理input_tensor
    output_bias = nn.Parameter(torch.zeros(bert_config.vocab_size))
    logits = torch.matmul(input_tensor, output_weights.transpose(0, 1)) + output_bias
    log_probs = nn.functional.log_softmax(logits, dim=-1)

    label_ids = label_ids.reshape(-1)
    label_weights = label_weights.reshape(-1)
    one_hot_labels = torch.nn.functional.one_hot(label_ids, num_classes=bert_config.vocab_size).float()
    per_example_loss = -torch.sum(log_probs * one_hot_labels, dim=-1)
    numerator = torch.sum(label_weights * per_example_loss)
    denominator = torch.sum(label_weights) + 1e-5
    loss = numerator / denominator

    return loss, per_example_loss, log_probs

In [7]:
masked_lm_loss, _, _ = get_masked_lm_output(config, sequence_output, bert_model.embeddings.word_embeddings.weight,
                                            masked_lm_positions, masked_lm_ids, masked_lm_weights)
masked_lm_loss.item()

34.76128387451172

## 3 是否下一个句子损失

In [8]:
def get_next_sentence_output(bert_config, input_tensor, labels):
    output_weights = nn.Parameter(torch.randn(2, bert_config.hidden_size))
    output_bias = nn.Parameter(torch.zeros(2))
    logits = torch.matmul(input_tensor, output_weights.transpose(0, 1)) + output_bias
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    labels = labels.reshape(-1)
    one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=2).float()
    per_example_loss = -torch.sum(one_hot_labels * log_probs, dim=-1)
    loss = torch.mean(per_example_loss)
    return loss, per_example_loss, log_probs
next_sentence_loss, _, _ = get_next_sentence_output(config, pooled_output, next_sentence_labels)

In [9]:
next_sentence_loss

tensor(2.9306, grad_fn=<MeanBackward0>)

# 4 损失及优化器

In [10]:
loss = masked_lm_loss + next_sentence_loss
loss

tensor(37.6919, grad_fn=<AddBackward0>)

In [11]:
optimizer = optim.Adam(bert_model.parameters(), lr=5e-4)
loss.backward()
optimizer.step()

# 5 训练

In [12]:
EPOCHS=5
i=0
for epoch in range(EPOCHS):
    for data in train_loader:
        input_ids=data["input_ids"]
        token_type_ids=data["segment_ids"]
        attention_mask=data["input_mask"]
        masked_lm_positions=data["masked_lm_positions"]
        masked_lm_ids=data["masked_lm_ids"]
        masked_lm_weights=data["masked_lm_weights"]
        next_sentence_labels=data["next_sentence_labels"]
        optimizer.zero_grad()
        pooled_output, sequence_output, encoded_layers=bert_model(input_ids,token_type_ids,attention_mask)
        masked_lm_loss, _, _ = get_masked_lm_output(config, sequence_output, bert_model.embeddings.word_embeddings.weight,
                                            masked_lm_positions, masked_lm_ids, masked_lm_weights)
        next_sentence_loss, _, _ = get_next_sentence_output(config, pooled_output, next_sentence_labels)
        loss = masked_lm_loss + next_sentence_loss
        loss.backward()
        optimizer.step()
        print(f"embedding weight step {i}",bert_model.embeddings.word_embeddings.weight[:3,:5])
        if i>=3:
            break
        i+=1

embedding weight step 0 tensor([[-1.1905, -0.4920,  0.4736,  0.7318, -1.4121],
        [-2.1571,  1.3470,  0.6932, -0.5502, -0.8680],
        [ 0.7587, -1.1549, -0.5255, -0.2983,  0.1872]],
       grad_fn=<SliceBackward0>)
embedding weight step 1 tensor([[-1.1905, -0.4920,  0.4736,  0.7318, -1.4121],
        [-2.1571,  1.3470,  0.6932, -0.5502, -0.8680],
        [ 0.7587, -1.1549, -0.5255, -0.2983,  0.1872]],
       grad_fn=<SliceBackward0>)
embedding weight step 2 tensor([[-1.1905, -0.4920,  0.4736,  0.7318, -1.4121],
        [-2.1571,  1.3470,  0.6932, -0.5502, -0.8680],
        [ 0.7587, -1.1549, -0.5255, -0.2983,  0.1872]],
       grad_fn=<SliceBackward0>)
embedding weight step 3 tensor([[-1.1905, -0.4920,  0.4736,  0.7318, -1.4121],
        [-2.1571,  1.3470,  0.6932, -0.5502, -0.8680],
        [ 0.7587, -1.1549, -0.5255, -0.2983,  0.1872]],
       grad_fn=<SliceBackward0>)
embedding weight step 3 tensor([[-1.1905, -0.4920,  0.4736,  0.7318, -1.4121],
        [-2.1571,  1.3470,  0

# 6 模型保存与加载

In [14]:
torch.save(bert_model.state_dict(), 'bert_model.pth')

# 模型加载
tokenizer = BertTokenizer.from_pretrained('/Users/wangaijun/pythoncode/github/model/bert-base-chinese')
vocab_words = list(tokenizer.vocab.keys())
config=BertConfig(len(vocab_words))
model = BertModel(config)
model.load_state_dict(torch.load('bert_model.pth'))
model.eval()  # 切换到评估模式

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 144)
    (position_embeddings): Embedding(512, 144)
    (token_type_embeddings): Embedding(2, 144)
    (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=144, out_features=144, bias=True)
            (key): Linear(in_features=144, out_features=144, bias=True)
            (value): Linear(in_features=144, out_features=144, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=144, out_features=144, bias=True)
            (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
       

In [16]:
pooled_output, sequence_output, encoded_layers=model(input_ids,token_type_ids,attention_mask)
pooled_output.shape

torch.Size([32, 144])