In [1]:
from transformers import BertTokenizer, RobertaModel

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
from tqdm.notebook import tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
class MyDataset(Dataset):
    def __init__(self, tokenizer, sample=None):
        df = pd.read_csv('../../../datasets/THUCNews/train.csv')
        self.label2class = dict(zip(df['label'], df['class']))
        
        if sample:
            df = df.sample(sample).reset_index(drop=True)
        self.texts = [tokenizer(text, max_length=32, padding='max_length', truncation=True, return_tensors='pt')
                      for text in tqdm(df['title'])]
        self.labels = list(df['label'].values)
        
    def __getitem__(self, idx):
        return self.texts[idx].input_ids, torch.LongTensor([self.labels[idx]])

    def __len__(self):
        return len(self.labels)
    
tokenizer = BertTokenizer.from_pretrained('../../../models/bert-base-chinese/')
dataset = MyDataset(tokenizer, sample=100000)

  0%|          | 0/100000 [00:00<?, ?it/s]

In [3]:

class RobertaChinese(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        self.roberta = RobertaModel.from_pretrained('../../../models/roberta-base/')
        self.roberta.embeddings.word_embeddings = nn.Embedding(tokenizer.vocab_size, 768, padding_idx=0) 
        self.fc = nn.Linear(768, num_class)
        
    def forward(self, x):
        pooler_output = self.roberta(x).pooler_output
        return self.fc(pooler_output)

model = RobertaChinese(num_class=len(dataset.label2class)).to(device)
model

Some weights of the model checkpoint at ../../../models/roberta-base/ were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RobertaChinese(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((7

In [4]:
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(20):
    total_loss = 0
    total_acc = 0
    for x, y in tqdm(dataloader):
        x = x.to(device)
        y = y.to(device)
        x.squeeze_(1)
        y = y.reshape(-1)
        output = model(x)    
        loss = criterion(output, y)
        acc = (output.argmax(1) == y).sum().item()
        
        model.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += acc
        
    print(f'Epochs:{epoch + 1}|Train Loss:{total_loss / len(dataset): .6f}|Train Accuracy:{total_acc / len(dataset): .6f}')

  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:1|Train Loss: 0.005025|Train Accuracy: 0.602520


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:2|Train Loss: 0.002805|Train Accuracy: 0.779370


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:3|Train Loss: 0.002113|Train Accuracy: 0.834510


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:4|Train Loss: 0.001656|Train Accuracy: 0.870420


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:5|Train Loss: 0.001238|Train Accuracy: 0.905100


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:6|Train Loss: 0.000858|Train Accuracy: 0.936770


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:7|Train Loss: 0.000551|Train Accuracy: 0.961360


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:8|Train Loss: 0.000336|Train Accuracy: 0.977530


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:9|Train Loss: 0.000237|Train Accuracy: 0.983570


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:10|Train Loss: 0.000200|Train Accuracy: 0.985670


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:11|Train Loss: 0.000160|Train Accuracy: 0.988710


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:12|Train Loss: 0.000128|Train Accuracy: 0.990500


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:13|Train Loss: 0.000152|Train Accuracy: 0.987740


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:14|Train Loss: 0.000117|Train Accuracy: 0.990800


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:15|Train Loss: 0.000105|Train Accuracy: 0.991500


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:16|Train Loss: 0.000101|Train Accuracy: 0.991520


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:17|Train Loss: 0.000089|Train Accuracy: 0.992760


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:18|Train Loss: 0.000093|Train Accuracy: 0.992580


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:19|Train Loss: 0.000061|Train Accuracy: 0.995030


  0%|          | 0/391 [00:00<?, ?it/s]

Epochs:20|Train Loss: 0.000097|Train Accuracy: 0.991910


In [21]:
model.eval()

df_test = pd.read_csv('../../../datasets/THUCNews/test.csv').sample(10000).reset_index(drop=True)
predicts = []
for text in tqdm(df_test['title']):
    input_ids = tokenizer(text, max_length=32, padding='max_length', truncation=True, return_tensors='pt').input_ids
    input_ids = input_ids.to(device)
    predicts.append(nn.functional.softmax(model(input_ids), dim=1).argmax(dim=1).item())

  0%|          | 0/10000 [00:00<?, ?it/s]

In [31]:
df_test['predict'] = predicts

In [32]:
df_test

Unnamed: 0,title,class,label,file,predict
0,防守科比最后一投非他莫属 勒布朗韦德亦自叹不如,体育,8,88639.txt,8
1,组图：杨采妮裸色长裙出镜 与陈数代言美钻,娱乐,5,167306.txt,5
2,Avaya中国完成对北电初步整合,科技,10,633926.txt,10
3,客场3连败!鲁能列中超倒数第1 这样的战绩想卫冕,体育,8,74986.txt,8
4,公布房价成本请非诚勿扰 需公权力介入,房产,12,266176.txt,12
...,...,...,...,...,...
9995,食品饮料：在估值优势和短期政策中寻找平衡,股票,3,708871.txt,3
9996,屋内的动静则随着光线折射(图),家居,9,227786.txt,9
9997,2010年南京中考数学不难70%都是基础题,教育,1,322847.txt,1
9998,1200万像素千元机 三星ES73新品套机促销,科技,10,489890.txt,10


In [29]:
print('10000条测试数据准确率：')
print((df_test['predict'] == df_test['label']).sum() / len(df_test['label']))

10000条测试数据准确率：
0.8142
