In [None]:
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader,Dataset
from transformers import BertTokenizer,BertModel,AdamW
import random

In [None]:
class my_dataset(Dataset):
    def __init__(self,split):
        dataset=load_dataset("seamew/Weibo",split=split)
        def f(data):
            return len(data["text"])>20
        self.dataset=dataset.filter(f)
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self,i):
        text=self.dataset[i]["text"]
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0
        if random.randint(0,1) == 0 :
            j = random.randint(0, len(self.dataset) - 1)
            sentence2 = self.dataset[j]["text"][20:40]
            label = 1
        return sentence1,sentence2,label
train_data=my_dataset("train")
sentence1, sentence2, label = train_data[0]
len(train_data),sentence1,sentence2,label

In [None]:
token=BertTokenizer.from_pretrained("bert-base-chinese")
def collate_fn(data):
    sents=[i[:2] for i in data]
    labels=[i[2] for i in data]
    data=token.batch_encode_plus(
        batch_text_or_text_pairs=sents,
        truncation=True,
        padding="max_length",
        max_length=45,
        return_tensors="pt",
        return_length=True,
        add_special_tokens=True
    )
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    input_ids = data["input_ids"]
    labels =torch.LongTensor(labels)
    return input_ids,attention_mask,token_type_ids,labels
loader=DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True
    )
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader):
    break
print(len(loader))
print(token.decode(input_ids[0]))
input_ids.shape,attention_mask.shape,token_type_ids.shape

In [None]:
pre_model=BertModel.from_pretrained("bert-base-chinese")
for param in pre_model.parameters():
    param.requires_grad_(False)
out=pre_model(input_ids,attention_mask,token_type_ids)
out.last_hidden_state.shape

In [None]:
#下游任务
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc=nn.Linear(768,2)
    def forward(self,input_ids,attention_mask,token_type_ids):
        with torch.no_grad():
            out = pre_model(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)
        out =self.fc(out.last_hidden_state[:,0])
        out =out.softmax(dim=1)
        return out
model=Model()
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape


In [None]:
optim=AdamW(model.parameters(),lr=5e-4)
criteration=nn.CrossEntropyLoss()
model.train()
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader):
    out = model(input_ids,attention_mask,token_type_ids)
    loss = criteration(out,labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)

    if i == 300:
        break
