In [1]:
# some configurations
dataset_dir = 'DATASET_DIR'
bert_trained_model_dir = 'MODEL_SAVE_DIR'

In [2]:
import os
import shutil

from simpletransformers.classification import ClassificationModel
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn import model_selection

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

train_batch_size = 16
eval_batch_size = 16
max_seq_length = 72

In [3]:
def train_for_CR(CR_num):
    # load cr dataset
    CR_datasets_path = os.path.join(dataset_dir, 'cr{}.csv'.format(CR_num))

    train_ratio = 0.7
    validation_ratio = 0.3

    datasets_df = pd.read_csv(CR_datasets_path, sep=',', header=None)
    train_df, validation_df = model_selection.train_test_split(datasets_df, random_state=5, test_size=validation_ratio)

    print('Train dataset size: {}\n Validation dataset size: {}'.format(len(train_df), len(validation_df)))

    model_args = {
        'best_model_dir': os.path.join(bert_trained_model_dir, 'CR/CR_{}/best_model'.format(CR_num)),
        'early_stopping_consider_epochs': False,
        'early_stopping_delta': 0.02,
        'early_stopping_patience': 3,
        'eval_batch_size': eval_batch_size,
        'evaluate_during_training': True,
        'evaluate_during_training_steps': int(len(train_df) / train_batch_size),
        'evaluate_during_training_verbose': True,
        'manual_seed': 666,
        'max_seq_length': max_seq_length,
        'num_train_epochs': 15,
        'overwrite_output_dir': True,
        'output_dir': os.path.join(bert_trained_model_dir, 'CR/CR_{}'.format(CR_num)),
        'save_eval_checkpoints': False,
        'save_model_every_epoch': False,
        'save_steps': -1,
        'train_batch_size': train_batch_size,
        'use_early_stopping': True,
        'use_multiprocessing': False,
        'use_multiprocessing_for_evaluation': False
    }
    model = ClassificationModel('bert', 'bert-base-chinese', num_labels=2, args=model_args)

    try:
        shutil.rmtree(model.args.best_model_dir)
        shutil.rmtree(model.args.output_dir)
    except:
        pass

    print('=============train for CR_{}============='.format(CR_num))

    # model training
    model.train_model(train_df, eval_df=validation_df)

    # model evaluation
    result, model_outputs, wrong_predictions = model.eval_model(validation_df, f1=f1_score, acc=accuracy_score,
                                                                precision=precision_score, recall=recall_score)

    print('=============results for CR_{}============='.format(CR_num))
    print(result)
    for wrong_prediction in wrong_predictions:
        print(wrong_prediction)