In [None]:
import torch
from datasets import load_dataset,load_from_disk
from transformers import BertTokenizer, BertModel
from transformers import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 1.定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)
        # dataset = load_from_disk('data/ChnSentiCorp')[split]

        def f(data):
            return len(data['text']) > 30

        self.dataset = dataset.filter(f)
        # self.dataset = dataset.filter(lambda data: len(data['text']) > 30)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']

        return text

dataset = Dataset('train')
# len(dataset), dataset[0]

# 2. 加载字典和分词工具
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

# 3.定义批处理函数
def collate_fn(data):
    # 批处理增强编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=data,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=30,
                                   return_tensors='pt',
                                   return_length=True)

    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']

    # 把第15个词固定替换为mask
    labels = input_ids[:, 15].reshape(-1).clone()
    input_ids[:, 15] = tokenizer.get_vocab()[tokenizer.mask_token]

    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


# 4.数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

# for i, (input_ids, attention_mask, token_type_ids,
#         labels) in enumerate(loader):
#     break

# 5.加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

# 不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

# 6. 定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = torch.nn.Linear(768, tokenizer.vocab_size, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(tokenizer.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, input_ids, attention_mask, token_type_ids):
        # 预训练模型不进行训练，不需要反向传播，不需要计算梯度
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)

        out = self.decoder(out.last_hidden_state[:, 15])

        return out

# 7. 加载整个模型：预训练+下游任务
model = Model().to(device)

optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()


# 训练
def train():
    model.train()
    for epoch in range(2):
        for i, (input_ids, attention_mask, token_type_ids,
                labels) in enumerate(loader):
            
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)
            labels = labels.to(device)

            # 正向传播
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

            # 计算损失
            loss = criterion(out, labels)
            # 反向传播，计算梯度
            loss.backward()
            # 更新参数
            optimizer.step()
            # 梯度归零
            optimizer.zero_grad()

            if i % 50 == 0:
                out = out.argmax(dim=1)
                accuracy = (out == labels).sum().item() / len(labels)

                print(epoch, i, loss.item(), accuracy)

    torch.save(model.state_dict(), 'model/model.pt')


# 测试
def test():
    test_model = Model()  # 定义模型
    test_model.load_state_dict(torch.load('model/model.pt'))  # 加载参数
    test_model.to(device)

    test_model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader_test):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)

        if i == 15:
            break

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

        # print(tokenizer.decode(input_ids[0]))
        # print(tokenizer.decode(labels[0]), tokenizer.decode(labels[0]))

        print(f"第{i}批数据:", correct / total)


if __name__ == "__main__":
    print('start...')
    train()
    test()
    print("end...")