In [1]:
import math
import numpy as np
import pandas as pd

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor

from src.config import cfg
from src.textcnn import TextCNN
from src.dataset import MovieReview

MindSpore version 1.1.1 and "topi" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install




In [2]:
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target, device_id=cfg.device_id)

In [3]:
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
dataset = instance.create_train_dataset(batch_size=cfg.batch_size,epoch_size=cfg.epoch_size)
batch_num = dataset.get_dataset_size() 

In [4]:
learning_rate = []
warm_up = [1e-3 / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in 
           range(math.floor(cfg.epoch_size / 5))]
shrink = [1e-3 / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))]
normal_run = [1e-3 for _ in range(batch_num) for i in 
              range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))]
learning_rate = learning_rate + warm_up + normal_run + shrink

In [5]:
net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, 
              num_classes=cfg.num_classes, vec_length=cfg.vec_length)

In [6]:
# Continue training if set pre_trained to be True
if cfg.pre_trained:
    param_dict = load_checkpoint(cfg.checkpoint_path)
    load_param_into_net(net, param_dict)

In [7]:
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), 
              learning_rate=learning_rate, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)

In [8]:
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})

In [9]:
config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2), keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt"
ckpoint_cb = ModelCheckpoint(prefix="train_textcnn", directory=ckpt_save_dir, config=config_ck)
loss_cb = LossMonitor()

In [10]:
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")

epoch: 1 step: 596, loss is 0.04297541
epoch time: 52296.017 ms, per step time: 87.745 ms
epoch: 2 step: 596, loss is 0.0065871133
epoch time: 4298.849 ms, per step time: 7.213 ms
epoch: 3 step: 596, loss is 0.0002644311
epoch time: 4260.524 ms, per step time: 7.149 ms
epoch: 4 step: 596, loss is 0.0017103986
epoch time: 4296.318 ms, per step time: 7.209 ms
train success


In [11]:
#验证
instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)
dataset = instance.create_train_dataset(batch_size=cfg.batch_size,epoch_size=cfg.epoch_size)
batch_num = dataset.get_dataset_size() 

In [12]:
checkpoint_path = './ckpt/train_textcnn-4_596.ckpt'

In [13]:
dataset = instance.create_test_dataset(batch_size=cfg.batch_size)
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), 
              learning_rate=0.001, weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net = TextCNN(vocab_len=instance.get_dict_len(),word_len=cfg.word_len,
                  num_classes=cfg.num_classes,vec_length=cfg.vec_length)

if checkpoint_path is not None:
    param_dict = load_checkpoint(checkpoint_path)
    print("load checkpoint from [{}].".format(checkpoint_path))
else:
    param_dict = load_checkpoint(cfg.checkpoint_path)
    print("load checkpoint from [{}].".format(cfg.checkpoint_path))

load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})

acc = model.eval(dataset)
print("accuracy: ", acc)

load checkpoint from [./ckpt/train_textcnn-4_596.ckpt].
accuracy:  {'acc': 0.76171875}
