In [3]:
import torch
import random
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
from transformers import AdamW


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1.定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

        def f(data):
            return len(data['text']) > 40

        self.dataset = dataset.filter(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']

        # 切分一句话为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0  # 有关为0，无关为1

        #有一半的概率把后半句替换为一句无关的话
        if random.randint(0, 1) == 0:
            j = random.randint(0, len(self.dataset) - 1)
            sentence2 = self.dataset[j]['text'][20:40]
            label = 1

        return sentence1, sentence2, label


dataset = Dataset('train')
# sentence1, sentence2, label = dataset[0]
# len(dataset), sentence1, sentence2, label

# 2. 加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

# 3. 定义批处理函数
def collate_fn(data):
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]

    # 批处理增强编码函数
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=45,
                                   return_tensors='pt',
                                   return_length=True,
                                   add_special_tokens=True)

    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels


# 4.数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

# for i, (input_ids, attention_mask, token_type_ids,
#         labels) in enumerate(loader):
#     break


# 5.加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

# 不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

# 6.定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)

        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)

        return out

# 7. 加载整个模型：预训练+下游任务
model = Model().to(device)

optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

# 训练
def train():
    model.train()
    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)

        # 正向传播
        out = model(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids)

        # 计算损失
        loss = criterion(out, labels)
        # 反向传播，计算梯度
        loss.backward()
        # 参数更新
        optimizer.step()
        # 梯度归零
        optimizer.zero_grad()

        if i % 5 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            print(i, loss.item(), accuracy)

        if i == 300:
            break
    torch.save(model.state_dict(), 'model/infer.pt')

# 测试
def test():
    test_model = Model()  # 定义模型
    test_model.load_state_dict(torch.load('model/infer.pt'))
    test_model.to(device)

    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader_test):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)

        if i == 10:
            break

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        pred = out.argmax(dim=1)

        correct += (pred == labels).sum().item()
        total += len(labels)

        print(f"第{i}批数据:", correct / total)

if __name__ ==  "__main__":
    print("start...")
    train()
    test()
    print('end...')
    

Found cached dataset chn_senti_corp (/home/codespace/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)
Loading cached processed dataset at /home/codespace/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85/cache-d23ef8c490aad355.arrow
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPre

start...
0 0.7801969051361084 0.25
5 0.7143212556838989 0.375
10 0.4878574013710022 1.0
15 0.5834335684776306 0.75
20 0.5357387065887451 0.75
25 0.5131847858428955 0.75
30 0.4711150527000427 0.875
35 0.4557431936264038 0.875
40 0.512300968170166 0.75
45 0.5064978003501892 0.625
50 0.396485298871994 0.875
55 0.5364752411842346 0.75
60 0.5696155428886414 0.75
65 0.45890069007873535 0.875
70 0.4444369673728943 0.875
75 0.6914003491401672 0.5
80 0.4821853041648865 0.875
85 0.4341668486595154 0.75
90 0.3415139317512512 1.0
95 0.3460024893283844 1.0
100 0.3879810571670532 0.875
105 0.37142595648765564 1.0
110 0.34169188141822815 1.0
115 0.4674598276615143 0.875
120 0.4514293968677521 0.875
125 0.37706851959228516 1.0
130 0.3896777331829071 1.0
135 0.4709225296974182 0.75
140 0.43113186955451965 0.875
145 0.3441063165664673 1.0
150 0.36928948760032654 1.0
155 0.4450352191925049 0.875
160 0.3400208353996277 1.0
165 0.45745861530303955 0.75
170 0.3964010775089264 0.875
175 0.484200119972229 0.8

Found cached dataset chn_senti_corp (/home/codespace/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)


  0%|          | 0/2 [00:00<?, ?ba/s]

0
第0批数据: 0.84375
1
第1批数据: 0.859375
2
第2批数据: 0.875
3
第3批数据: 0.875
4
第4批数据: 0.88125
end...
