In [1]:
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context
from mindnlp.engine.callbacks import CheckpointCallback
from mindnlp.transforms import PadTransform
from mindnlp.models import BertModel, BertConfig
from mindnlp.transforms.tokenizers import BertTokenizer
from mindnlp.engine import Trainer, Accuracy

In [2]:
# set to GPU mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)

# please download the dataset from "https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz"
# and extract it to the "data" folder
column_names = ["label", "text_a"]
dataset_train = GeneratorDataset(source=SentimentDataset("data/train.tsv"),
                                 column_names=column_names, shuffle=True)
dataset_val = GeneratorDataset(source=SentimentDataset("data/dev.tsv"),
                               column_names=column_names, shuffle=True)
dataset_test = GeneratorDataset(source=SentimentDataset("data/test.tsv"),
                                column_names=column_names, shuffle=False)

vocab_path = os.path.join("data", "vocab.txt")
vocab = text.Vocab.from_file(vocab_path)
vocab_size = len(vocab.vocab())

pad_value_text = vocab.tokens_to_ids('[PAD]')
tokenizer = BertTokenizer(vocab=vocab)
pad_op_text = PadTransform(max_length=64, pad_value=pad_value_text)
type_cast_op = transforms.TypeCast(mindspore.int32)

dataset_train = dataset_train.map(
    operations=[tokenizer, pad_op_text], input_columns="text_a")
dataset_train = dataset_train.map(
    operations=[type_cast_op], input_columns="label")

dataset_val = dataset_val.map(
    operations=[tokenizer, pad_op_text], input_columns="text_a")
dataset_val = dataset_val.map(operations=[type_cast_op], input_columns="label")

# rename the columns because the model's construct function requires the parameter input_ids
rename_columns = ["label", "input_ids"]
dataset_train = dataset_train.rename(
    input_columns=column_names, output_columns=rename_columns)
dataset_val = dataset_val.rename(
    input_columns=column_names, output_columns=rename_columns)

dataset_train = dataset_train.batch(32)
dataset_val = dataset_val.batch(32)

In [3]:
# define model
class BertForSequenceClassification(nn.Cell):
    """Bert Model for classification tasks"""

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.bert = BertModel(config)
        # load the pre-trained parameters
        mindspore.load_param_into_net(self.bert, state_dict)
        self.classifier = nn.Dense(config.hidden_size, self.num_labels)

    def construct(self, input_ids, attention_mask=None, token_type_ids=None,
                  position_ids=None, head_mask=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask
        )
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        return logits

# please download the pre-trained model from "https://download.mindspore.cn/toolkits/mindnlp/models/bert/bert-base-chinese.ckpt"
# and put it to the "checkpoints" folder
model_path = os.path.join("checkpoints/bert-base-chinese.ckpt")
state_dict = mindspore.load_checkpoint(model_path)

In [4]:
# set bert config and define parameters for training
config = BertConfig(vocab_size=vocab_size, num_labels=3)
model_instance = BertForSequenceClassification(config)

model_instance.set_train(True)

loss = nn.CrossEntropyLoss(ignore_index=pad_value_text)
optimizer = nn.Adam(model_instance.trainable_params(), learning_rate=1e-5)

metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(
    save_path='sentimentbert_ckpt', epochs=1, keep_checkpoint_max=5)

trainer = Trainer(network=model_instance, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, loss_fn=loss, optimizer=optimizer, callbacks=[ckpoint_cb])

In [5]:
# start training
trainer.run(tgt_columns="label")


The train will start from the checkpoint saved in sentimentbert_ckpt.



Epoch 0: 100%|██████████| 302/302 [01:04<00:00,  4.65it/s, loss=0.3034695] 


Checkpoint: BertForSequenceClassification_epoch_0.ckpt has been saved in epoch:0.


Evaluate: 100%|██████████| 34/34 [00:09<00:00,  3.47it/s]


Evaluate Score: {'Accuracy': 0.7583333333333333}


Epoch 1: 100%|██████████| 302/302 [00:44<00:00,  6.82it/s, loss=0.13073428]


Checkpoint: BertForSequenceClassification_epoch_1.ckpt has been saved in epoch:1.


Evaluate: 100%|██████████| 34/34 [00:01<00:00, 23.31it/s]


Evaluate Score: {'Accuracy': 0.7916666666666666}


Epoch 2: 100%|██████████| 302/302 [00:44<00:00,  6.84it/s, loss=0.08923956] 


Checkpoint: BertForSequenceClassification_epoch_2.ckpt has been saved in epoch:2.


Evaluate: 100%|██████████| 34/34 [00:01<00:00, 23.17it/s]


Evaluate Score: {'Accuracy': 0.825}


Epoch 3: 100%|██████████| 302/302 [00:44<00:00,  6.83it/s, loss=0.06826735] 


Checkpoint: BertForSequenceClassification_epoch_3.ckpt has been saved in epoch:3.


Evaluate: 100%|██████████| 34/34 [00:01<00:00, 24.03it/s]


Evaluate Score: {'Accuracy': 0.825}


Epoch 4: 100%|██████████| 302/302 [00:44<00:00,  6.81it/s, loss=0.049864933]


Checkpoint: BertForSequenceClassification_epoch_4.ckpt has been saved in epoch:4.


Evaluate: 100%|██████████| 34/34 [00:01<00:00, 23.26it/s]

Evaluate Score: {'Accuracy': 0.825}



