In [1]:
from bert.bert_model import BertModel
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import math
from transformers import BertTokenizer
input_path="./data/bert_output_data2.json"
class MyDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                self.data.append(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            "input_ids": torch.tensor(item["input_ids"]),
            "input_mask": torch.tensor(item["input_mask"]),
            "segment_ids": torch.tensor(item["segment_ids"]),
            "masked_lm_ids": torch.tensor(item["masked_lm_ids"]),
            "masked_lm_positions": torch.tensor(item["masked_lm_positions"]),
            "masked_lm_weights": torch.tensor(item["masked_lm_weights"]),
            "next_sentence_labels": torch.tensor(item["next_sentence_labels"])
        }

class BertConfig:
    def __init__(self, vocab_size, hidden_size=144, num_hidden_layers=3, num_attention_heads=12,
                 intermediate_size=512, hidden_act='gelu', hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1, max_position_embeddings=512,
                 type_vocab_size=2, initializer_range=0.02):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range



In [2]:

train_dataset = MyDataset(input_path)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
for data in train_loader:
    data
    break
print(data["input_ids"].shape)
print(data["input_mask"].shape)
input_ids=data["input_ids"]
token_type_ids=data["segment_ids"]
attention_mask=data["input_mask"]
masked_lm_positions=data["masked_lm_positions"]
masked_lm_ids=data["masked_lm_ids"]
masked_lm_weights=data["masked_lm_weights"]
next_sentence_labels=data["next_sentence_labels"]
data

torch.Size([32, 512])
torch.Size([32, 512])


