# FastText实现文本分类

在本教程中，我们将在MindSpore中使用`MindRecord`加载并构建文本数据集，用户可以从教程中了解到如何：

- 创建迭代数据集
- 将文本转换为向量
- 对数据进行shuffle等操作

此外，本教程使用N-Gram，即N元语法模型来判断语句单词的构成顺序。N-Gram可以按照字节顺序，将文本内容进行大小为N的划窗操作，最终形成长度为N的字节片段序列。实践中经常使用二元或三元模型，本例通过将`ngram`参数设定为2，将二元模型应用在文本分类案例中。

## 数据处理

点击下载[文本分类AG_NEWS数据集](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/middleclass/ag_news_csv.tgz) ，在教程的同级目录下新建`data`文件夹，将下载好的数据集存放在`data`中。
目录如下：
```
project
│  text_sentiment_ngrams_tutorial.ipynb      
└─data
   │   train.csv
   │   test.csv
```
在进行其他操作之前，需要先安装`sklearn`和`spacy`工具包，并导入所需要的库并进行参数设置。

In [1]:
import csv
import os
import re
import argparse
import ast
import html

import spacy
import numpy as np
from mindspore import nn
from mindspore import context
import mindspore.ops as ops
from mindspore import dataset as ds
from mindspore.mindrecord import FileWriter
import mindspore.common.dtype as mstype
from mindspore import Tensor,Model,ParameterTuple
from mindspore.context import ParallelMode
import mindspore.dataset.transforms.c_transforms as deC
from mindspore.common.initializer import XavierUniform
from sklearn.feature_extraction import FeatureHasher
from sklearn.metrics import accuracy_score, classification_report

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--ngram', type=int, default=2, required=False)
parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset')
parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')
parser.add_argument('--test_bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')
parser.add_argument('--feature_size', type=int, default=10000000, help='hash feature size')
parser.add_argument('--device_target', type=str, default="GPU", choices=['Ascend', 'GPU'])
args = parser.parse_known_args()[0]
context.set_context(
    mode=context.GRAPH_MODE,
    save_graphs=False,
    device_target=args.device_target)

### 读取数据

定义数据预处理函数，填充至训练集与测试集。

