In [None]:
import torch
from datasets import load_dataset, load_from_disk
from transformers import BertTokenizer, BertModel
from transformers import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. 定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        # self.dataset = load_from_disk('data/ChnSentiCorp')[split]
        self.dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label


dataset = Dataset('train')
# print(len(dataset), dataset[0])


# 2. 加载字典和分词工具
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# print(tokenizer)


# 3. 定义批处理函数
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    # 批处理增强编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                       truncation=True,
                                       padding='max_length',
                                       max_length=500,
                                       return_tensors='pt',
                                       return_length=True)

    # 分词编码后数据
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)
    ##print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


# 4. 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

# for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
#     break



# 5. 加载预训练语言模型
pre_model = BertModel.from_pretrained('bert-base-chinese').to(device)
# print(pre_model.parameters)

# 不对预训练模型进行训练，不需要计算梯度
for i,param in enumerate(pre_model.parameters()):
    # print(i)
    # print(param)
    param.requires_grad_(False)


# 6. 定义下游任务
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            # 预训练语言模型不进行训练
            out = pre_model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        # print('last hidden state: ', out.last_hidden_state.shape)
        # last hidden state:  torch.Size([16, 500, 768])
        # 只对全连接层进行训练
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)

        return out


# 7. 加载整个模型： 预训练+下游任务
model = Model().to(device)
# print("model out: ", model(input_ids=input_ids,
#         attention_mask=attention_mask,
#         token_type_ids=token_type_ids).shape)
# model out: torch.Size([16, 2])


optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()


# 8. 训练
def train():
    model.train()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        input_ids = input_ids.to(device)
        attention_mask =  attention_mask.to(device)
        token_type_ids =  token_type_ids.to(device)
        labels = labels.to(device)
        # print(input_ids.device, attention_mask.device, token_type_ids.device, labels.device)

        # 正向传播
        out = model(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids)
        
        # 计算损失
        loss = criterion(out, labels) 
        # 反向传播，计算梯度
        loss.backward() 
        # 更新参数
        optimizer.step()  
        # 梯度归零
        optimizer.zero_grad() 

        if i % 5 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)

            print(f'第{i}批数据:', loss.item(), accuracy)

        # loader为600, 取前100批数据进行训练
        if i == 100:
            break
    # torch.save(model, '')


# 8. 测试
def test():
    model.eval()
    correct = 0
    total = 0

    # 数据加载器
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)

        # 取钱5批数据进行测试
        if i == 5:
            break
        print(f'第{i}批数据:')

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1) # 选择二分类中最大
        correct += (out == labels).sum().item()
        total += len(labels)

        print('准确率：', correct / total)


if __name__ == "__main__":
    train()
    test()
    print('end...')