{'input_ids': tensor([[  101,  4263,  2209,  ...,     0,     0,     0],
         [  101,  1462,   103,  ...,  1469,   821,   102],
         [  101,  7216,   677,  ...,     0,     0,     0],
         ...,
         [  101,   782,  1920,  ...,  4852,   833,   102],
         [  101, 19688,  6371,  ...,   100,   100,   102],
         [  101,  4638,  3175,  ...,  2821,   103,   102]]),
 'input_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'segment_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]]),
 'masked_lm_ids': tensor([[8205,  782,  807, 6848, 1999,  868, 4638, 8024,  517,  511, 1068,  704,
          1059, 6375, 4638, 2798,  83

In [3]:
def gather_indexes(sequence_tensor, positions):
    sequence_shape = list(sequence_tensor.shape)
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = torch.arange(0, batch_size, dtype=torch.int64).reshape(-1, 1) * seq_length
    flat_positions = (positions + flat_offsets).reshape(-1)
    flat_sequence_tensor = sequence_tensor.reshape(batch_size * seq_length, width)
    output_tensor = flat_sequence_tensor[flat_positions]
    return output_tensor

In [4]:
input_=torch.rand(2,6,10)
positions=torch.tensor([0,3,5])#torch.arange(10)
print(input_)
gather_indexes(input_,positions)

tensor([[[0.4812, 0.7262, 0.2435, 0.5518, 0.9686, 0.1407, 0.0784, 0.4719,
          0.6008, 0.7062],
         [0.4333, 0.2983, 0.3437, 0.6014, 0.0281, 0.1652, 0.1080, 0.0771,
          0.3824, 0.0508],
         [0.9324, 0.5528, 0.4446, 0.5430, 0.5504, 0.4381, 0.9909, 0.2223,
          0.0389, 0.1210],
         [0.6169, 0.9299, 0.5685, 0.5774, 0.2957, 0.0931, 0.8956, 0.4808,
          0.1888, 0.9418],
         [0.0100, 0.8559, 0.0202, 0.4111, 0.6520, 0.4036, 0.6835, 0.6983,
          0.5808, 0.2298],
         [0.0159, 0.4673, 0.8373, 0.5593, 0.6574, 0.9926, 0.4005, 0.3090,
          0.4683, 0.7394]],

        [[0.6759, 0.8836, 0.5332, 0.9862, 0.3775, 0.1172, 0.7190, 0.9611,
          0.8943, 0.3026],
         [0.2646, 0.3314, 0.2111, 0.9224, 0.3554, 0.5549, 0.1032, 0.0603,
          0.7090, 0.5561],
         [0.1240, 0.1333, 0.3225, 0.9056, 0.3678, 0.5768, 0.1364, 0.9830,
          0.1813, 0.1304],
         [0.0517, 0.8949, 0.9011, 0.5888, 0.8939, 0.4046, 0.2699, 0.6114,
          0.463

tensor([[0.4812, 0.7262, 0.2435, 0.5518, 0.9686, 0.1407, 0.0784, 0.4719, 0.6008,
         0.7062],
        [0.6169, 0.9299, 0.5685, 0.5774, 0.2957, 0.0931, 0.8956, 0.4808, 0.1888,
         0.9418],
        [0.0159, 0.4673, 0.8373, 0.5593, 0.6574, 0.9926, 0.4005, 0.3090, 0.4683,
         0.7394],
        [0.6759, 0.8836, 0.5332, 0.9862, 0.3775, 0.1172, 0.7190, 0.9611, 0.8943,
         0.3026],
        [0.0517, 0.8949, 0.9011, 0.5888, 0.8939, 0.4046, 0.2699, 0.6114, 0.4632,
         0.8878],
        [0.6340, 0.9508, 0.8339, 0.5182, 0.5434, 0.1386, 0.5247, 0.9030, 0.6434,
         0.1408]])

In [6]:
tokenizer = BertTokenizer.from_pretrained('/Users/wangaijun/pythoncode/github/model/bert-base-chinese')
vocab_words = list(tokenizer.vocab.keys())
config=BertConfig(len(vocab_words))
config.vocab_size
bert_model = BertModel(config)

pooled_output, sequence_output, encoded_layers=bert_model(input_ids,token_type_ids,attention_mask)
print(pooled_output.shape,sequence_output.shape,len(encoded_layers))
print(bert_model)

torch.Size([32, 144]) torch.Size([32, 512, 144]) 3
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 144)
    (position_embeddings): Embedding(512, 144)
    (token_type_embeddings): Embedding(2, 144)
    (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=144, out_features=144, bias=True)
            (key): Linear(in_features=144, out_features=144, bias=True)
            (value): Linear(in_features=144, out_features=144, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=144, out_features=144, bias=True)
            (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
            (dropout)

# 2 mask预测损失

In [7]:
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, label_ids, label_weights):
    input_tensor = gather_indexes(input_tensor, positions)
    sequential = nn.Sequential(
        nn.Linear(bert_config.hidden_size, bert_config.hidden_size),
        nn.LayerNorm(bert_config.hidden_size),
        nn.ReLU()
    )
    input_tensor = sequential(input_tensor)  # 使用sequential处理input_tensor
    output_bias = nn.Parameter(torch.zeros(bert_config.vocab_size))
    logits = torch.matmul(input_tensor, output_weights.transpose(0, 1)) + output_bias
    log_probs = nn.functional.log_softmax(logits, dim=-1)

    label_ids = label_ids.reshape(-1)
    label_weights = label_weights.reshape(-1)
    one_hot_labels = torch.nn.functional.one_hot(label_ids, num_classes=bert_config.vocab_size).float()
    per_example_loss = -torch.sum(log_probs * one_hot_labels, dim=-1)
    numerator = torch.sum(label_weights * per_example_loss)
    denominator = torch.sum(label_weights) + 1e-5
    loss = numerator / denominator

    return loss, per_example_loss, log_probs

In [9]:
masked_lm_loss, _, _ = get_masked_lm_output(config, sequence_output, bert_model.embeddings.word_embeddings.weight,
                                            masked_lm_positions, masked_lm_ids, masked_lm_weights)
masked_lm_loss.item()

33.317161560058594

In [12]:
def get_next_sentence_output(bert_config, input_tensor, labels):
    output_weights = nn.Parameter(torch.randn(2, bert_config.hidden_size))
    output_bias = nn.Parameter(torch.zeros(2))
    logits = torch.matmul(input_tensor, output_weights.transpose(0, 1)) + output_bias
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    labels = labels.reshape(-1)
    one_hot_labels = torch.nn.functional.one_hot(labels, num_classes=2).float()
    per_example_loss = -torch.sum(one_hot_labels * log_probs, dim=-1)
    loss = torch.mean(per_example_loss)
    return loss, per_example_loss, log_probs
next_sentence_loss, _, _ = get_next_sentence_output(config, pooled_output, next_sentence_labels)

In [13]:
next_sentence_loss

tensor(4.5587, grad_fn=<MeanBackward0>)

In [14]:
loss = masked_lm_loss + next_sentence_loss
loss

tensor(37.8759, grad_fn=<AddBackward0>)

In [15]:
optimizer = optim.Adam(bert_model.parameters(), lr=5e-4)
loss.backward()
optimizer.step()

# 5 训练

In [16]:
EPOCHS=5
i=0
for epoch in range(EPOCHS):
    for data in train_loader:
        input_ids=data["input_ids"]
        token_type_ids=data["segment_ids"]
        attention_mask=data["input_mask"]
        masked_lm_positions=data["masked_lm_positions"]
        masked_lm_ids=data["masked_lm_ids"]
        masked_lm_weights=data["masked_lm_weights"]
        next_sentence_labels=data["next_sentence_labels"]
        optimizer.zero_grad()
        pooled_output, sequence_output, encoded_layers=bert_model(input_ids,token_type_ids,attention_mask)
        masked_lm_loss, _, _ = get_masked_lm_output(config, sequence_output, bert_model.embeddings.word_embeddings.weight,
                                            masked_lm_positions, masked_lm_ids, masked_lm_weights)
        next_sentence_loss, _, _ = get_next_sentence_output(config, pooled_output, next_sentence_labels)
        loss = masked_lm_loss + next_sentence_loss
        loss.backward()
        optimizer.step()
        print(f"embedding weight step {i}",bert_model.embeddings.word_embeddings.weight[:3,:5])
        if i>=3:
            break
        i+=1

embedding weight step 0 tensor([[-0.3895, -2.1538,  0.7998,  2.3261, -0.5966],
        [ 0.8771,  0.7261,  1.3902, -0.0718,  0.8855],
        [-1.0391, -0.7370, -1.7500, -0.3334, -0.2318]],
       grad_fn=<SliceBackward0>)
embedding weight step 1 tensor([[-0.3895, -2.1539,  0.7998,  2.3259, -0.5966],
        [ 0.8770,  0.7258,  1.3899, -0.0720,  0.8852],
        [-1.0391, -0.7372, -1.7500, -0.3335, -0.2318]],
       grad_fn=<SliceBackward0>)
embedding weight step 2 tensor([[-0.3895, -2.1539,  0.7998,  2.3258, -0.5966],
        [ 0.8770,  0.7256,  1.3897, -0.0721,  0.8850],
        [-1.0391, -0.7373, -1.7501, -0.3335, -0.2319]],
       grad_fn=<SliceBackward0>)
embedding weight step 3 tensor([[-0.3896, -2.1539,  0.7996,  2.3256, -0.5969],
        [ 0.8770,  0.7254,  1.3895, -0.0723,  0.8849],
        [-1.0391, -0.7376, -1.7501, -0.3336, -0.2322]],
       grad_fn=<SliceBackward0>)
embedding weight step 3 tensor([[-0.3897, -2.1540,  0.7994,  2.3254, -0.5971],
        [ 0.8770,  0.7252,  1

In [17]:
torch.save(bert_model.state_dict(), 'bert_model.pth')

In [18]:
# 模型加载
tokenizer = BertTokenizer.from_pretrained('/Users/wangaijun/pythoncode/github/model/bert-base-chinese')
vocab_words = list(tokenizer.vocab.keys())
config=BertConfig(len(vocab_words))
model = BertModel(config)

model.load_state_dict(torch.load('bert_model.pth'))
model.eval()  # 切换到评估模式

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 144)
    (position_embeddings): Embedding(512, 144)
    (token_type_embeddings): Embedding(2, 144)
    (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=144, out_features=144, bias=True)
            (key): Linear(in_features=144, out_features=144, bias=True)
            (value): Linear(in_features=144, out_features=144, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=144, out_features=144, bias=True)
            (LayerNorm): LayerNorm((144,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
       

In [19]:
pooled_output, sequence_output, encoded_layers=model(input_ids,token_type_ids,attention_mask)

In [21]:
pooled_output.shape

torch.Size([32, 144])