# [Bert-base-Chinese 微调](https://blog.csdn.net/qq_43668800/article/details/131921617)

In [1]:
import torch
from torch.optim import AdamW
from datasets import load_dataset
from transformers import BertModel, BertTokenizer

## 模型

In [2]:
# 优先使用 GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device=', device)

device= cpu


In [3]:
# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')
# 需要移动到cuda上
pretrained.to(device)

# 不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

In [4]:
# 定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)
        return out


model = Model()
# 同样要移动到cuda
model.to(device)

Model(
  (fc): Linear(in_features=768, out_features=2, bias=True)
)

## 数据集

In [5]:
# 定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, split):
        self.dataset = load_dataset('lansinuote/ChnSentiCorp')[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label


dataset = Dataset('train')

In [6]:
# 加载字典和分词工具
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [7]:
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    # 编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)
    # input_ids:编码之后的数字
    # attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)
    #print(data['length'], data['length'].max())
    return input_ids, attention_mask, token_type_ids, labels


# 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

# print(len(loader))
# print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)

NameError: name 'token' is not defined

## 训练

In [None]:
import time, datetime

start_time = datetime.datetime.now()
print('Start time:', start_time)


# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)
    if i == 300:
        break


end_time = datetime.datetime.now()
print('End time:', end_time)

consume_time = end_time - start_time
print('Consume time of second:', consume_time.seconds)

## 测试

In [None]:
def test():
    model.eval()
    correct = 0
    total = 0
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i%5 == 0 and total != 0:
            print(i, correct / total)
        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print('Test finish:', correct / total)


test()

## 保存模型

In [None]:
import torch

model_path = '/root/workspace/model/model_bert-base-chinese/emotion_classification.pt'
# torch.save(model.state_dict(), model_path)
torch.save(model.state_dict(), model_path)
print('模型已保存至:', model_path)

## 加载模型

In [None]:
my_model = Model()
my_model.load_state_dict(torch.load(model_path))
my_model.eval()

In [None]:
def predict(the_model, text_arr):
    the_model.eval()

    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    torch.no_grad()
    x_encode = tokenizer.batch_encode_plus(
        # 传入的所有句子，有成对句子
        batch_text_or_text_pairs=text_arr,
        # 长度大于设置是否截断
        truncation=True,
        # 一律补齐，如果长度不够
        padding='max_length',
        add_special_tokens=True,
        max_length=30,
        # 可取值tf,pt,np,（tensorflow,pytorch,numpy）默认返回list
        return_tensors="pt",
        # 返回token_type_ids,第一句与特殊符号是0，第二句是1
        return_token_type_ids=True,
        # 返回attention_mask，填充是0，其他是1
        return_attention_mask=True,
        # 返回special_tokens_mask特殊符号标识，特殊是1，其他是0
        return_special_tokens_mask=True,
        # 返回长度,这里的长度是真实长度，而非设置的长度30了
        return_length=True
    )

    y = the_model(input_ids=x_encode['input_ids'],
                  attention_mask=x_encode['attention_mask'],
                  token_type_ids=x_encode['token_type_ids'])
    # print(y)
    res = y.argmax(dim=1)
    return res.numpy().tolist()


texts = ['这家店不错，下次还来', '唉，有点失望，卫生如果能更好一些就不错了', '还行，跟网上描述的一致，有无线网，网速很快']
predict(my_model, texts)