In [3]:
class FastTextDataPreProcess():
    """FastText数据预处理"""
    def __init__(self, train_path, test_file, max_length,class_num,ngram, train_feature_dict,
                 buckets, test_feature_dict, test_bucket,feature_size):
        self.train_path = train_path
        self.test_path = test_file
        self.max_length = max_length
        self.class_num = class_num
        self.train_feature_dict = train_feature_dict
        self.test_feature_dict = test_feature_dict
        self.test_bucket = test_bucket
        self.feature_size = feature_size
        self.buckets = buckets
        self.ngram = ngram
        self.text_greater = '>'
        self.text_less = '<'
        self.word2vec = dict()
        self.vec2words = dict()
        self.non_str = '\\'
        self.end_string = ['.', '?', '!']
        self.word2vec['PAD'] = 0
        self.vec2words[0] = 'PAD'
        self.word2vec['UNK'] = 1
        self.vec2words[1] = 'UNK'
        self.str_html = re.compile(r'<[^>]+>')

    def common_block(self, _pair_sen, spacy_nlp):
        """数据通用模块"""
        label_idx = int(_pair_sen[0]) - 1
        if len(_pair_sen) == 3:
            src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
                                               src_text2=_pair_sen[2],
                                               spacy_nlp=spacy_nlp,
                                               train_mode=True)
            src_tokens_length = len(src_tokens)
        elif len(_pair_sen) == 2:
            src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
                                               src_text2=None,
                                               spacy_nlp=spacy_nlp,
                                               train_mode=True)
            src_tokens_length = len(src_tokens)
        elif len(_pair_sen) == 4:
            if _pair_sen[2]:
                sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
            else:
                sen_o_t = _pair_sen[1]
            src_tokens = self.input_preprocess(src_text1=sen_o_t,
                                               src_text2=_pair_sen[3],
                                               spacy_nlp=spacy_nlp,
                                               train_mode=True)
            src_tokens_length = len(src_tokens)
        return src_tokens, src_tokens_length, label_idx

    def load(self):
        """数据读取"""
        train_dataset_list = []
        test_dataset_list = []
        spacy_nlp = spacy.load('en_core_web_sm', disable=['parser', 'tagger', 'ner','lemmatizer'])
        spacy_nlp.add_pipe('sentencizer')
        print("开始处理训练数据")
        with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file:
            reader = csv.reader(src_file, delimiter=",", quotechar='"')
            for _, _pair_sen in enumerate(reader):
                src_tokens, src_tokens_length, label_idx = self.common_block(_pair_sen=_pair_sen,
                                                                             spacy_nlp=spacy_nlp)
                train_dataset_list.append([src_tokens, src_tokens_length, label_idx])

        print("开始处理测试数据")
        with open(self.test_path, 'r', newline='', encoding='utf-8') as test_file:
            reader2 = csv.reader(test_file, delimiter=",", quotechar='"')
            for _, _test_sen in enumerate(reader2):
                label_idx = int(_test_sen[0]) - 1
                if len(_test_sen) == 3:
                    src_tokens = self.input_preprocess(src_text1=_test_sen[1],
                                                       src_text2=_test_sen[2],
                                                       spacy_nlp=spacy_nlp,
                                                       train_mode=False)
                    src_tokens_length = len(src_tokens)
                elif len(_test_sen) == 2:
                    src_tokens = self.input_preprocess(src_text1=_test_sen[1],
                                                       src_text2=None,
                                                       spacy_nlp=spacy_nlp,
                                                       train_mode=False)
                    src_tokens_length = len(src_tokens)
                elif len(_test_sen) == 4:
                    if _test_sen[2]:
                        sen_o_t = _test_sen[1] + ' ' + _test_sen[2]
                    else:
                        sen_o_t = _test_sen[1]
                    src_tokens = self.input_preprocess(src_text1=sen_o_t,
                                                       src_text2=_test_sen[3],
                                                       spacy_nlp=spacy_nlp,
                                                       train_mode=False)
                    src_tokens_length = len(src_tokens)

                test_dataset_list.append([src_tokens, src_tokens_length, label_idx])
                
        # 填充训练数据
        train_dataset_list_length = len(train_dataset_list)
        test_dataset_list_length = len(test_dataset_list)
        for l in range(train_dataset_list_length):
            bucket_length = self._get_bucket_length(train_dataset_list[l][0], self.buckets)
            while len(train_dataset_list[l][0]) < bucket_length:
                train_dataset_list[l][0].append(self.word2vec['PAD'])
            train_dataset_list[l][1] = len(train_dataset_list[l][0])
        # 填充测试数据
        for j in range(test_dataset_list_length):
            test_bucket_length = self._get_bucket_length(test_dataset_list[j][0], self.test_bucket)
            while len(test_dataset_list[j][0]) < test_bucket_length:
                test_dataset_list[j][0].append(self.word2vec['PAD'])
            test_dataset_list[j][1] = len(test_dataset_list[j][0])

        train_example_data = []
        test_example_data = []
        for idx in range(train_dataset_list_length):
            train_example_data.append({
                "src_tokens": train_dataset_list[idx][0],
                "src_tokens_length": train_dataset_list[idx][1],
                "label_idx": train_dataset_list[idx][2],
            })
            for key in self.train_feature_dict:
                if key == train_example_data[idx]['src_tokens_length']:
                    self.train_feature_dict[key].append(train_example_data[idx])
        for h in range(test_dataset_list_length):
            test_example_data.append({
                "src_tokens": test_dataset_list[h][0],
                "src_tokens_length": test_dataset_list[h][1],
                "label_idx": test_dataset_list[h][2],
            })
            for key in self.test_feature_dict:
                if key == test_example_data[h]['src_tokens_length']:
                    self.test_feature_dict[key].append(test_example_data[h])
        print("train vocab size is ", len(self.word2vec))

        return self.train_feature_dict, self.test_feature_dict

    def input_preprocess(self, src_text1, src_text2, spacy_nlp, train_mode):
        """数据处理函数"""
        src_text1 = src_text1.strip()
        if src_text1 and src_text1[-1] not in self.end_string:
            src_text1 = src_text1 + '.'

        if src_text2:
            src_text2 = src_text2.strip()
            sent_describe = src_text1 + ' ' + src_text2
        else:
            sent_describe = src_text1
        if self.non_str in sent_describe:
            sent_describe = sent_describe.replace(self.non_str, ' ')

        sent_describe = html.unescape(sent_describe)

        if self.text_less in sent_describe and self.text_greater in sent_describe:
            sent_describe = self.str_html.sub('', sent_describe)

        doc = spacy_nlp(sent_describe)
        bows_token = [token.text for token in doc]

        try:
            tagged_sent_desc = '<p> ' + ' </s> '.join([s.text for s in doc.sents]) + ' </p>'
        except ValueError:
            tagged_sent_desc = '<p> ' + sent_describe + ' </p>'
        doc = spacy_nlp(tagged_sent_desc)
        ngrams = self.generate_gram([token.text for token in doc], num=self.ngram)

        bo_ngrams = bows_token + ngrams

        if train_mode is True:
            for ngms in bo_ngrams:
                idx = self.word2vec.get(ngms)
                if idx is None:
                    idx = len(self.word2vec)
                    self.word2vec[ngms] = idx
                    self.vec2words[idx] = ngms

        processed_out = [self.word2vec[ng] if ng in self.word2vec else self.word2vec['UNK'] for ng in bo_ngrams]

        return processed_out

    def _get_bucket_length(self, x, bts):
        x_len = len(x)
        for index in range(1, len(bts)):
            if bts[index - 1] < x_len <= bts[index]:
                return bts[index]
        return bts[0]

    def generate_gram(self, words, num=2):
        return [' '.join(words[i: i + num]) for i in range(len(words) - num + 1)]

