In [None]:
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import BertTokenizer,BertModel
from torch.utils.data import DataLoader,Dataset
from transformers import AdamW

In [None]:
class my_dataset(Dataset):
    def __init__(self,split):
        self.data=load_dataset("seamew/Weibo",split=split)
    def __len__(self):
        return len(self.data)
    #以元组返回
    def __getitem__(self,i):
        text=self.data[i]["text"]
        #label=self.data[i][1]
        return text
train_data=my_dataset("train")
train_data[0]

In [None]:
token=BertTokenizer.from_pretrained("bert-base-chinese")


In [None]:
def collate_fn(data):
    data=token.batch_encode_plus(
            batch_text_or_text_pairs=data,
            truncation=True,
            padding="max_length",
            max_length=32,
            return_tensors="pt",
            return_length=True)
    input_ids=data["input_ids"]
    attention_mask=data["attention_mask"]
    token_type_ids=data["token_type_ids"]
    labels=input_ids[:, 15].reshape(-1).clone()
    input_ids[:, 15] = token.get_vocab()[token.mask_token]
    return input_ids,attention_mask,token_type_ids,labels
loader=DataLoader(dataset=train_data,collate_fn=collate_fn,shuffle=True,batch_size=100)
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(loader):
    break
print(len(loader))
print(input_ids.shape)
print(token.decode(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape


In [None]:
pre_model=BertModel.from_pretrained("bert-base-chinese")
for param in pre_model.parameters():
    param.requires_grad_(False)
#model试算
out=pre_model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
out.last_hidden_state.shape

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder=nn.Linear(768,token.vocab_size, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(token.vocab_size))
        self.decoder.bias = self.bias
    def forward(self,input_ids,attention_mask,token_type_ids):
        with torch.no_grad():
            out=pre_model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        out=self.decoder(out.last_hidden_state[:, 15])
        return out
model=Model()
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids)[:,0]

In [30]:
optim=AdamW(model.parameters(),lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(5):
    for i,(input_ids, attention_mask, token_type_ids,labels) in enumerate(loader):
        out=model(input_ids,attention_mask,token_type_ids)
        loss=criterion(out,labels)
        loss.backward()
        optim.step()
        optim.zero_grad()
        if i % 50 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            print(epoch, i, loss.item(), accuracy)

0 0 2.994009017944336 0.54
0 50 2.8158719539642334 0.57
0 100 3.153811454772949 0.54
0 150 2.5734691619873047 0.68
0 200 3.027992010116577 0.59
0 250 2.60850191116333 0.6
0 300 2.639406442642212 0.65
0 350 2.411080837249756 0.68
0 400 2.022575616836548 0.74
0 450 1.7936979532241821 0.73
1 0 1.8559455871582031 0.71
1 50 1.4746702909469604 0.72
1 100 1.5755122900009155 0.72


KeyboardInterrupt: 