In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch.optim as optim
import joblib
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedKFold
from colorama import Fore, Style
from transformers import logging

from utils import (set_seed, read_data, get_collate_fn, evaluate, predict)

# 忽略警告
logging.set_verbosity_error()

In [2]:
USE_AUGMENT = False

SEED = 2022
set_seed(SEED)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
train_data = read_data('datasets/KUAKE-QQR_train.json')
valid_data = read_data('datasets/KUAKE-QQR_dev.json')

if USE_AUGMENT:  # 是否开启数据增强
    augment_data = pd.read_csv('extra_data/augment_data.csv')
    all_data = pd.concat([train_data, valid_data, augment_data], axis=0)
    all_data = all_data.drop_duplicates()
else:
    all_data = pd.concat([train_data, valid_data], axis=0)

In [5]:
model_ckpt = "hfl/chinese-electra-180g-large-discriminator"
token = AutoTokenizer.from_pretrained(model_ckpt)
print(token.model_input_names)

['input_ids', 'token_type_ids', 'attention_mask']


In [6]:
class CustomModel(nn.Module):
    """下游训练任务模型"""

    def __init__(self, pretrained_model_name):
        super(CustomModel, self).__init__()
        self.pretrained = AutoModel.from_pretrained(pretrained_model_name)
        self.fc = torch.nn.Linear(self.pretrained.config.hidden_size, 3)

    def forward(self, input_ids, attention_mask, token_type_ids):
        out = self.pretrained(input_ids=input_ids, attention_mask=attention_mask,
                              token_type_ids=token_type_ids).last_hidden_state
        # output.shape=[batch_size, class_num]
        out = self.fc(torch.mean(out, dim=1))
        out = out.softmax(dim=1)
        return out

In [7]:
# 模型训练
def train_and_evaluate(model, dataloader_train, dataloader_valid, best, criterion, optimizer, scheduler_lr=None,
                       device=torch.device('cpu')):
    model.train()

    for idx, (input_ids, attention_mask, token_type_ids, labels) in enumerate(dataloader_train, start=1):
        # 数据设备切换
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        loss = criterion(out, labels)  # 每个step的损失值
        loss.backward()
        optimizer.step()
        if scheduler_lr is not None:
            scheduler_lr.step()

        if idx % 20 == 0:
            valid_acc = evaluate(model, dataloader_valid, device)
            train_step_acc = accuracy_score(labels.cpu().numpy(), torch.argmax(out.cpu(), dim=1).numpy())  # 评估指标
            if best[1] < valid_acc:
                best.pop()
                best.append(valid_acc)
                torch.save(model.state_dict(), 'models/save_model_{}.pkl'.format(best[0]))
                print('| step {:5d} | loss {:9.6f} | train_step_acc {:9.6f} | valid_acc {:9.6f} |'.format(idx,
                                                                                                          loss.item(),
                                                                                                          train_step_acc,
                                                                                                          valid_acc))


In [8]:
skfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

for fold, (trn_ind, val_ind) in enumerate(skfold.split(all_data, all_data['label'].values)):
    print(Fore.RED + '#' * 50 + str(fold) + '#' * 50)
    print(Style.RESET_ALL, end='')
    train_data, valid_data = all_data.iloc[trn_ind].values.tolist(), all_data.iloc[val_ind].values.tolist()
    dataloader_train = torch.utils.data.DataLoader(train_data,
                                                   shuffle=True,
                                                   batch_size=64,
                                                   collate_fn=get_collate_fn(token))
    dataloader_valid = torch.utils.data.DataLoader(valid_data,
                                                   batch_size=64,
                                                   collate_fn=get_collate_fn(token))
    # *****************************************************************************
    model = CustomModel(model_ckpt)
    model = model.to(device)

    optimizer_adamw = optim.AdamW(model.parameters(), lr=1e-5)
    criterion_cross_entropy = torch.nn.CrossEntropyLoss()
    # *****************************************************************************
    best_acc_list = [fold, 0.0]
    for epoch in range(5):
        print('-' * 40 + str(epoch) + '-' * 40)
        train_and_evaluate(model, dataloader_train, dataloader_valid, best_acc_list, criterion_cross_entropy,
                           optimizer_adamw, None, device)
    print(best_acc_list)

[31m##################################################0##################################################
[0m----------------------------------------0----------------------------------------
| step    20 | loss  0.924986 | train_step_acc  0.625000 | valid_acc  0.599699 |
| step    40 | loss  0.915998 | train_step_acc  0.593750 | valid_acc  0.712651 |
| step    60 | loss  0.677429 | train_step_acc  0.906250 | valid_acc  0.793373 |
| step    80 | loss  0.707501 | train_step_acc  0.828125 | valid_acc  0.824398 |
| step   120 | loss  0.646340 | train_step_acc  0.906250 | valid_acc  0.843976 |
| step   140 | loss  0.651439 | train_step_acc  0.906250 | valid_acc  0.852108 |
| step   160 | loss  0.724603 | train_step_acc  0.828125 | valid_acc  0.854819 |
| step   180 | loss  0.633448 | train_step_acc  0.906250 | valid_acc  0.858434 |
----------------------------------------1----------------------------------------
| step    20 | loss  0.679236 | train_step_acc  0.859375 | valid_acc  0.85963

In [9]:
k_fold_predict = np.zeros([1596, 3])

for i in range(5):
    model_predict = CustomModel(model_ckpt)
    model_predict.load_state_dict(torch.load('models/save_model_{}.pkl'.format(i)))
    model_predict = model_predict.to(device)
    result_i = predict('datasets/KUAKE-QQR_test.json', token, model_predict, device)
    k_fold_predict += (result_i / 5)

In [10]:
joblib.dump(k_fold_predict, 'predict/electra_base.pkl')

['predict/electra_base.pkl']