In [1]:
import mindspore
from mindspore.dataset import GeneratorDataset, transforms
from mindnlp.transforms import NezhaTokenizer, PadTransform

  from tqdm.autonotebook import tqdm


In [1]:
!wget https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/test.txt
!wget https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/dev.txt
!wget https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/train.txt

--2023-06-26 20:04:12--  https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/test.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... 

connected.
HTTP request sent, awaiting response... 200 OK
Length: 551596 (539K) [text/plain]
Saving to: ‘test.txt’


2023-06-26 20:04:24 (53.5 KB/s) - ‘test.txt’ saved [551596/551596]

--2023-06-26 20:04:24--  https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/dev.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 551313 (538K) [text/plain]
Saving to: ‘dev.txt’


2023-06-26 20:04:27 (576 KB/s) - ‘dev.txt’ saved [551313/551313]

--2023-06-26 20:04:27--  https://raw.githubusercontent.com/JackHCC/Chinese-Text-Classification-PyTorch/master/THUCNews/data/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercon

In [9]:
# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[:-1]:
            text_a, label = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)

In [10]:
def process_dataset(source, tokenizer, pad_value, max_seq_len=64, batch_size=32, shuffle=True):
    column_names = ["label", "text_a"]
    rename_columns = ["label", "input_ids"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    pad_op = PadTransform(max_seq_len, pad_value=pad_value)
    type_cast_op = transforms.TypeCast(mindspore.int32)
    
    # map dataset
    dataset = dataset.map(operations=[tokenizer, pad_op], input_columns="text_a")
    dataset = dataset.map(operations=[type_cast_op], input_columns="label")
    # rename dataset
    dataset = dataset.rename(input_columns=column_names, output_columns=rename_columns)
    # batch dataset
    dataset = dataset.batch(batch_size)

    return dataset

In [7]:
tokenizer = NezhaTokenizer.from_pretrained('nezha-cn-base')
pad_value = tokenizer.token_to_id('[PAD]')

In [11]:
dataset_train = process_dataset(SentimentDataset("train.txt"), tokenizer, pad_value)
dataset_val = process_dataset(SentimentDataset("dev.txt"), tokenizer, pad_value)
dataset_test = process_dataset(SentimentDataset("test.txt"), tokenizer, pad_value, shuffle=False)

In [12]:
import json
from mindnlp.models.nezha import NezhaConfig, NezhaForSequenceClassification

with open("../ckpt_ms/nezha-cn-base/config.json") as f:
    config = json.load(f)
config = NezhaConfig(**config)
config.num_labels = 10
model = NezhaForSequenceClassification(config)

import mindspore as ms

params_dict = ms.load_checkpoint("nezha_classfication_epoch_2.ckpt")
params_not_load = ms.load_param_into_net(model, params_dict)

In [7]:
from mindnlp._legacy.amp import auto_mixed_precision
from mindspore import nn, ops
model = auto_mixed_precision(model, 'O1')

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

def forward(input_ids, label):
    outputs = model(input_ids)
    loss = loss_fn(outputs[0], label)
    return loss

grad_fn = ops.value_and_grad(forward, None, model.trainable_params())

def train_step(input_ids, label):
    loss, grads = grad_fn(input_ids, label)
    optimizer(grads)
    return loss



In [8]:
import os
from tqdm import tqdm
from mindspore.train.serialization import save_checkpoint
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

total = dataset_train.get_dataset_size()
for epoch in range(3):
    with tqdm(total=total) as progress:
        progress.set_description(f'Epoch {epoch}')
        loss_total = 0
        cur_step_nums = 0
        for batch, (label, data) in enumerate(dataset_train.create_tuple_iterator()):
            loss = train_step(data, label)
            loss_total += loss
            cur_step_nums += 1
            progress.set_postfix(loss=loss_total/cur_step_nums)
            progress.update(1)
        save_checkpoint(model, f"nezha_classfication_epoch_{epoch}.ckpt")

Epoch 0: 100%|██████████| 5625/5625 [24:31<00:00,  3.82it/s, loss=0.27862355]
Epoch 1: 100%|██████████| 5625/5625 [24:42<00:00,  3.80it/s, loss=0.11460448] 
Epoch 2: 100%|██████████| 5625/5625 [24:44<00:00,  3.79it/s, loss=0.06515992]   


In [16]:
from mindspore import Tensor

def predict(text, label=None):
    label_map = {0: '财经', 1: '房产', 2: '股票',
                 3: '教育', 4: '科技', 5: '社会',
                 6: '时政', 7: '运动', 8: '游戏', 9:'娱乐' 
                }
    text_tokenized = Tensor([tokenizer.encode(text).ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)

In [14]:
from mindnlp.metrics import Accuracy
from tqdm import tqdm
metric = Accuracy()
cur_step_nums = 0
acc_total = 0
total = dataset_val.get_dataset_size()
with tqdm(total=total) as progress:
    for label, data in dataset_val.create_tuple_iterator():
        cur_step_nums += 1
        pred = model(data)[0]
        metric.update(pred, label)
        acc_total += metric.eval()
        progress.set_postfix(acc=acc_total/cur_step_nums)
        progress.update(1)

100%|██████████| 313/313 [00:41<00:00,  7.63it/s, acc=0.941]


In [17]:
predict("原神是一款由米哈游开发的开放世界冒险")

inputs: '原神是一款由米哈游开发的开放世界冒险', predict: '游戏'