### 生成预处理数据

现在调用上一步定义好的`FastTextDataPreProcess`函数，获取训练与测试的预处理数据，以便于下一步使用`mindspore.dataset.MindRecord`接口进一步转换数据格式。

In [4]:
train_feature_dicts = {}
# 通过循环将bucket中的长度都加载到空字典
for i in args.bucket:
    train_feature_dicts[i] = []
test_feature_dicts = {}
for i in args.test_bucket:
    test_feature_dicts[i] = []
data_path = "./data/"
# 读取bucket的test和train数据进行处理
g_d = FastTextDataPreProcess(train_path=os.path.join(data_path, "train.csv"),
                             test_file=os.path.join(data_path, "test.csv"),
                             max_length=args.max_len,
                             ngram=args.ngram,
                             class_num=True,
                             train_feature_dict=train_feature_dicts,
                             buckets=args.bucket,
                             test_feature_dict=test_feature_dicts,
                             test_bucket=args.test_bucket,
                             feature_size=args.feature_size)
train_data_example, test_data_example = g_d.load()

开始处理训练数据
开始处理测试数据
train vocab size is  1071957


### 完成MindRecord转换

接下来我们通过定义`write_to_mindrecord`方法来将预处理后的基本数据转换为MindRecord格式，该方法提供两个参数：

data：AG_NEWS数据集的路径。

path：定义生成MindRecord格式文件路径。

