In [1]:
import torch
from datasets import load_dataset

### 1 定义数据集

In [2]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,split):
        self.dataset=load_dataset(path='lansinuote/ChnSentiCorp', split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,i):
        text=self.dataset[i]['text']
        label=self.dataset[i]['label']

        return text,label

dataset=Dataset('train')
len(dataset),dataset[0]

dataset_infos.json:   0%|          | 0.00/960 [00:00<?, ?B/s]

(…)-00000-of-00001-02f200ca5f2a7868.parquet:   0%|          | 0.00/2.16M [00:00<?, ?B/s]

(…)-00000-of-00001-405befbaa3bcf1a2.parquet:   0%|          | 0.00/276k [00:00<?, ?B/s]

(…)-00000-of-00001-5372924f059fe767.parquet:   0%|          | 0.00/275k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9600 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1200 [00:00<?, ? examples/s]

(9600,
 ('选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',
  1))

### 2 加载字典和分词工具

In [3]:
from transformers import BertTokenizer

token = BertTokenizer.from_pretrained('bert-base-chinese')

token

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/624 [00:00<?, ?B/s]

BertTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

### 3 数据加载器

In [None]:
def collate_fn(data):
    sents=[i[0] for i in data]
    labels=[i[1] for i in data]

    #编码
    data=token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                truncation=True,
                                padding='max_length',
                                max_length=500,
                                return_tensors='pt',
                                return_length=True,
                                )

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids=data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']

    labels=torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, labels

#数据加载器
loader=torch.utils.data.DataLoader(dataset=dataset,
                                  batch_size=16,
                                  collate_fn=collate_fn,
                                  shuffle=True,
                                  drop_last=True)

#取第一个批次的数据
for i,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

600


(torch.Size([16, 500]),
 torch.Size([16, 500]),
 torch.Size([16, 500]),
 tensor([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0]))

### 4 模型

In [None]:
from transformers import BertModel

#加载预训练模型
pretrained=BertModel.from_pretrained('bert-base-chinese')
#不训练,不需要计算梯度
'''在训练过程中这些参数不会被更新。这样做的目的是使用预训练模型的特征提取能力，而不对其进行微调'''
for param in pretrained.parameters():
    param.requires_grad_(False)

#模型试算
out = pretrained(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)
out.last_hidden_state.shape

model.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

torch.Size([16, 500, 768])

### 5 定义下游任务模型

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc=torch.nn.Linear(768,2)

    def forward(self,input_ids,attention_mask,token_type_ids):
        with torch.no_grad(): #只需要进行特征提取而不需要更新模型参数
            out = pretrained(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)

        out=self.fc(out.last_hidden_state[:,0])
        #将输入张量的每一行（即每个样本的输出）转换为一个概率分布，所有元素的和为1
        out=out.softmax(dim=1)

        return out

model=Model()

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([16, 2])

### 6 训练

In [None]:
from transformers import AdamW

optimizer=AdamW(model.parameters(),lr=5e-4)
criterion=torch.nn.CrossEntropyLoss()

model.train()

for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    
    #计算损失
    loss=criterion(out,labels)
    #反向传播和优化
    loss.backward()#计算损失的梯度
    optimizer.step()#更新模型参数
    optimizer.zero_grad()#清空优化器的梯度

    #打印训练信息
    if i%5==0:
        out=out.argmax(dim=1) #找到每个样本预测结果中概率最大的类别索引
        accuracy=(out==labels).sum().item()/len(labels)
        #item() 方法用于将包含单个值的张量转换为 Python 标量
        print(i,loss.item(),accuracy)
    
    #当处理到第 300 个批次时，终止训练循环
    if i == 300:
        break



0 0.7590072751045227 0.375
5 0.6746047735214233 0.5625
10 0.6309472918510437 0.6875
15 0.6731126308441162 0.6875
20 0.6028311848640442 0.75
25 0.6080811619758606 0.625
30 0.6647576689720154 0.5
35 0.520226001739502 1.0
40 0.5909832119941711 0.6875
45 0.5008869171142578 0.9375
50 0.4827878475189209 0.9375
55 0.5563780069351196 0.875
60 0.48446062207221985 0.875
65 0.47542521357536316 0.875
70 0.4827728569507599 0.9375
75 0.49847108125686646 0.8125
80 0.4473411440849304 0.875
85 0.4774356186389923 0.875
90 0.4438154101371765 0.9375
95 0.43533143401145935 0.9375
100 0.48415639996528625 0.875
105 0.5001518726348877 0.75
110 0.5333747863769531 0.75
115 0.47413820028305054 0.875
120 0.4978610575199127 0.875
125 0.5136285424232483 0.875
130 0.5017361640930176 0.8125
135 0.48214778304100037 0.8125
140 0.4277268052101135 0.9375
145 0.5093130469322205 0.875
150 0.4964030683040619 0.8125
155 0.6243466138839722 0.6875
160 0.48617786169052124 0.75
165 0.4418548345565796 0.9375
170 0.403619349002838

### 7 测试

In [None]:
def test():
    model.eval()
    correct=0
    total=0
    
    loader_test=torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                  batch_size=32,
                                  collate_fn=collate_fn,
                                  shuffle=True,
                                  drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,labels) in enumerate(loader_test):
        if i==5:
            break
        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)
        out=out.argmax(dim=1)
        correct+=(out==labels).sum().item()
        total+=len(labels)

    print(correct/total)
        
test()

0
1
2
3
4
0.85625
