<a href="https://www.kaggle.com/code/lxytypk/5chinese-sentence-inference?scriptVersionId=227190166" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

### 1 定义数据集

In [1]:
import torch
import random
from datasets import load_dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self,split):
        dataset=load_dataset(path='lansinuote/ChnSentiCorp', split=split)
        
        def f(data):
            return len(data['text'])>40

        self.dataset=dataset.filter(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self,i):
        text=self.dataset[i]['text']

        #切分一句话为前半句和后半句
        sentence1=text[:20]
        sentence2=text[20:40]
        label=0

        #有一半的概率把后半句替换为一句无关的话
        if random.randint(0,1)==0:
            j=random.randint(0,len(self.dataset)-1)
            sentence2=self.dataset[j]['text'][20:40]
            label=1

        return sentence1,sentence2,label

dataset=Dataset('train')
sentence1, sentence2, label = dataset[0]

len(dataset), sentence1, sentence2, label

dataset_infos.json:   0%|          | 0.00/960 [00:00<?, ?B/s]

(…)-00000-of-00001-02f200ca5f2a7868.parquet:   0%|          | 0.00/2.16M [00:00<?, ?B/s]

(…)-00000-of-00001-405befbaa3bcf1a2.parquet:   0%|          | 0.00/276k [00:00<?, ?B/s]

(…)-00000-of-00001-5372924f059fe767.parquet:   0%|          | 0.00/275k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9600 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1200 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9600 [00:00<?, ? examples/s]

(8001, '选择珠江花园的原因就是方便，有电动扶梯直', '配置的还要便宜。用过很多笔记本电脑，这个', 1)

### 2 加载字典和分词工具

In [2]:
from transformers import BertTokenizer

token=BertTokenizer.from_pretrained('bert-base-chinese')
token

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/624 [00:00<?, ?B/s]

BertTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

### 3 数据加载器

In [3]:
def collate_fn(data):
    sents=[i[:2] for i in data]
    labels=[i[2] for i in data]
    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=45,
                                   return_tensors='pt',
                                   return_length=True,
                                   add_special_tokens=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels=torch.LongTensor(labels)

    return input_ids, attention_mask, token_type_ids, labels


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

1000
[CLS] 笔 记 本 还 是 不 错 的 ， 要 是 我 自 己 用 ， linu [SEP] x 也 凑 合 了 ， 无 奈 是 给 家 里 人 买 的 ， 只 能 改 [UNK] [SEP] [PAD] [PAD] [PAD] [PAD]


(torch.Size([8, 45]),
 torch.Size([8, 45]),
 torch.Size([8, 45]),
 tensor([0, 0, 0, 0, 1, 1, 0, 1]))

### 4 模型试算

In [4]:
from transformers import BertModel

#加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

#模型试算
out = pretrained(input_ids=input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids)

out.last_hidden_state.shape

model.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

torch.Size([8, 45, 768])

### 5 定义下游任务模型

In [5]:
#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc=torch.nn.Linear(768,2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)

        out = self.fc(out.last_hidden_state[:, 0])
        out=out.softmax(dim=1)

        return out


model = Model()

model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([8, 2])

### 6 训练

In [6]:
from transformers import AdamW

#训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print(i, loss.item(), accuracy)

    if i == 300:
        break



0 0.7448517084121704 0.375
5 0.6085245013237 0.875
10 0.5618289113044739 0.75
15 0.47911709547042847 1.0
20 0.5362741351127625 0.75
25 0.42467349767684937 1.0
30 0.5011657476425171 0.75
35 0.45691585540771484 0.875
40 0.4079835116863251 1.0
45 0.47950437664985657 0.875
50 0.4913110136985779 0.875
55 0.5584380626678467 0.625
60 0.5110815167427063 0.75
65 0.3372158706188202 1.0
70 0.34630680084228516 1.0
75 0.5314846634864807 0.75
80 0.5558391809463501 0.75
85 0.5062395334243774 0.75
90 0.4464564323425293 0.875
95 0.4681287109851837 0.75
100 0.4517415463924408 0.875
105 0.4239656925201416 0.875
110 0.36756274104118347 1.0
115 0.3529762327671051 1.0
120 0.3612954616546631 1.0
125 0.6456762552261353 0.625
130 0.5186907052993774 0.75
135 0.4549063742160797 0.875
140 0.4864486753940582 0.75
145 0.4023735225200653 0.875
150 0.5235679745674133 0.75
155 0.495903342962265 0.75
160 0.3758128881454468 1.0
165 0.342488169670105 1.0
170 0.5804436802864075 0.625
175 0.4100106358528137 1.0
180 0.39372

### 7 测试

In [7]:
#测试
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader_test):

        if i == 5:
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

Filter:   0%|          | 0/1200 [00:00<?, ? examples/s]

0
1
2
3
4
0.88125