In [5]:
def write_to_mindrecord(data, path, shared_num=1):
    """生成MindRecord"""
    if not os.path.isabs(path):
        path = os.path.abspath(path)

    writer = FileWriter(path, shared_num)
    data_schema = {
        "src_tokens": {"type": "int32", "shape": [-1]},
        "src_tokens_length": {"type": "int32", "shape": [-1]},
        "label_idx": {"type": "int32", "shape": [-1]}
    }
    writer.add_schema(data_schema, "fasttext")
    for item in data:
        item['src_tokens'] = np.array(item['src_tokens'], dtype=np.int32)
        item['src_tokens_length'] = np.array(item['src_tokens_length'], dtype=np.int32)
        item['label_idx'] = np.array(item['label_idx'], dtype=np.int32)
        writer.write_raw_data([item])
    writer.commit()

遍历原始数据集，将所有数据全部写为MindRecord数据格式。

In [6]:
# 通过循环来将文件转换成拼接的MindRecord文件
print("Writing train data to MindRecord file.....")
for i in args.bucket:
    write_to_mindrecord(train_data_example[i], './train/train_dataset_bs_' + str(i) + '.mindrecord', 1)
print("Writing test data to MindRecord file.....")
for k in args.test_bucket:
    write_to_mindrecord(test_data_example[k], './test/test_dataset_bs_' + str(k) + '.mindrecord', 1)

Writing train data to MindRecord file.....
Writing test data to MindRecord file.....


### 生成统一数据集

经过`write_to_mindrecord`，现在我们已经得到了全部数据的MindRecord格式的数据集，接下来进一步调用`load_dataset`方法，实现如下功能：

1. 循环遍历所有MindRecord文件。
2. 将读取的数据通过`batch_per_bucket`合并到统一数据集。

In [7]:
def load_dataset(dataset_path,batch_size,epoch_count=1, rank_size=1,rank_id=0,bucket=None, shuffle=True):
    """数据集读取"""

    def batch_per_bucket(bucket_length, input_file):
        input_file = input_file + 'train/train_dataset_bs_' + str(bucket_length) + '.mindrecord'
        if not input_file:
            raise FileNotFoundError("input file parameter must not be empty.")

        data_set = ds.MindDataset(input_file,
                                  columns_list=['src_tokens', 'src_tokens_length', 'label_idx'],
                                  shuffle=shuffle,
                                  num_shards=rank_size,
                                  shard_id=rank_id,
                                  num_parallel_workers=4)
        ori_dataset_size = data_set.get_dataset_size()
        print(f"Dataset size: {ori_dataset_size}")
        repeat_count = epoch_count

        data_set = data_set.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
                                   output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
        data_set = data_set.batch(batch_size, drop_remainder=False)
        data_set = data_set.repeat(repeat_count)
        return data_set

    for i, _ in enumerate(bucket):
        bucket_len = bucket[i]
        ds_per = batch_per_bucket(bucket_len, dataset_path)
        if i == 0:
            data_set = ds_per
        else:
            data_set = data_set + ds_per
    data_set = data_set.shuffle(data_set.get_dataset_size())
    data_set.channel_name = 'fasttext'

    return data_set

### 生成训练数据

通过`load_dataset`生成训练数据，其中的四个参数为：

dataset：文件存取路径。

batch_size：设定训练的batch。

epoch_count：设定训练进行的epoch。

bucket：数据中bucket的拼接长度。

In [8]:
preprocessed_data = load_dataset(dataset_path="",
                                     batch_size=512,
                                     epoch_count=1,
                                     bucket=[64,128,467])

Dataset size: 4780
Dataset size: 73255
Dataset size: 6706


## 定义模型

