In [None]:
from mindspore import nn
from mindnlp import load_dataset
from mindspore.dataset import text, GeneratorDataset, transforms
from mindnlp.engine import Trainer, Evaluator
from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp.metrics import Accuracy
from mindnlp.transformers import AutoModelForSequenceClassification
from mindnlp.transformers import AutoTokenizer
from mindnlp.transformers.models.bert.modeling_bert import BertDualForSequenceClassification

In [None]:
real_model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)
#create dual model
model = BertDualForSequenceClassification(real_model.config)
# tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [None]:
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']

In [None]:
def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['token_type_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'token_type_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'token_type_ids': (None, 0),
                                                             'attention_mask': (None, 0)})

    return dataset

In [None]:
# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

dataset_train = process_dataset(imdb_train, tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, tokenizer)
dataset_test = process_dataset(imdb_test, tokenizer)

In [None]:
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='dual_bert_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='dual_bert_imdb_finetune_best', auto_load=True)

In [None]:
trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=13, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)
    
trainer.run(tgt_columns="labels")

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")