In [1]:
import os
import sys
sys.path.append('.')
import argparse
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AlbertForSequenceClassification, \
    BertForTokenClassification, AlbertForTokenClassification

from cblue.data import STSDataProcessor, STSDataset, QICDataset, QICDataProcessor, QQRDataset, \
    QQRDataProcessor, QTRDataset, QTRDataProcessor, CTCDataset, CTCDataProcessor, EEDataset, EEDataProcessor
from cblue.trainer import STSTrainer, QICTrainer, QQRTrainer, QTRTrainer, CTCTrainer, EETrainer
from cblue.utils import init_logger, seed_everything
from cblue.models import ZenConfig, ZenNgramDict, Z
enForSequenceClassification, ZenForTokenClassification

In [2]:
TASK_DATASET_CLASS = {
    'ee': (EEDataset, EEDataProcessor),
    'ctc': (CTCDataset, CTCDataProcessor),
    'sts': (STSDataset, STSDataProcessor),
    'qqr': (QQRDataset, QQRDataProcessor),
    'qtr': (QTRDataset, QTRDataProcessor),
    'qic': (QICDataset, QICDataProcessor)
}

TASK_TRAINER = {
    'ee': EETrainer,
    'ctc': CTCTrainer,
    'sts': STSTrainer,
    'qic': QICTrainer,
    'qqr': QQRTrainer,
    'qtr': QTRTrainer
}

MODEL_CLASS = {
    'bert': (BertTokenizer, BertForSequenceClassification),
    'roberta': (BertTokenizer, BertForSequenceClassification),
    'albert': (BertTokenizer, AlbertForSequenceClassification),
    'zen': (BertTokenizer, ZenForSequenceClassification)
}

TOKEN_MODEL_CLASS = {
    'bert': (BertTokenizer, BertForTokenClassification),
    'roberta': (BertTokenizer, BertForTokenClassification),
    'albert': (BertTokenizer, AlbertForTokenClassification),
    'zen': (BertTokenizer, ZenForTokenClassification)
}

In [4]:
parser = argparse.ArgumentParser()
args = parser.parse_args([])

In [17]:
args.task_name = "ee"
args.model_type = "bert"
args.seed = 14
args.model_dir, args.model_name = "data/model_data", "chinese-bert-wwm-ext"
args.data_dir="CBLUEDatasets"
args.max_length = 128
args.output_dir="data/output"

In [6]:
tokenizer_class, model_class = MODEL_CLASS[args.model_type]
dataset_class, data_processor_class = TASK_DATASET_CLASS[args.task_name]
trainer_class = TASK_TRAINER[args.task_name]

if args.task_name == 'ee':
    tokenizer_class, model_class = TOKEN_MODEL_CLASS[args.model_type]

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.device = device
seed_everything(args.seed)

In [18]:
logger = init_logger(os.path.join(args.output_dir, f'{args.task_name}_{args.model_name}.log'))

# train

In [19]:
args.do_train = True
if args.do_train:
    tokenizer = tokenizer_class.from_pretrained(os.path.join(args.model_dir, args.model_name))

    # compatible with 'ZEN' model
    ngram_dict = None
    if args.model_type == 'zen':
        ngram_dict = ZenNgramDict(os.path.join(args.model_dir, args.model_name), tokenizer=tokenizer)

    data_processor = data_processor_class(root=args.data_dir)
    train_samples = data_processor.get_train_sample()
    eval_samples = data_processor.get_dev_sample()

    if args.task_name == 'ee' or args.task_name == 'ctc':
        train_dataset = dataset_class(train_samples, data_processor, tokenizer, mode='train',
                                      model_type=args.model_type, ngram_dict=ngram_dict, max_length=args.max_length)
        eval_dataset = dataset_class(eval_samples, data_processor, tokenizer, mode='eval',
                                     model_type=args.model_type, ngram_dict=ngram_dict, max_length=args.max_length)
    else:
        train_dataset = dataset_class(train_samples, data_processor, mode='train')
        eval_dataset = dataset_class(eval_samples, data_processor, mode='eval')

    model = model_class.from_pretrained(os.path.join(args.model_dir, args.model_name),
                                        num_labels=data_processor.num_labels)

    trainer = trainer_class(args=args, model=model, data_processor=data_processor,
                            tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset,
                            logger=logger, model_class=model_class, ngram_dict=ngram_dict)

    global_step, best_step = trainer.train()

Some weights of BertForTokenClassification were not initialized from the model checkpoint at data/model_data/chinese-bert-wwm-ext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


AttributeError: 'Namespace' object has no attribute 'train_batch_size'