In [1]:
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context

from mindnlp.transforms import PadTransform
from mindnlp.models import BertModel, BertConfig
from mindnlp.transforms.tokenizers import BertTokenizer

from mindnlp.engine import Trainer, Evaluator
from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp.metrics import Accuracy

In [2]:
# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)

In [3]:
# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

--2023-04-21 17:41:41--  https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz
Connecting to 172.20.106.122:7890... connected.
Proxy request sent, awaiting response... 200 OK
Length: 1710581 (1.6M) [application/x-gzip]
Saving to: ‘emotion_detection.tar.gz’


2023-04-21 17:41:43 (2.57 MB/s) - ‘emotion_detection.tar.gz’ saved [1710581/1710581]

data/
data/test.tsv
data/infer.tsv
data/dev.tsv
data/train.tsv
data/vocab.txt


In [4]:
!wget https://download.mindspore.cn/toolkits/mindnlp/models/bert/bert-base-chinese/vocab.txt -O vocab.txt

--2023-04-21 17:41:43--  https://download.mindspore.cn/toolkits/mindnlp/models/bert/bert-base-chinese/vocab.txt
Connecting to 172.20.106.122:7890... connected.
Proxy request sent, awaiting response... 200 OK
Length: 109540 (107K) [text/plain]
Saving to: ‘vocab.txt’


2023-04-21 17:41:45 (218 KB/s) - ‘vocab.txt’ saved [109540/109540]



In [5]:
def process_dataset(source, tokenizer, pad_value, max_seq_len=64, batch_size=32, shuffle=True):
    column_names = ["label", "text_a"]
    rename_columns = ["label", "input_ids"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    pad_op = PadTransform(max_seq_len, pad_value=pad_value)
    type_cast_op = transforms.TypeCast(mindspore.int32)
    
    # map dataset
    dataset = dataset.map(operations=[tokenizer, pad_op], input_columns="text_a")
    dataset = dataset.map(operations=[type_cast_op], input_columns="label")
    # rename dataset
    dataset = dataset.rename(input_columns=column_names, output_columns=rename_columns)
    # batch dataset
    dataset = dataset.batch(batch_size)

    return dataset

In [6]:
vocab = text.Vocab.from_file("vocab.txt")
vocab_size = len(vocab.vocab())

pad_value = vocab.tokens_to_ids('[PAD]')
tokenizer = BertTokenizer(vocab=vocab)

In [7]:
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer, pad_value)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer, pad_value)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, pad_value, shuffle=False)

In [8]:
# define model
class BertForSequenceClassification(nn.Cell):
    """Bert Model for classification tasks"""

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config
        self.bert = BertModel.from_pretrained('bert-base-chinese', config=config)

        self.classifier = nn.Dense(config.hidden_size, self.num_labels)

    def construct(self, input_ids, attention_mask=None, token_type_ids=None,
                  position_ids=None, head_mask=None):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask
        )
        pooled_output = outputs[1]
        logits = self.classifier(pooled_output)
        return logits

In [9]:
# set bert config and define parameters for training
config = BertConfig.from_pretrained('bert-base-chinese', num_labels=3)
# (vocab_size=vocab_size, num_labels=3)
model_instance = BertForSequenceClassification(config)

model_instance.set_train(True)

loss = nn.CrossEntropyLoss()
optimizer = nn.Adam(model_instance.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='sentiment_model', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', auto_load=True)

trainer = Trainer(network=model_instance, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=10, loss_fn=loss, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=True)



In [10]:
# start training
trainer.run(tgt_columns="label")


The train will start from the checkpoint saved in checkpoint.



Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [01:05<00:00,  4.61it/s, loss=0.54870176]


Checkpoint: sentiment_model_epoch_0.ckpt has been saved in epoch:0.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:10<00:00,  3.16it/s]


Evaluate Score: {'Accuracy': 0.9009259259259259}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 0.---------------


Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:45<00:00,  6.64it/s, loss=0.25947005]


Checkpoint: sentiment_model_epoch_1.ckpt has been saved in epoch:1.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.91it/s]


Evaluate Score: {'Accuracy': 0.9314814814814815}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 1.---------------


Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.56it/s, loss=0.19326684]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_2.ckpt has been saved in epoch:2.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.70it/s]


Evaluate Score: {'Accuracy': 0.9564814814814815}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 2.---------------


Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.54it/s, loss=0.15251678]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_3.ckpt has been saved in epoch:3.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.53it/s]


Evaluate Score: {'Accuracy': 0.962037037037037}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 3.---------------


Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.45it/s, loss=0.118863754]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_4.ckpt has been saved in epoch:4.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.39it/s]


Evaluate Score: {'Accuracy': 0.9805555555555555}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 4.---------------


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.46it/s, loss=0.09339741]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_5.ckpt has been saved in epoch:5.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.52it/s]


Evaluate Score: {'Accuracy': 0.9842592592592593}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 5.---------------


Epoch 6: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.47it/s, loss=0.07248868]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_6.ckpt has been saved in epoch:6.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.54it/s]


Evaluate Score: {'Accuracy': 0.9888888888888889}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 6.---------------


Epoch 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.47it/s, loss=0.059156317]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_7.ckpt has been saved in epoch:7.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.16it/s]


Evaluate Score: {'Accuracy': 0.9833333333333333}


Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.52it/s, loss=0.049899463]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_8.ckpt has been saved in epoch:8.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.57it/s]


Evaluate Score: {'Accuracy': 0.9944444444444445}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 8.---------------


Epoch 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 302/302 [00:46<00:00,  6.47it/s, loss=0.042433485]


The maximum number of stored checkpoints has been reached.
Checkpoint: sentiment_model_epoch_9.ckpt has been saved in epoch:9.


Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 19.51it/s]


Evaluate Score: {'Accuracy': 0.9972222222222222}
---------------Best Model: best_so_far.ckpt has been saved in epoch: 9.---------------
Loading best model from checkpoint with ['Accuracy']: [0.9972222222222222]...
---------------The model is already load the best model from best_so_far.ckpt.---------------


In [11]:
evaluator = Evaluator(network=model_instance, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="label")

Evaluate: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:03<00:00,  8.68it/s]

Evaluate Score: {'Accuracy': 0.888996138996139}



