In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from transformers import BertTokenizer
from transformers import BertModel

import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# 数据处理

In [4]:
class THUCNews(Dataset):
    def __init__(self, dataset_path, tokenizer, sample=10000):
        df = pd.read_csv(dataset_path).dropna().sample(sample).reset_index(drop=True)
        self.labels = df['label']
        self.n_classes = len(df['label'].unique())
        self.texts = [tokenizer(title, padding='max_length', max_length=32,
                                return_tensors='pt', truncation=True)
                      for title in tqdm(df['title'])]
        
    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, idx):
        return self.texts[idx], np.array(self.labels[idx])

In [5]:
# 预训练模型 bert-base-chinese
model_path = '../../models/bert-base-chinese/'
# bert tokenizer
tokenizer = BertTokenizer.from_pretrained(model_path)

# 构造train dataset和valid dataset
train_dataset = THUCNews('../data/THUCNews/train.csv', tokenizer, sample=100000)

valid_dataset = THUCNews('../data/THUCNews/valid.csv', tokenizer, sample=100000)

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/100000 [00:00<?, ?it/s]

# 构造模型

In [6]:
class BertClassifier(nn.Module):
    def __init__(self, n_classes, dropout=0.5):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, n_classes)
        self.relu = nn.ReLU()
        
    def forward(self, input_ids, atention_mask):
        _, pooled_output = self.bert(input_ids, atention_mask, return_dict=False)
        output = self.dropout(pooled_output)
        output = self.linear(output)
        output = self.relu(output)
        return output
    
model = BertClassifier(n_classes=train_dataset.n_classes)
model = model.to(device)

Some weights of the model checkpoint at ../../models/bert-base-chinese/ were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# 训练模型

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=256, shuffle=False)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(20):
    total_loss_train = 0
    total_acc_train = 0
    for train_input, train_label in train_dataloader:
        input_ids = train_input['input_ids'].squeeze(1).to(device)
        attention_mask = train_input['attention_mask'].squeeze(1).to(device)
        train_label = train_label.to(device)
        output = model(input_ids, attention_mask)
        loss = criterion(output, train_label)
        acc = (output.argmax(dim=1) == train_label).sum().item()

        model.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss_train += loss.item()
        total_acc_train += acc
    
    total_loss_valid = 0
    total_acc_valid = 0
    with torch.no_grad():
        for valid_input, valid_label in valid_dataloader:
            input_ids = valid_input['input_ids'].squeeze(1).to(device)
            attention_mask = valid_input['attention_mask'].squeeze(1).to(device)
            valid_label = valid_label.to(device)
            output = model(input_ids, attention_mask)
            loss = criterion(output, valid_label)
            acc = (output.argmax(dim=1) == valid_label).sum().item()
            
            total_loss_valid += loss.item()
            total_acc_valid += acc
        
    print(f'Epochs:{epoch + 1}|Train Loss:{total_loss_train / len(train_dataset): .4f}|Train Accuracy:{total_acc_train / len(train_dataset): .4f}|Val Loss:{total_loss_valid / len(valid_dataset): .4f}|Val Accuracy:{total_acc_valid / len(valid_dataset): .4f}')

Epochs:1|Train Loss: 0.0023|Train Accuracy: 0.8468|Val Loss: 0.0011|Val Accuracy: 0.9241
Epochs:2|Train Loss: 0.0008|Train Accuracy: 0.9441|Val Loss: 0.0009|Val Accuracy: 0.9328
Epochs:3|Train Loss: 0.0005|Train Accuracy: 0.9649|Val Loss: 0.0009|Val Accuracy: 0.9355
Epochs:4|Train Loss: 0.0003|Train Accuracy: 0.9796|Val Loss: 0.0010|Val Accuracy: 0.9351


KeyboardInterrupt: 