论文[Bag of Tricks for Efficient Text Classification](https://arxiv.org/pdf/1607.01759.pdf)中详细阐述了FastText模型的实现原理，模型结构如图所示：

![ ](images/fasttext.png)

FastText模型主要由输入层、隐藏层和输出层组成。其中输入是单词序列，通常以文本或句子的形式出现。输出层是词序列属于不同类别的概率。隐藏层是多个词向量的叠加平均。特征通过线性变换映射到隐藏层，再从隐藏层映射到标签。

下面定义FastText网络。

In [9]:
class FastText(nn.Cell):
    def __init__(self, vocab_size, embedding_dims, num_class):
        """定义FastText网络"""
        super(FastText, self).__init__()
        self.vocab_size = vocab_size
        self.embeding_dims = embedding_dims
        self.num_class = num_class
        self.embeding_func = nn.Embedding(vocab_size=self.vocab_size,
                                          embedding_size=self.embeding_dims,
                                          padding_idx=0, embedding_table='Zeros')
        self.fc = nn.Dense(self.embeding_dims, out_channels=self.num_class,
                           weight_init=XavierUniform(1)).to_float(mstype.float16)
        self.reducesum = ops.operations.ReduceSum()
        self.expand_dims = ops.operations.ExpandDims()
        self.squeeze = ops.operations.Squeeze(axis=1)
        self.cast = ops.operations.Cast()
        self.tile = ops.operations.Tile()
        self.realdiv = ops.operations.RealDiv()
        self.fill = ops.operations.Fill()
        self.log_softmax = nn.LogSoftmax(axis=1)
        
    def construct(self, src_tokens, src_token_length):
        """ FastText网络构建 """
        src_tokens = self.embeding_func(src_tokens)
        embeding = self.reducesum(src_tokens, 1)
        embeding = self.realdiv(embeding, src_token_length)
        embeding = self.cast(embeding, mstype.float16)
        classifier = self.fc(embeding)
        classifier = self.cast(classifier, mstype.float32)
        return classifier

## 启动实例

`AG_NEWS`数据集具有四个标签，因此类别数是四个。

```py
1 : World
2 : Sports
3 : Business
4 : Sci/Tec

```

在网络中，`vocab_size`为词汇数据的长度，其中包括单个单词和N元组。类的数量等于标签的数量，在`AG_NEWS`情况下为4。

In [10]:
fast_text_net = FastText(1383812, 16, 4)

## 用于生成批量的函数

由于文本条目的长度不同，所以使用自定义函数`batch_per_bucket`生成批量数据和偏移量。该函数传递到`mindspore.dataset.MindDataset`中的`inpufile`。`inputfile`的输入是张量文件，其大小为`batch_size`，函数将它们打包成一个小批量。

In [11]:
def batch_per_bucket(bucket_length, input_file):
        input_file = input_file + 'train/train_dataset_bs_' + str(bucket_length) + '.mindrecord'
        if not input_file:
            raise FileNotFoundError("input file parameter must not be empty.")

        data_set = ds.MindDataset(input_file,
                                  columns_list=['src_tokens', 'src_tokens_length', 'label_idx'],
                                  shuffle=shuffle,
                                  num_shards=rank_size,
                                  shard_id=rank_id,
                                  num_parallel_workers=4)
        ori_dataset_size = data_set.get_dataset_size()
        print(f"Dataset size: {ori_dataset_size}")
        repeat_count = epoch_count

        data_set = data_set.rename(input_columns=['src_tokens', 'src_tokens_length', 'label_idx'],
                                   output_columns=['src_token_text', 'src_tokens_text_length', 'label_idx_tag'])
        data_set = data_set.batch(batch_size, drop_remainder=False)
        data_set = data_set.repeat(repeat_count)
        return data_set

## 模型训练

我们在此处使用MindSpore数据集接口`MindDataset`加载`AG_NEWS`数据集，并将其发送到模型以进行训练/验证。

### 提供FastTextloss计算

我们已经在前面定义了一个完整的`FastText`网络，现在需要来为网络提供一个计算loss值的方法，这一过程由`FastTextNetWithLoss`类来实现。

In [12]:
class FastTextNetWithLoss(nn.Cell):
    """
   提供FastText的loss运算
    """
    def __init__(self,network, vocab_size, embedding_dims, num_class):
        super(FastTextNetWithLoss, self).__init__()
        self.fasttext = network
        self.loss_func = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        self.squeeze = ops.operations.Squeeze(axis=1)
        self.print = ops.operations.Print()

    def construct(self, src_tokens, src_tokens_lengths, label_idx):
        """
        带有loss的FastText网络
        """
        predict_score = self.fasttext(src_tokens, src_tokens_lengths)
        label_idx = self.squeeze(label_idx)
        predict_score = self.loss_func(predict_score, label_idx)

        return predict_score

### 创建网络计算loss值

在这一步中实例化`FastTextNetWithLoss`类。将定义好的网络`FastTextNet`、vocab的大小、embedding的数量和类别数放入到实例中。

In [13]:
net_with_loss = FastTextNetWithLoss(fast_text_net, 1383812, 16, 4)
net_with_loss.init_parameters_data()

{Parameter (name=fasttext.embeding_func.embedding_table, shape=(1383812, 16), dtype=Float32, requires_grad=True): Parameter (name=fasttext.embeding_func.embedding_table, shape=(1383812, 16), dtype=Float32, requires_grad=True),
 Parameter (name=fasttext.fc.weight, shape=(4, 16), dtype=Float32, requires_grad=True): Parameter (name=fasttext.fc.weight, shape=(4, 16), dtype=Float32, requires_grad=True),
 Parameter (name=fasttext.fc.bias, shape=(4,), dtype=Float32, requires_grad=True): Parameter (name=fasttext.fc.bias, shape=(4,), dtype=Float32, requires_grad=True)}

### 设置学习率和优化器

现在我们需要为`mindspore.nn.optim.Adam`优化器来定义一个学习率变化方式，以此来为优化器提供所需学习率参数。

In [15]:
from mindspore.nn.optim import Adam
from mindspore.nn import piecewise_constant_lr

learn_rate = 0.2
min_lr = 0.000001
decay_steps = preprocessed_data.get_dataset_size()
update_steps = 5 * preprocessed_data.get_dataset_size()
lr_step = [i+1 for i in range(update_steps)]
lr_list = [learn_rate - min_lr * i for i in range(update_steps)]
lr = Tensor(piecewise_constant_lr(lr_step,lr_list), dtype=mstype.float32)
print(type(lr))
optimizer = Adam(net_with_loss.trainable_params(), lr, beta1=0.9, beta2=0.999)

<class 'mindspore.common.tensor.Tensor'>


### 定义训练pipeline

当所有准备完毕后，我们要规划一次训练所需要的pipeline，于是定义了`TrainOneStepCell`类，该类主要实现以下方法：

set_sens：将获取值转为sens类型方便后续传入`tuple_to_array`转换。

construct：定义一次训练结算所需要的流程。

In [16]:
class FastTextTrainOneStepCell(nn.Cell):
    def __init__(self, network, optimizer, sens=1.0):
        super(FastTextTrainOneStepCell, self).__init__(auto_prefix=False)
        self.network = network
        self.weights = ParameterTuple(network.trainable_params())
        self.optimizer = optimizer
        self.grad = ops.composite.GradOperation(get_by_list=True, sens_param=True)
        self.sens = sens
        self.reducer_flag = False
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode not in ParallelMode.MODE_LIST:
            raise ValueError("Parallel mode does not support: ", self.parallel_mode)
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = None
        if self.reducer_flag:
            mean = context.get_auto_parallel_context("gradients_mean")
            degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)

        self.hyper_map =  ops.composite.HyperMap()
        self.cast = ops.operations.Cast()

    def set_sens(self, value):
        self.sens = value

    def construct(self,
                  src_token_text,
                  src_tokens_text_length,
                  label_idx_tag):
        """定义执行运算."""
        weights = self.weights
        loss = self.network(src_token_text,
                            src_tokens_text_length,
                            label_idx_tag)
        grads = self.grad(self.network, weights)(src_token_text,
                                                 src_tokens_text_length,
                                                 label_idx_tag,
                                                 self.cast(ops.functional.tuple_to_array((self.sens,)),
                                                           mstype.float32))
        grads = self.hyper_map(ops.functional.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
        if self.reducer_flag:
            # 实现梯度消除
            grads = self.grad_reducer(grads)

        succ = self.optimizer(grads)
        return ops.functional.depend(loss, succ)

### 定义梯度

因为本次梯度所需格式的不同，需要通过`clip_grad`修饰器重新定义`_clip_grad`传入参数的类型，如下所示：

- clip_type为数字类型。

- clip_value为数字类型。

- grad为张量类型。

In [17]:
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
clip_grad =  ops.composite.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
    if clip_type not in (0, 1):
        return grad
    dt = ops.functional.dtype(grad)
    if clip_type == 0:
        new_grad =  ops.composite.clip_by_value(grad, ops.functional.cast(ops.functional.tuple_to_array((-clip_value,)), dt),
                                   ops.functional.cast(ops.functional.tuple_to_array((clip_value,)), dt))
    else:
        new_grad = nn.ClipByNorm()(grad, ops.functional.cast(ops.functional.tuple_to_array((clip_value,)), dt))
    return new_grad

### 进行模型训练

调用之前设定的`FastTextTrainOneStepCell`并迭代数据集，完成模型训练。

In [18]:
net_with_grads = FastTextTrainOneStepCell(net_with_loss, optimizer=optimizer)
net_with_grads.set_train(True)

FastTextTrainOneStepCell<
  (network): FastTextNetWithLoss<
    (fasttext): FastText<
      (embeding_func): Embedding<vocab_size=1383812, embedding_size=16, use_one_hot=False, embedding_table=Parameter (name=fasttext.embeding_func.embedding_table, shape=(1383812, 16), dtype=Float32, requires_grad=True), dtype=Float32, padding_idx=0>
      (fc): Dense<input_channels=16, output_channels=4, has_bias=True>
      (log_softmax): LogSoftmax<>
      >
    (loss_func): SoftmaxCrossEntropyWithLogits<>
    >
  (optimizer): Adam<
    (learning_rate): _IteratorLearningRate<>
    >
  >

In [20]:
# 进行epoch训练
for i in range(20):
    for d in preprocessed_data.create_dict_iterator():
        net_with_grads(d["src_token_text"],len(d["src_token_text"]),d["label_idx_tag"])
        # 输出loss值
        print(net_with_loss(d["src_token_text"],len(d["src_token_text"]),d["label_idx_tag"]))

1.3239299
1.2918508
1.236133
1.1651388
1.074889
1.1294309
0.9561551
0.9522176
0.91801494
0.8881521
0.80080545
0.7337659
0.6696707
0.63573897
0.5883118
0.23005332
0.4515081
0.20126605
0.4553006
0.21953695
0.15097088
0.22751673
0.299681
0.23459665
0.17367001
0.32614958
0.24170385
0.18644962
0.14626658
0.18693896
0.22911525
0.30018106
0.28360566
0.22088502
0.21194872
0.17272016
0.21119592
0.21003135
0.17690946
0.18701789
0.22161637
0.18359481
0.25332585
0.1607348
0.18905574
0.21450931
0.4525343
0.048400477
0.06543859
0.04598104
0.046952773
0.05878158
0.05802965
0.021141667
0.016563205
0.0599133
0.03379585
0.020350233
0.033926312
0.10194215
0.034460913
0.055590115
0.014893334
0.060085252
0.028355705
0.056327038
0.024952719
0.032113466
0.023740696
0.01511923
0.034571428
0.037790537
0.07907674
0.032159526
0.046872605
0.028533353
0.0076825884
0.0077427584
0.040141877
0.013469651
0.029853245
1.1512634
0.010118321
0.025405075
0.026934445
0.031721305
0.042373456
0.0452683
0.07718848
0.06898584
0

## 使用测试数据集评估模型

### 读取验证集

如同读取训练数据集一样，这里定义`load_infer_dataset`方法来读取测试数据集，其中入参分别为：

batch_size：测试集中的batch数量。

datafile：读取测试数据的路径。

bucket：数据中bucket的拼接长度。

In [21]:
def load_infer_dataset(batch_size, datafile, bucket):
    """测试数据加载"""
    def batch_per_bucket(bucket_length, input_file):
        input_file = input_file + 'test/test_dataset_bs_' + str(bucket_length) + '.mindrecord'
        if not input_file:
            raise FileNotFoundError("input file parameter must not be empty.")

        data_set = ds.MindDataset(input_file,
                                  columns_list=['src_tokens', 'src_tokens_length', 'label_idx'])
        type_cast_op = deC.TypeCast(mstype.int32)
        data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens")
        data_set = data_set.map(operations=type_cast_op, input_columns="src_tokens_length")
        data_set = data_set.map(operations=type_cast_op, input_columns="label_idx")

        data_set = data_set.batch(batch_size, drop_remainder=False)
        return data_set
    for i, _ in enumerate(bucket):
        bucket_len = bucket[i]
        ds_per = batch_per_bucket(bucket_len, datafile)
        if i == 0:
            data_set = ds_per
        else:
            data_set = data_set + ds_per

    return data_set

### 定义验证方法

现在传入训练后的网络`network`，通过`FastTextInferCell`来完成我们的验证流程。

In [22]:
class FastTextInferCell(nn.Cell):

    def __init__(self, network):
        super(FastTextInferCell, self).__init__(auto_prefix=False)
        self.network = network
        self.argmax = ops.operations.ArgMaxWithValue(axis=1, keep_dims=True)
        self.log_softmax = nn.LogSoftmax(axis=1)

    def construct(self, src_tokens, src_tokens_lengths):
        prediction = self.network(src_tokens, src_tokens_lengths)
        predicted_idx = self.log_softmax(prediction)
        predicted_idx, _ = self.argmax(predicted_idx)

        return predicted_idx

### 读取数据并推理模型

最后，实例化`load_infer_dataset`和`FastTextInferCell`来模型推理：

In [23]:
load_test_data = load_infer_dataset(batch_size=512,
                                     datafile="",
                                     bucket=[64,128,467])
ft_infer = FastTextInferCell(fast_text_net)
predictions = []
target_sens = []
model = Model(ft_infer)
for batch in load_test_data.create_dict_iterator(output_numpy=True, num_epochs=1):
    target_sens.append(batch['label_idx'])
    src_tokens = Tensor(batch['src_tokens'], mstype.int32)
    src_tokens_length = Tensor(batch['src_tokens_length'], mstype.int32)
    predicted_idx = ft_infer(src_tokens,src_tokens_length)
    predictions.append(predicted_idx.asnumpy())

### 评估模型

计算模型的预测值与真实值之前的误差，输出模型的每个batch精度。

In [24]:
predictions = np.array(predictions).flatten()
merge_predictions = []
for prediction in predictions:
    merge_predictions.extend([prediction])
predictions = merge_predictions
target_sens = np.array(target_sens).flatten()
merge_target_sens = []
for target_sen in target_sens:
    merge_target_sens.extend([target_sen])
target_sens = merge_target_sens

for i in range(len(target_sens)):
    acc = accuracy_score(target_sens[i], predictions[i])
    print("Accuracy: ", acc)

Accuracy:  0.8404494382022472
Accuracy:  0.9140625
Accuracy:  0.912109375
Accuracy:  0.91796875
Accuracy:  0.923828125
Accuracy:  0.93359375
Accuracy:  0.9453125
Accuracy:  0.923828125
Accuracy:  0.90625
Accuracy:  0.9140625
Accuracy:  0.9375
Accuracy:  0.91796875
Accuracy:  0.923828125
Accuracy:  0.9050772626931567
Accuracy:  0.912109375
Accuracy:  0.9347826086956522
