# 基于昇思MindSpore实现MobileViT

## MobileViT简介

  轻量级卷积神经网络(CNN)是移动视觉任务的实际应用,它们的空间归纳偏差允许它们在不同的视觉任务中以较少的参数学习表征.然而,这些网络在空间上是局部.为了学习全局表征,采用基于自注意力的Vision Transformer(ViT).为了结合CNN和ViT的优势,构建一个轻量级、低延迟的移动视觉任务网络，因此提出MobileViT,这是一种轻量级的、通用的移动设备Vision Transformer，作者提出了一个不同的视角，以Transformer作为卷积处理信息。这是第一次基于轻量级CNN网络性能的轻量级ViT工作。

## mobilevit基本原理


![image.png](image/表-1.png)

**表-1**：MobileViT架构。在这里，d代表输入到转化器的维度如图-1中所示）。默认情况下，在MobileViT块中，将内核大小n设置为3。在MobileViT块中，贴片的空间尺寸（高度h和宽度w）为2。


![image.png](image/图-1.png)

**图-1**:MobileViT。这里，MobileViT块中的Conv-n×n代表标准的n×n卷积，MV2指的是MobileNetv2块。其中进行下采样的MV2用↓2标记。

MobileViT的灵感来自于轻量级CNN的理念，如上表-1中给出了不同参数预算下MobileViT的整体架构。MobileViT的初始层是一个分层的3×3标准卷积，然后是MobileNetv2（或MV2）块和MobileViT块。mobilevit使用Swish（作为激活函数。按照CNN模型，MobileViT块中使用n=3。特征图的空间维度通常是2的倍数，并且h，w≤n。因此，我们在所有的空间层面上设置h=w=2。MobileViT网络中的MV2块主要负责下采样。因此，在这些块中，使用4的扩展系数，除了MobileViT-XXS，都使用2的扩展系数。 MobileViT中的转换层需要一个d维的输入，如上图所示所示。我们将转换层中第一个前馈层的输出维度设置为2d，而不是4d，这是Vaswani等人的标准转换块中的默认值。

![image.png](image/图-2.png)

**图-2**：MobileViT显示出与CNN相似的泛化能力。MobileNetv2和ResNet-50的最终训练和验证误差分别用*和 ◦标记。

MobileViT带来了几个新的优点:

(i) 更好的性能。在给定的参数预算下，与现有的轻量级CNN相比，MobileViT模型在不同的移动视觉任务中取得了更好的性能。
(ii) 泛化能力。泛化能力是指训练和评估指标之间的差距。对于具有类似训练指标的两个模型，具有更好评价指标的模型更具有泛化能力，因为它可以在未见过的数据集上更好地预测。与之前的ViT变体（有或没有卷积）不同，即使与CNN相比有大量的数据增强，也显示出较差的泛化能力，MobileViT显示出更好的泛化能力（图-2）。
(iii) 鲁棒性。一个好的模型应该对超参数（如数据增强和L2正则化）具有鲁棒性，因为调整这些超参数是需要时间和资源的。与大多数基于ViT的模型不同，MobileViT模型用基本的增强训练，对L2正则化不太敏感。

## 官方库和第三方库的导入

In [2]:
import os
import math
import json
import time
import cv2
import argparse
import tempfile
import pathlib
import os.path
import numpy as np
from scipy import io
from PIL import Image
from enum import Enum
from collections import OrderedDict
from abc import ABCMeta, abstractmethod
from typing import Union, Optional, Dict, Tuple, List, Callable, Iterable


import mindspore
from mindspore import context
import mindspore.dataset as ds
from mindspore.train import Model
from mindspore import numpy as m_np
from mindspore.common import set_seed
from mindspore.nn.loss import LossBase
from mindspore.ops.operations import Add
from mindspore.ops import operations as P
from mindspore.dataset.vision import Inter
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator
from mindspore.train.callback import Callback
import mindspore.dataset.vision.c_transforms as c_transforms
import mindspore.dataset.vision.py_transforms as p_transforms
from mindspore import nn, Tensor, ops, load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor

## 模型结构

下面我来剖析MobileViT的结构，相关模块在Vision套件中可直接调用。

### 1）ConvNormActivation结构

ConvNormActivation模块是所有卷积网络中最基础的模块，由一个卷积层（Conv, Depwise Conv），一个归一化层(BN)，一个激活函数组成。模型中可以套用这个结构的的小模块：Conv+BN+Swish,Conv+BN

In [3]:
class ConvNormActivation(nn.Cell):
    """
    Convolution/Depthwise fused with normalization and activation blocks definition.

    Args:
        in_planes (int): Input channel.
        out_planes (int): Output channel.
        kernel_size (int): Input kernel size.
        stride (int): Stride size for the first convolutional layer. Default: 1.
        groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
        norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution
        layer. Default: nn.BatchNorm2d.
        activation (nn.Cell, optional): Activation function which will be stacked on top of the
        normalization layer (if not None), otherwise on top of the conv layer. Default: nn.ReLU.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> conv = ConvNormActivation(16, 256, kernel_size=1, stride=1, groups=1)
    """

    def __init__(self,
                 in_planes: int,
                 out_planes: int,
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm: Optional[nn.Cell] = nn.BatchNorm2d,
                 activation: Optional[nn.Cell] = nn.ReLU
                 ) -> None:
        super(ConvNormActivation, self).__init__()
        padding = (kernel_size - 1) // 2
        layers = [
            nn.Conv2d(
                in_planes,
                out_planes,
                kernel_size,
                stride,
                pad_mode='pad',
                padding=padding,
                group=groups
            )
        ]

        if norm:
            layers.append(norm(out_planes))
        if activation:
            layers.append(activation())

        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


## 2)激活函数：swish

Swish是Google在10月16号提出的一种新型激活函数,其原始公式为:f(x)=x\*sigmod(x),变形Swish-B激活函数的公式则为f(x)=x\*sigmod(b\*x),其拥有不饱和,光滑,非单调性的特征,而Google在论文中的多项测试表明Swish以及Swish-B激活函数的性能即佳,在不同的数据集上都表现出了要优于当前最佳激活函数的性能.

In [4]:
class Swish(nn.Cell):
    """
    swish activation function.

    Args:
        None

    Return:
        Tensor

    Example:
        >>> x = Tensor(((20, 16), (50, 50)), mindspore.float32)
        >>> Swish()(x)
    """

    def __init__(self) -> None:
        super(Swish, self).__init__()
        self.sigmoid = nn.Sigmoid()

    def construct(self, x) -> Tensor:
        """Swish construct."""
        return x * self.sigmoid(x)

## 3）vision Transformer

![image.png](image/图-3.png)

   **图-3**：标准的Visual Transformer

一个标准的ViT模型，如图-3所示，将输入的$\mathbf{X} \in \mathbb{R}^{H \times W \times C}$重塑为一序列的扁平化patches$\mathbf{X}_{f} \in \mathbb{R}^{N \times P C}$，投射到一个固定的d维空间$\mathbf{X}_{p} \in \mathbb{R}^{N \times d}$，然后用一个堆叠的L个transformer的块来学习patches间的表征。视觉变换器中自我注意的计算成本是$O\left(N^{2} d\right)$。这里，C、H和W分别代表张量的通道、高度和宽度，$P=w h$是高度为h、宽度为w的补丁中的像素数，N是补丁的数量。因为这些模型忽略了CNN中固有的空间归纳偏见，所以它们需要更多的参数来学习视觉表征。例如，DPT(一个基于ViT的网络)与DeepLabv3(一个基于CNN的网络)相比，需要学习6倍以上的参数来提供类似的分割性能。另外，与CNN相比，这些模型表现出次标准的可优化性。这些模型对L2正则化很敏感，需要大量的数据增量来防止过拟合。

MobileViT其核心思想是用transformer作为卷积来学习全局表征。这使能够在网络中隐含地纳入类似卷积的属性，用简单的训练方法学习表征，并轻松地将MobileViT与下游架构整合起来。

### 3.1)DropPath

In [5]:
class DropPath(nn.Cell):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, keep_prob=None, seed=0):
        super(DropPath, self).__init__()
        self.keep_prob = 1 - keep_prob
        seed = min(seed, 0)
        self.rand = P.UniformReal(seed=seed)
        self.shape = P.Shape()
        self.floor = P.Floor()

    def construct(self, x):
        if self.training:
            x_shape = self.shape(x)
            random_tensor = self.rand((x_shape[0], 1, 1))
            random_tensor = random_tensor + self.keep_prob
            random_tensor = self.floor(random_tensor)
            x = x / self.keep_prob
            x = x * random_tensor

        return x


### 3.2)Attention, FeedForward, ResidualCell

In [6]:
class Attention(nn.Cell):
    """
    Attention layer implementation, Rearrange Input -> B x N x hidden size.

    Args:
        dim (int): The dimension of input features.
        num_heads (int): The number of attention heads. Default: 8.
        keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0.
        attention_keep_prob (float): The keep rate for attention. Default: 1.0.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = Attention(768, 12)
    """

    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
        Validator.check_equal_int(dim % num_heads, 0, 'dim should be divisible by num_heads.')
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(keep_prob)

        self.mul = ops.Mul()
        self.reshape = ops.Reshape()
        self.transpose = ops.Transpose()
        self.unstack = ops.Unstack(axis=0)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = self.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = self.unstack(qkv)

        attn = self.q_matmul_k(q, k)
        attn = self.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        out = self.attn_matmul_v(attn, v)
        out = self.transpose(out, (0, 2, 1, 3))
        out = self.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

In [7]:
class FeedForward(nn.Cell):
    """
    Feed Forward layer implementation.

    Args:
        in_features (int): The dimension of input features.
        hidden_features (int): The dimension of hidden features. Default: None.
        out_features (int): The dimension of output features. Default: None
        activation (nn.Cell): Activation function which will be stacked on top of the
        normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU.
        keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = FeedForward(768, 3072)
    """

    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(keep_prob)

    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)

        return x

In [8]:
class ResidualCell(nn.Cell):
    """
    Cell which implements Residual function:

    $$output = x + f(x)$$

    Args:
        cell (Cell): Cell needed to add residual block.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = ResidualCell(nn.Dense(3,4))
    """

    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x


### 3.3) TransformerEncoder

In [9]:
class TransformerEncoder(nn.Cell):
    """
    Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
    attention and feedforward layer.

    Args:
        dim (int): The dimension of embedding.
        num_layers (int): The depth of transformer.
        num_heads (int): The number of attention heads.
        mlp_dim (int): The dimension of MLP hidden layer.
        keep_prob (float): The keep rate, greater than 0 and less equal than 1. Default: 1.0.
        attention_keep_prob (float): The keep rate for attention. Default: 1.0.
        drop_path_keep_prob (float): The keep rate for drop path. Default: 1.0.
        activation (nn.Cell): Activation function which will be stacked on top of the
        normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU.
        norm (nn.Cell, optional): Norm layer that will be stacked on top of the convolution
        layer. Default: nn.LayerNorm.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> ops = TransformerEncoder(768, 12, 12, 3072)
    """

    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        drop_path_rate = 1 - drop_path_keep_prob
        dpr = [i.item() for i in np.linspace(0, drop_path_rate, num_layers)]
        attn_seeds = [np.random.randint(1024) for _ in range(num_layers)]
        mlp_seeds = [np.random.randint(1024) for _ in range(num_layers)]

        layers = []
        for i in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            if drop_path_rate > 0:
                layers.append(
                    nn.SequentialCell([
                        ResidualCell(nn.SequentialCell([normalization1,
                                                        attention,
                                                        DropPath(dpr[i], attn_seeds[i])])),
                        ResidualCell(nn.SequentialCell([normalization2,
                                                        feedforward,
                                                        DropPath(dpr[i], mlp_seeds[i])]))]))
            else:
                layers.append(
                    nn.SequentialCell([
                        ResidualCell(nn.SequentialCell([normalization1,
                                                        attention])),
                        ResidualCell(nn.SequentialCell([normalization2,
                                                        feedforward]))
                    ])
                )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)


## 4)mobilevitblock

图-1所示的MobileViT模块，旨在用较少的参数对输入张量中的局部和全局信息进行建模。形式上，对于一个给定的输入张量$\mathbf{X} \in \mathbb{R}^{H \times W \times C}$，MobileViT应用一个n×n的标准卷积层，然后是一个 a point-wise（或1×1）卷积层，以产生$\mathbf{X}_{L} \in \mathbb{R}^{H \times W \times d}$。n×n卷积层编码局部空间信息，而a point-wise卷积通过学习输入通道的线性组合将张量投射到一个高维空间（或d维，其中d>C）。
通过MobileViT，我们希望在拥有H×W的有效接收场的同时，对长距离的非局部依赖进行建模。广泛研究的长距离依赖性建模方法之一是扩张卷积。然而，这种方法需要仔细选择扩张率。否则，权重将被应用于填充的零点而不是有效的空间区域。另一个有希望的解决方案是自我注意。在自我注意方法中，具有多头自我注意的视觉变压器（ViTs）被证明对视觉识别任务有效。然而，ViTs是重量级的，并且表现出不标准的可优化性。这是因为ViTs缺乏空间归纳偏差.
为了使MobileViT能够学习具有空间归纳偏见的全局表征，我们将$X_L$展开为N个不重叠的扁平化patches$\mathbf{X}_{U} \in \mathbb{R}^{P \times N \times d}$。这里，$P=wh$, $N=\frac{H W}{P}$是补丁的数量，h≤n和w≤n分别是一个补丁的高度和宽度。对于每个$p \in\{1, \cdots, P\}$，patches间的关系通过应用变换器进行编码，得到$\mathbf{X}_{G} \in \mathbb{R}^{P \times N \times d}$为:

$$
\huge \mathbf{X}_{G}(p)=\operatorname{Transformer}\left(\mathbf{X}_{U}(p)\right), 1 \leq p \leq P
$$

与失去像素空间顺序的ViTs不同，MobileViT既不失去补丁顺序，也不失去每个补丁内像素的空间顺序（图-1）。因此，我们可以对$\mathbf{X}_{G} \in \mathbb{R}^{P \times N \times d}$进行折叠，得到$\mathbf{X}_{F} \in \mathbb{R}^{H \times W \times d}$。然后，$X_F$使用点对点卷积投射到低C维空间，并通过连接操作与**X**结合起来。然后用另一个n×n卷积层来融合这些串联的特征。请注意，由于$X_U(p)$使用卷积对n×n区域的局部信息进行编码，而$X_G(p)$对p-th位置的P个斑块的全局信息进行编码，所以$X_G$中的每个像素可以编码X中所有像素的信息，如图-4所示. 因此，MobileViT的整体有效接收场为H×W

![image.png](image/图-4.png)

**图-4**：每个像素都能看到MobileViT块中的其他像素。在这个例子中，红色像素使用变换器依附于蓝色像素（在其他patches的相应位置的像素）。因为蓝色像素已经使用卷积对邻近像素的信息进行了编码，这使得红色像素可以对图像中所有像素的信息进行编码。这里，黑色和灰色网格中的每个单元格分别代表一个补丁和一个像素.

In [10]:
class MobileViTBlock(nn.Cell):
    """
    This class defines the MobileViT module
    Args:
        in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
        transformer_dim (int): Input dimension to the transformer unit
        ffn_dim (int): Dimension of the FFN block
        n_transformer_blocks (Optional[int]): Number of transformer blocks. Default: None
        head_dim (Optional[int]): Head dimension in the multi-head attention. Default: None
        attn_dropout (Optional[float]): Dropout in multi-head attention. Default: None
        dropout (Optional[float]): Dropout rate. Default: None
        ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: None
        patch_h (Optional[int]): Patch height for unfolding operation. Default: None
        patch_w (Optional[int]): Patch width for unfolding operation. Default: None
        conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: None
        no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: None
    """

    def __init__(
            self,
            in_channels: int,
            transformer_dim: int,
            ffn_dim: int,
            n_transformer_blocks: Optional[int] = None,
            head_dim: Optional[int] = None,
            attn_dropout: Optional[float] = None,
            dropout: Optional[float] = None,
            ffn_dropout: Optional[float] = None,
            patch_h: Optional[int] = None,
            patch_w: Optional[int] = None,
            conv_ksize: Optional[int] = None,
            no_fusion: Optional[bool] = None,
    ) -> None:

        conv_1x1_out = ConvNormActivation(
            in_planes=transformer_dim,
            out_planes=in_channels,
            kernel_size=1,
            stride=1,
            activation=Swish
        )
        conv_3x3_out = None
        if not no_fusion:
            conv_3x3_out = ConvNormActivation(
                in_planes=2 * in_channels,
                out_planes=in_channels,
                kernel_size=conv_ksize,
                stride=1,
                activation=Swish
            )

        super(MobileViTBlock, self).__init__()
        self.local_rep = nn.SequentialCell(OrderedDict([
            ('conv_3x3',
             ConvNormActivation(in_planes=in_channels, out_planes=in_channels, kernel_size=conv_ksize, stride=1,
                                activation=Swish)),
            ('conv_1x1', ConvNormActivation(in_planes=in_channels, out_planes=transformer_dim,
                                            kernel_size=1, stride=1, norm=None, activation=None)),
        ]))
        num_heads = transformer_dim // head_dim
        global_rep = [
            TransformerEncoder(
                dim=transformer_dim,
                mlp_dim=ffn_dim,
                num_heads=num_heads,
                num_layers=n_transformer_blocks,
                attention_keep_prob=1 - attn_dropout,
                keep_prob=1 - dropout,
                drop_path_keep_prob=1 - ffn_dropout,
                activation=Swish
            )
        ]
        norm_layer = nn.LayerNorm((transformer_dim,))
        global_rep.append(norm_layer)

        self.global_rep = nn.SequentialCell(global_rep)
        self.conv_proj = conv_1x1_out
        self.fusion = conv_3x3_out
        self.patch_h = patch_h
        self.patch_w = patch_w
        self.patch_area = self.patch_w * self.patch_h
        self.resize_unfolding = nn.ResizeBilinear()
        self.resize_folding = nn.ResizeBilinear()
        self.concat_op = ops.Concat(axis=1)

    def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]:
        """
        Reshape the input Tensor into a series of flattened blocks and project it into a fixed d-dimensional space

        Args:
            feature_map(Tensor):input feature map tensor

        Returns:
            Tuple,reshaped patches and properties of patches
        """
        patch_w, patch_h = self.patch_w, self.patch_h
        patch_area = int(patch_w * patch_h)
        batch_size, in_channels, orig_h, orig_w = feature_map.shape
        new_h = int(m_np.ceil(orig_h / self.patch_h) * self.patch_h)
        new_w = int(m_np.ceil(orig_w / self.patch_w) * self.patch_w)
        interpolate = False

        # Note: Padding can be done, but then it needs to be handled in attention function.
        if new_w != orig_w or new_h != orig_h:
            feature_map = self.resize_unfolding(feature_map, size=(new_h, new_w), align_corners=False)
            interpolate = True

        # number of patches along width and height
        num_patch_w = new_w // patch_w  # n_w
        num_patch_h = new_h // patch_h  # n_h
        num_patches = num_patch_h * num_patch_w  # N

        # [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w]
        reshaped_fm = feature_map.reshape(
            batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w
        )

        # [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w]
        transposed_fm = ops.transpose(reshaped_fm, (0, 2, 1, 3))

        # [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
        reshaped_fm = transposed_fm.reshape(
            batch_size, in_channels, num_patches, patch_area
        )

        # [B, C, N, P] --> [B, P, N, C]
        transposed_fm = ops.transpose(reshaped_fm, (0, 3, 2, 1))

        # [B, P, N, C] --> [BP, N, C]
        patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1)
        info_dict = {
            "orig_size": (orig_h, orig_w),
            "batch_size": batch_size,
            "interpolate": interpolate,
            "total_patches": num_patches,
            "num_patches_w": num_patch_w,
            "num_patches_h": num_patch_h,
        }
        return patches, info_dict

    def folding(self, patches: Tensor, info_dict: Dict) -> Tensor:
        """
        Fold the tensors according to the order of the patches and the spatial order of the pixels inside the patches

        Args:
            patches(Tensor):input patches tensor
            info_dict(dict):the properties of patches

        Returns:
            Tensor,The folded feature map tensor
        """

        # [BP, N, C] --> [B, P, N, C]
        patches = patches.view(
            info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
        )
        batch_size = patches.shape[0]
        channels = patches.shape[3]
        num_patch_h = info_dict["num_patches_h"]
        num_patch_w = info_dict["num_patches_w"]

        # [B, P, N, C] --> [B, C, N, P]
        patches = ops.transpose(patches, (0, 3, 2, 1))

        # [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w]
        feature_map = patches.reshape(
            batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w
        )

        # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w]
        feature_map = ops.transpose(feature_map, (0, 2, 1, 3))

        # [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
        feature_map = feature_map.reshape(
            batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w
        )
        if info_dict["interpolate"]:
            feature_map = self.resize_folding(feature_map, size=info_dict["orig_size"], align_corners=False)
        return feature_map

    def construct(self, x):
        res = x
        x = self.local_rep(x)
        patches, info_dict = self.unfolding(x)
        patches = self.global_rep(patches)
        x = self.folding(patches=patches, info_dict=info_dict)
        x = self.conv_proj(x)
        if self.fusion is not None:
            x = self.fusion(self.concat_op((res, x)))
        return x


## 5) 倒残差模块InvertedResidual

倒残差模块即图-1中的MV2模块，主要负责向下采样。因此，这些块在MobileViT网络中是浅而窄的。图-5中MobileViT的空间水平参数分布进一步表明，在不同的网络配置中，MV2区块对网络总参数的贡献非常小.

![image.png](image/图-5.png)

**图-5**：训练效率：这里，标准采样器是指PyTorch的分布式数据并行采样器

In [11]:
class InvertedResidual(nn.Cell):
    """
    Mobilenetv2 residual block definition.

    Args:
        in_channel (int): The input channel.
        out_channel (int): The output channel.
        stride (int): The Stride size for the first convolutional layer. Default: 1.
        expand_ratio (int): The expand ration of input channel.
        norm (nn.Cell, optional): The norm layer that will be stacked on top of the convoution
            layer. Default: None.
    Returns:
        Tensor, output tensor.

    Examples:
        >>> from mindvision.classification.models.backbones import InvertedResidual
        >>> InvertedResidual(3, 256, 1, 1)
    """

    def __init__(self,
                 in_channel: int,
                 out_channel: int,
                 stride: int,
                 expand_ratio: int,
                 norm: Optional[nn.Cell] = None,
                 activation: Optional[nn.Cell] = None
                 ) -> None:
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        if not norm:
            norm = nn.BatchNorm2d
        if not activation:
            activation = Swish

        hidden_dim = round(in_channel * expand_ratio)
        self.use_res_connect = stride == 1 and in_channel == out_channel

        layers: List[nn.Cell] = []
        if expand_ratio != 1:
            # pw
            layers.append(
                ConvNormActivation(in_channel, hidden_dim, kernel_size=1, norm=norm, activation=activation))
        layers.extend([
            # dw
            ConvNormActivation(
                hidden_dim,
                hidden_dim,
                stride=stride,
                groups=hidden_dim,
                norm=norm,
                activation=activation
            ),
            # pw-linear
            nn.Conv2d(hidden_dim, out_channel, kernel_size=1,
                      stride=1, has_bias=False),
            norm(out_channel)
        ])
        self.conv = nn.SequentialCell(layers)
        self.add = Add()

    def construct(self, x):
        identity = x
        x = self.conv(x)
        if self.use_res_connect:
            return self.add(identity, x)
        return x


## 6）参数获取函数get_configuration：

用于获得不同规格模型的参数配置

In [12]:
def get_configuration(model_type: Optional[str] = None) -> Dict:
    """
    Get the parameters of different specifications of the mobilevit model

    Args:
        model_type(Str):specifications of the model

    Returns:
        Dict,parameters for the corresponding model's specification
    """
    head_dim = None
    num_heads = 4
    model_type = model_type.lower()

    if model_type == "xx_small":
        mv2_exp_mult = 2
        config = {
            "layer1": {
                "out_channels": 16,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 1,
                "stride": 1,
                "block_type": "mv2",
            },
            "layer2": {
                "out_channels": 24,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 3,
                "stride": 2,
                "block_type": "mv2",
            },
            "layer3": {
                "out_channels": 48,
                "transformer_channels": 64,
                "ffn_dim": 128,
                "transformer_blocks": 2,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "layer4": {
                "out_channels": 64,
                "transformer_channels": 80,
                "ffn_dim": 160,
                "transformer_blocks": 4,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "layer5": {
                "out_channels": 80,
                "transformer_channels": 96,
                "ffn_dim": 192,
                "transformer_blocks": 3,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "last_layer_exp_factor": 4,
        }
    elif model_type == "x_small":
        mv2_exp_mult = 4
        config = {
            "layer1": {
                "out_channels": 32,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 1,
                "stride": 1,
                "block_type": "mv2",
            },
            "layer2": {
                "out_channels": 48,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 3,
                "stride": 2,
                "block_type": "mv2",
            },
            "layer3": {
                "out_channels": 64,
                "transformer_channels": 96,
                "ffn_dim": 192,
                "transformer_blocks": 2,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "layer4": {
                "out_channels": 80,
                "transformer_channels": 120,
                "ffn_dim": 240,
                "transformer_blocks": 4,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "layer5": {
                "out_channels": 96,
                "transformer_channels": 144,
                "ffn_dim": 288,
                "transformer_blocks": 3,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "last_layer_exp_factor": 4,
        }
    elif model_type == "small":
        mv2_exp_mult = 4
        config = {
            "layer1": {
                "out_channels": 32,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 1,
                "stride": 1,
                "block_type": "mv2",
            },
            "layer2": {
                "out_channels": 64,
                "expand_ratio": mv2_exp_mult,
                "num_blocks": 3,
                "stride": 2,
                "block_type": "mv2",
            },
            "layer3": {
                "out_channels": 96,
                "transformer_channels": 144,
                "ffn_dim": 288,
                "transformer_blocks": 2,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "layer4": {
                "out_channels": 128,
                "transformer_channels": 192,
                "ffn_dim": 384,
                "transformer_blocks": 4,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "layer5": {
                "out_channels": 160,
                "transformer_channels": 240,
                "ffn_dim": 480,
                "transformer_blocks": 3,
                "patch_h": 2,
                "patch_w": 2,
                "stride": 2,
                "mv_expand_ratio": mv2_exp_mult,
                "head_dim": head_dim,
                "num_heads": num_heads,
                "block_type": "mobilevit",
            },
            "last_layer_exp_factor": 4,
        }
    else:
        raise NotImplementedError
    return config


## 7）GlobalAvgPooling

In [13]:
class GlobalAvgPooling(nn.Cell):
    """
    Global avg pooling definition.

    Args:

    Returns:
        Tensor, output tensor.

    Examples:
        >>> GlobalAvgPooling()
    """

    def __init__(self,
                 keep_dims: bool = False
                 ) -> None:
        super(GlobalAvgPooling, self).__init__()
        self.mean = P.ReduceMean(keep_dims=keep_dims)

    def construct(self, x):
        x = self.mean(x, (2, 3))
        return x

## 整体构建MobileViT

以下代码构建了一个完整的mobilevit模型。

In [14]:
class MobileViT(nn.Cell):
    """
    This class implements the `MobileViT architecture <https://arxiv.org/abs/2110.02178?context=cs.LG>`
    Args:
        num_classes(int):The number of classification. Default: 1000.
        classifier_dropout(float): The drop out rate. Default: 0.1.
        image_channels(int):Input channel. Default: 3.
        out_channels(int):Out channel. Default: 16.
        model_type(str):specifications of the model.Default: xx_small
    """

    def __init__(self,
                 num_classes: int = 1000,
                 classifier_dropout: float = 0.1,
                 image_channels: int = 3,
                 out_channels: int = 16,
                 model_type: str = "xx_small",
                 ):
        self.mobilevit_config = get_configuration(model_type)
        super(MobileViT, self).__init__()

        self.conv_1 = ConvNormActivation(
            in_planes=image_channels,
            out_planes=out_channels,
            kernel_size=3,
            stride=2,
            activation=Swish
        )
        in_channels = out_channels
        self.layer_1, out_channels = self._make_layer(
            input_channel=in_channels, cfg=self.mobilevit_config["layer1"]
        )
        in_channels = out_channels
        self.layer_2, out_channels = self._make_layer(
            input_channel=in_channels, cfg=self.mobilevit_config["layer2"]
        )
        in_channels = out_channels
        self.layer_3, out_channels = self._make_layer(
            input_channel=in_channels, cfg=self.mobilevit_config["layer3"]
        )
        in_channels = out_channels
        self.layer_4, out_channels = self._make_layer(
            input_channel=in_channels,
            cfg=self.mobilevit_config["layer4"],
            dilate=False,
        )
        in_channels = out_channels
        self.layer_5, out_channels = self._make_layer(
            input_channel=in_channels,
            cfg=self.mobilevit_config["layer5"],
            dilate=False,
        )
        in_channels = out_channels
        exp_channels = min(self.mobilevit_config["last_layer_exp_factor"] * in_channels, 960)
        self.conv_1x1_exp = ConvNormActivation(
            in_planes=in_channels,
            out_planes=exp_channels,
            kernel_size=1,
            stride=1,
            activation=Swish
        )
        if 0.0 < classifier_dropout < 1.0:
            self.classifier = nn.SequentialCell(OrderedDict([
                ('global_pool', GlobalAvgPooling(keep_dims=False)),
                ('dropout', nn.Dropout(keep_prob=1 - classifier_dropout)),
                ('fc', nn.Dense(in_channels=exp_channels, out_channels=num_classes, has_bias=True)),
            ]))
        else:
            self.classifier = nn.SequentialCell(OrderedDict([
                ('global_pool', GlobalAvgPooling(keep_dims=False)),
                ('fc', nn.Dense(in_channels=exp_channels, out_channels=num_classes, has_bias=True)),
            ]))

    def _make_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = False) -> Tuple[nn.SequentialCell, int]:
        """
        Generate a layer with MobileNetv2block or MobileViTblock according to the configuration information

        Args:
           input_channel(int):Input channel
           cfg(dict):parameters for the corresponding model's specification
           dilate(bool):Set whether to dilate

        Returns:
            Tuple,a SequentialCell and out channel
        """
        block_type = cfg.get("block_type", "mobilevit")
        if block_type.lower() == "mobilevit":
            return self._make_mit_layer(
                input_channel=input_channel, cfg=cfg, dilate=dilate
            )

        return self._make_mobilenet_layer(
            input_channel=input_channel, cfg=cfg
        )

    @staticmethod
    def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.SequentialCell, int]:
        """
        Generate a layer with MobileNetv2 block

        Args:
           input_channel(int):Input channel
           cfg(dict):parameters for the corresponding model's specification

        Returns:
            Tuple,a SequentialCell and out channel
        """
        output_channels = cfg.get("out_channels")
        num_blocks = cfg.get("num_blocks", 2)
        expand_ratio = cfg.get("expand_ratio", 4)
        block = []
        for i in range(num_blocks):
            stride = cfg.get("stride", 1) if i == 0 else 1

            layer = InvertedResidual(
                in_channel=input_channel,
                out_channel=output_channels,
                stride=stride,
                expand_ratio=expand_ratio,
            )
            block.append(layer)
            input_channel = output_channels
        return nn.SequentialCell(block), input_channel

    def _make_mit_layer(self, input_channel, cfg: Dict, dilate: Optional[bool] = False) \
            -> Tuple[nn.SequentialCell, int]:
        """
        Generate a layer with MobileViTBlock

        Args:
           input_channel(int):Input channel
           cfg(dict):parameters for the corresponding model's specification
           dilate(boo):Set whether to dilate

        Returns:
            Tuple,a SequentialCell and out channel
        """
        block = []
        stride = cfg.get("stride", 1)
        if stride == 2:
            if dilate:
                self.dilation *= 2
                stride = 1
            layer = InvertedResidual(
                in_channel=input_channel,
                out_channel=cfg.get("out_channels"),
                stride=stride,
                expand_ratio=cfg.get("mv_expand_ratio", 4),
            )
            block.append(layer)
            input_channel = cfg.get("out_channels")
        head_dim = cfg.get("head_dim", 32)
        transformer_dim = cfg["transformer_channels"]
        ffn_dim = cfg.get("ffn_dim")
        if head_dim is None:
            num_heads = cfg.get("num_heads", 4)
            if num_heads is None:
                num_heads = 4
            head_dim = transformer_dim // num_heads
        block.append(
            MobileViTBlock(
                in_channels=input_channel,
                transformer_dim=transformer_dim,
                ffn_dim=ffn_dim,
                n_transformer_blocks=cfg.get("transformer_blocks", 1),
                patch_h=cfg.get("patch_h", 2),
                patch_w=cfg.get("patch_w", 2),
                dropout=0.05,
                ffn_dropout=0.0,
                attn_dropout=0.0,
                head_dim=head_dim,
                no_fusion=False,
                conv_ksize=3,
            )
        )
        return nn.SequentialCell(block), input_channel

    def construct(self, x: Tensor) -> Tensor:
        x = self.conv_1(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.layer_5(x)
        x = self.conv_1x1_exp(x)
        x = self.classifier(x)
        return x


## 环境准备与数据获取

首先，导入相关模块，配置相关超参数并读取数据集。这部分代码在MindSpore Vision套件中有一个API，可以直接调用。
有关详细信息，请参阅以下链接：https://www.mindspore.cn/vision/docs/zh-CN/r0.1/index.html.
完整的ImageNet数据集可以在http://image-net.org.
您可以将数据集文件解压缩到此目录结构中，并通过MindSpore Vision的API读取它们。
请确保你的数据集路径如以下结构。

```text
.dataset/
├── train/  (1000 directories and 1281167 images)
│  ├── n04347754/
│  │   ├── 000001.jpg
│  │   ├── 000002.jpg
│  │   └── ....
│  └── n04347756/
│      ├── 000001.jpg
│      ├── 000002.jpg
│      └── ....
└── val/   (1000 directories and 50000 images)
│   ├── n04347754/
│   │   ├── 000001.jpg
│   │   ├── 000002.jpg
│   │   └── ....
│   └── n04347756/
│        ├── 000001.jpg
│        ├── 000002.jpg
│        └── ....
└── infer/  (用于存放推理图片的文件目录)
│
└── imagenet_meta.json/  (Imagenet图像的类别注释json文件)
```

## 数据集处理准备

In [15]:
FILE_TYPE_ALIASES = {
    ".tbz": (".tar", ".bz2"),
    ".tbz2": (".tar", ".bz2"),
    ".tgz": (".tar", ".gz")
}

ARCHIVE_TYPE_SUFFIX = [".tar", ".zip"]

COMPRESS_TYPE_SUFFIX = [".bz2", ".gz"]


def check_file_exist(file_name: str):
    """Check the input filename is exist or not."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")


def check_file_valid(filename: str, extension: Tuple[str, ...]):
    """Check image file is valid through the extension."""
    return filename.lower().endswith(extension)


def check_dir_exist(dir_name: str) -> None:
    """Check the input directory is exist or not."""
    if not os.path.isdir(dir_name):
        raise FileNotFoundError(f"Directory `{dir_name}` does not exist.")


def save_json_file(filename: str, data: Dict) -> None:
    """Save json file."""
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=1)
        print("Json file dump success.")


def load_json_file(filename: str) -> None:
    """Load json file."""
    with open(filename, "r") as f:
        return json.load(f)


def detect_file_type(filename: str):  # pylint: disable=inconsistent-return-statements
    """Detect file type by suffixes and return tuple(suffix, archive_type, compression)."""
    suffixes = pathlib.Path(filename).suffixes
    if not suffixes:
        raise RuntimeError(f"File `{filename}` has no suffixes that could be used to detect.")
    suffix = suffixes[-1]

    # Check if the suffix is a known alias.
    if suffix in FILE_TYPE_ALIASES:
        return suffix, FILE_TYPE_ALIASES[suffix][0], FILE_TYPE_ALIASES[suffix][1]

    # Check if the suffix is an archive type.
    if suffix in ARCHIVE_TYPE_SUFFIX:
        return suffix, suffix, None

    # Check if the suffix is a compression.
    if suffix in COMPRESS_TYPE_SUFFIX:
        # Check for suffix hierarchy.
        if len(suffixes) > 1:
            suffix2 = suffixes[-2]
            # Check if the suffix2 is an archive type.
            if suffix2 in ARCHIVE_TYPE_SUFFIX:
                return suffix2 + suffix, suffix2, suffix
        return suffix, None, suffix


## 用于读取和写入图像的图像io

In [16]:
image_format = ('.JPEG', '.jpeg', '.PNG', '.png', '.JPG', '.jpg')
image_mode = ('1', 'L', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F')


def imread(image, mode=None):
    """
    Read an image.

    Args:
        image (ndarray or str or Path): Ndarry, str or pathlib.Path.
        mode (str): Image mode.

    Returns:
        ndarray: Loaded image array.
    """
    Validator.check_string(mode, image_mode)

    if isinstance(image, pathlib.Path):
        image = str(image)

    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")

    return image


def imwrite(image, image_path, auto_mkdir=True):
    """
    Write image to file.

    Args:
        image (ndarray): Image array to be written.
        image_path (str): Image file path to be written.
        auto_mkdir (bool): `image_path` does not exist create it automatically.

    """
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)

    image = Image.fromarray(image)
    image.save(image_path)


def read_dataset(path: str) -> Tuple[List[str], List[int]]:
    """
    Get the path list and index list of images.
    """
    img_list = list()
    id_list = list()

    idx = 0
    if os.path.isdir(path):
        for img_name in os.listdir(path):
            if pathlib.Path(img_name).suffix in image_format:
                img_path = os.path.join(path, img_name)
                img_list.append(img_path)
                id_list.append(idx)
                idx += 1
    else:
        img_list.append(path)
        id_list.append(idx)
    return img_list, id_list


def label2index(path: str) -> Dict[str, int]:
    """
    Read images directory for getting label and its corresponding index.
    """
    label = sorted(i.name for i in os.scandir(path) if i.is_dir())

    if not label:
        raise ValueError(f"Cannot find any folder in {path}.")

    return dict((j, i) for i, j in enumerate(label))


## DatasetGenerator：

用于获取图像路径及其相应标签的数据集生成器。

In [17]:
class DatasetGenerator:
    """ Dataset generator for getting image path and its corresponding label. """

    def __init__(self, image, label, image_id=None, mode=None):
        self.image = image
        self.label = label
        self.image_id = image_id
        self.mode = mode

    def __getitem__(self, item):
        """Get the image and label for each item."""
        if isinstance(self.image, list):
            image = imread(self.image[item], self.mode) if self.mode else np.fromfile(self.image[item], dtype="int8")
        else:
            image = self.image[item]

        label = self.label[item]

        if self.image_id:
            image_id = self.image_id[item]
            return image, image_id, label

        return image, label

    def __len__(self):
        """Get the the size of dataset."""
        return len(self.image)


## Dataset ：是制作与MindSpore Vision兼容的数据集的基类。

In [18]:
class Dataset:
    """
    Dataset is the base class for making dataset which are compatible with MindSpore Vision.
    """

    def __init__(self,
                 path: str,
                 split: str,
                 load_data: Union[Callable, Tuple],
                 transform: Optional[Callable],
                 target_transform: Optional[Callable],
                 batch_size: int,
                 repeat_num: int,
                 resize: Union[int, Tuple[int, int]],
                 shuffle: bool,
                 num_parallel_workers: Optional[int],
                 num_shards: int,
                 shard_id: int,
                 mr_file: Optional[str] = None,
                 columns_list: Tuple = ('image', 'label'),
                 mode: Optional[str] = None) -> None:
        ds.config.set_enable_shared_mem(False)
        self.path = os.path.expanduser(path)
        self.split = split

        if len(columns_list) == 3 and self.split != "infer":
            self.image, self.image_id, self.label = load_data()
        else:
            self.image, self.label = load_data(self.path) if self.split == "infer" else load_data()
            self.image_id = None
        self.transform = transform
        self.target_transform = target_transform
        self.batch_size = batch_size
        self.repeat_num = repeat_num
        self.resize = resize
        self.shuffle = shuffle
        self.num_parallel_workers = num_parallel_workers
        self.num_shards = num_shards
        self.shard_id = shard_id
        self.mode = mode
        self.mr_file = mr_file
        self.columns_list = columns_list
        if self.mr_file:
            self.dataset = ds.MindDataset(mr_file,
                                          columns_list=list(self.columns_list),
                                          num_parallel_workers=num_parallel_workers,
                                          shuffle=self.shuffle,
                                          num_shards=self.num_shards,
                                          shard_id=self.shard_id)
        else:
            if self.image_id:
                self.dataset = ds.GeneratorDataset(DatasetGenerator(self.image,
                                                                    self.label,
                                                                    self.image_id,
                                                                    mode=self.mode),
                                                   column_names=list(self.columns_list),
                                                   num_parallel_workers=num_parallel_workers,
                                                   shuffle=self.shuffle,
                                                   num_shards=self.num_shards,
                                                   shard_id=self.shard_id)
            else:
                self.dataset = ds.GeneratorDataset(DatasetGenerator(self.image,
                                                                    self.label,
                                                                    mode=self.mode),
                                                   column_names=list(self.columns_list),
                                                   num_parallel_workers=num_parallel_workers,
                                                   shuffle=self.shuffle,
                                                   num_shards=self.num_shards,
                                                   shard_id=self.shard_id)

    @property
    def get_path(self):
        """Get path in imagenet dataset which will be train or val folder."""

        return os.path.join(self.path, self.split)

    @property
    def index2label(self):
        """Get the mapping of indexes and labels."""
        raise NotImplementedError

    def default_transform(self):
        """Default data augmentation."""
        raise NotImplementedError

    def transforms(self):
        """Data augmentation."""
        if not self.dataset:
            raise ValueError("dataset is None")

        trans = self.transform if self.transform else self.default_transform()

        self.dataset = self.dataset.map(operations=trans,
                                        input_columns='image',
                                        num_parallel_workers=self.num_parallel_workers)
        if self.target_transform:
            self.dataset = self.dataset.map(operations=self.target_transform,
                                            input_columns='label',
                                            num_parallel_workers=self.num_parallel_workers)

    def run(self):
        """Dataset pipeline."""
        self.transforms()
        self.dataset = self.dataset.batch(self.batch_size, drop_remainder=True)
        self.dataset = self.dataset.repeat(self.repeat_num)

        return self.dataset


class ParseDataset(metaclass=ABCMeta):
    """
    Parse dataset.
    """

    def __init__(self, path: str):
        self.path = os.path.expanduser(path)

    @abstractmethod
    def parse_dataset(self):
        """parse dataset from internet or compression file."""


## Imagenet数据集处理接口

In [None]:
__all__ = ["ImageNet", "ParseImageNet"]
class ImageNet(Dataset):
    """
    A source dataset that reads, parses and augments the IMAGENET dataset.

    The generated dataset has two columns :py:obj:`[image, label]`.
    The tensor of column :py:obj:`image` is a matrix of the float32 type.
    The tensor of column :py:obj:`label` is a scalar of the int32 type.

    Args:
        path (str): The root directory of the IMAGENET dataset or inference image.
        split (str): The dataset split, supports "train", "val" or "infer". Default: "train".
        num_parallel_workers (int, optional): The number of subprocess used to fetch the dataset
            in parallel. Default: None.
        transform (callable, optional):A function transform that takes in a image. Default: None.
        target_transform (callable, optional):A function transform that takes in a label. Default: None.
        batch_size (int): The batch size of dataset. Default: 64.
        repeat_num (int): The repeat num of dataset. Default: 1.
        shuffle (bool, optional): Whether to perform shuffle on the dataset. Default: None.
        num_shards (int, optional): The number of shards that the dataset will be divided. Default: None.
        shard_id (int, optional): The shard ID within num_shards. Default: None.
        resize (Union[int, tuple]): The output size of the resized image. If size is an integer, the smaller edge of the
            image will be resized to this value with the same image aspect ratio. If size is a sequence of length 2,
            it should be  (height, width). Default: 224.

    Raises:
        ValueError: If `split` is not 'train', 'test' or 'infer'.

    Examples:
        >>> from mindvision.classification.dataset import ImageNet
        >>> dataset = ImagenNet("./data/imagenet/", "train")
        >>> dataset = dataset.run()
    """
    def default_transform(self):
        pass

    def __init__(self,
                 path: str,
                 split: str = "train",
                 transform: Optional[Callable] = None,
                 target_transform: Optional[Callable] = None,
                 batch_size: int = 64,
                 resize: Union[Tuple[int, int], int] = 224,
                 repeat_num: int = 1,
                 shuffle: Optional[bool] = None,
                 num_parallel_workers: int = 1,
                 num_shards: Optional[int] = None,
                 shard_id: Optional[int] = None,
                 ):
        Validator.check_string(split, ["train", "val", "infer"], "split")

        self.images_path, self.images_label = self.__load_data(os.path.join(os.path.expanduser(path), split))
        load_data = read_dataset if split == "infer" else self.read_dataset

        super(ImageNet, self).__init__(path=path,
                                       split=split,
                                       load_data=load_data,
                                       transform=transform,
                                       target_transform=target_transform,
                                       batch_size=batch_size,
                                       repeat_num=repeat_num,
                                       resize=resize,
                                       shuffle=shuffle,
                                       num_parallel_workers=num_parallel_workers,
                                       num_shards=num_shards,
                                       shard_id=shard_id,
                                       )

    @property
    def index2label(self):
        """
        Get the mapping of indexes and labels.
        """
        parse_imagenet = ParseImageNet(self.path)
        if not os.path.exists(os.path.join(parse_imagenet.path, parse_imagenet.meta_file)):
            parse_imagenet.parse_devkit()

        wind2class_name = load_json_file(os.path.join(parse_imagenet.path, parse_imagenet.meta_file))['wnid2class_name']
        wind2class_name = sorted(wind2class_name.items(), key=lambda x: x[0])
        mapping = {}

        for index, (_, class_name) in enumerate(wind2class_name):
            mapping[index] = class_name[0]

        return mapping

    @staticmethod
    def __load_data(path):
        """Read each image and its corresponding label from directory."""
        check_dir_exist(path)

        available_label = set()
        images_label, images_path = [], []
        label_to_idx = label2index(path)

        # Iterate each file in the path
        for label in label_to_idx.keys():
            for file_name in os.listdir(os.path.join(path, label)):
                if check_file_valid(file_name, image_format):
                    images_path.append(os.path.join(path, label, file_name))
                    images_label.append(label_to_idx[label])
                    if label not in available_label:
                        available_label.add(label)

        empty_label = set(label_to_idx.keys()) - available_label

        if empty_label:
            raise ValueError(f"Found invalid file for the label {','.join(empty_label)}.")

        return images_path, images_label

    def read_dataset(self, *args):
        if not args:
            return self.images_path, self.images_label

        return np.fromfile(self.images_path[args[0]], dtype="int8"), self.images_label[args[0]]


class ParseImageNet(ParseDataset):
    """
    Parse ImageNet dataset and generate the json file (file name:imagenet_meta.json).
    """

    devkit_file = ["meta.mat", "ILSVRC2012_validation_ground_truth.txt"]
    meta_file = "imagenet_meta.json"

    def __parse_meta_mat(self, devkit_path):
        """Parse the mat file(meta.mat)."""
        metafile = os.path.join(devkit_path, "data", self.devkit_file[0])
        meta = io.loadmat(metafile, squeeze_me=True)['synsets']

        nums_children = list(zip(*meta))[4]
        meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

        idcs, wnids, classes = list(zip(*meta))[:3]
        clssname = [tuple(clss.split(', ')) for clss in classes]
        idx2wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
        wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
        return idx2wnid, wnid2class

    def __parse_groundtruth(self, devkit_path):
        """Parse ILSVRC2012_validation_ground_truth.txt."""
        val_gt = os.path.join(devkit_path, "data", self.devkit_file[1])
        with open(val_gt, "r") as f:
            val_idx2image = f.readlines()
        return [int(i) for i in val_idx2image]

    def parse_devkit(self):
        """Parse the devkit archive of the ImageNet2012 classification dataset and save meta info in json file."""

        with tempfile.TemporaryDirectory() as temp_dir:
            devkit_path = os.path.join(temp_dir, "ILSVRC2012_devkit_t12")
            idx2wnid, wnid2class = self.__parse_meta_mat(devkit_path)
            val_idcs = self.__parse_groundtruth(devkit_path)
            val_wnids = [idx2wnid[idx] for idx in val_idcs]

            # Generating imagenet_meta.json which saved the values of wnid2class and val_wnids
            dict_json = {"wnid2class": wnid2class, "val_wnids": val_wnids}
            save_json_file(os.path.join(self.path, self.meta_file), dict_json)

    # pylint: disable=unused-argument
    def parse_dataset(self, *args):
        """Parse the devkit archives of ImageNet dataset."""
        if not os.path.exists(os.path.join(self.path, self.meta_file)):
            self.parse_devkit()

## 模型训练

我们在ImageNet-1k分类数据集上从头开始训练MobileViT模型。该数据集分别提供了128万和5万张图像用于训练和验证。MobileViT网络使用Ascend训练了180个epochs，使用SGD优化器、标签平滑交叉熵损失（平滑=0.1）和多尺度采样器（S={（160，160）,（192，192）,（256，256）,（288。288), (320, 320)}). 学习率在前3k次迭代中从0.0002增加到0.002，然后使用余弦计划退火到0.0002。我们使用0.01的L2权重衰减。我们使用基本的数据增强（即随机调整大小的裁剪和水平翻转），并使用单个裁剪的top-1精度来评估性能。

### 首先定义基础架构引擎监控寄存器的初始化

In [20]:
class LossMonitor(Callback):
    """
    Loss Monitor for classification.

    Args:
        lr_init (Union[float, Iterable], optional): The learning rate schedule. Default: None.
        per_print_times (int): Every how many steps to print the log information. Default: 1.

    Examples:
        >>> from mindvision.engine.callback import LossMonitor
        >>> lr = [0.01, 0.008, 0.006, 0.005, 0.002]
        >>> monitor = LossMonitor(lr_init=lr, per_print_times=100)
    """

    def __init__(self,
                 lr_init: Optional[Union[float, Iterable]] = None,
                 per_print_times: int = 1):
        super(LossMonitor, self).__init__()
        self.lr_init = lr_init
        self.per_print_times = per_print_times
        self.last_print_time = 0

    # pylint: disable=unused-argument
    def epoch_begin(self, run_context):
        """
        Record time at the beginning of epoch.

        Args:
            run_context (RunContext): Context of the process running.
        """
        self.losses = []
        self.epoch_time = time.time()

    def epoch_end(self, run_context):
        """
        Print training info at the end of epoch.

        Args:
            run_context (RunContext): Context of the process running.
        """
        callback_params = run_context.original_args()
        epoch_mseconds = (time.time() - self.epoch_time) * 1000
        per_step_mseconds = epoch_mseconds / callback_params.batch_num
        print(f"Epoch time: {epoch_mseconds:5.3f} ms, "
              f"per step time: {per_step_mseconds:5.3f} ms, "
              f"avg loss: {np.mean(self.losses):5.3f}", flush=True)

    # pylint: disable=unused-argument
    def step_begin(self, run_context):
        """
        Record time at the beginning of step.

        Args:
            run_context (RunContext): Context of the process running.
        """
        self.step_time = time.time()

    # pylint: disable=missing-docstring
    def step_end(self, run_context):
        """
        Print training info at the end of step.

        Args:
            run_context (RunContext): Context of the process running.
        """
        callback_params = run_context.original_args()
        step_mseconds = (time.time() - self.step_time) * 1000
        loss = callback_params.net_outputs

        if isinstance(loss, (tuple, list)):
            if isinstance(loss[0], mindspore.Tensor) and isinstance(loss[0].asnumpy(), np.ndarry):
                loss = loss[0]

        if isinstance(loss, mindspore.Tensor) and isinstance(loss.asnumpy(), np.ndarray):
            loss = np.mean(loss.asnumpy())

        self.losses.append(loss)
        cur_step_in_epoch = (callback_params.cur_step_num - 1) % callback_params.batch_num + 1

        # Boundary check.
        if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
            raise ValueError(f"Invalid loss, terminate training.")

        def print_info():
            lr_output = self.lr_init[callback_params.cur_step_num - 1] if isinstance(self.lr_init,
                                                                                     list) else self.lr_init
            print(f"Epoch:[{(callback_params.cur_epoch_num - 1):3d}/{callback_params.epoch_num:3d}], "
                  f"step:[{cur_step_in_epoch:5d}/{callback_params.batch_num:5d}], "
                  f"loss:[{loss:5.3f}/{np.mean(self.losses):5.3f}], "
                  f"time:{step_mseconds:5.3f} ms, "
                  f"lr:{lr_output:5.5f}", flush=True)

        if (callback_params.cur_step_num - self.last_print_time) >= self.per_print_times:
            self.last_print_time = callback_params.cur_step_num
            print_info()


### 定义带标签平滑的Logits的Softmax交叉熵的计算

In [21]:
class CrossEntropySmooth(LossBase):
    """
    Computes softmax cross entropy between digits and labels with label smoothing.

    Measures the distribution error between the probabilities of the input (computed with softmax function) and the
    target where the classes are mutually exclusive using cross entropy loss.

    In order to avoid nan loss in training, smoothing is applied to label.

    Typical input into this function is unnormalized scores denoted as x whose shape is (N, C),
    and the corresponding targets.

    Args:
        classes_num: Number of classes.
        sparse (bool): Specifies whether labels use sparse format or not. Default: False.
        reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
            If "none", do not perform reduction. Default: "none".
        smooth_factor: Smoothing the label to avoid nan loss. Default: 0.0.

    Inputs:
        - **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32.
        - **labels** (Tensor) - Tensor of shape (N, ). If `sparse` is True, The type of
          `labels` is int32 or int64. If `sparse` is False, the type of `labels` is the same as the type of `logits`.

    Outputs:
        Tensor, a tensor of the same shape and type as logits with the component-wise logistic losses.
    """

    def __init__(self, classes_num, sparse=True, reduction='mean', smooth_factor=0.):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
        self.off_value = Tensor(1.0 * smooth_factor / (classes_num - 1), mstype.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss


### 模型训练脚本文件

In [22]:
set_seed(1)

def mobilevit_train(args_opt):
    """mobilevit train"""

    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
    context.set_context(enable_graph_kernel=False)

    # Data preprocessing
    if args_opt.model_type == 'small':
        img_transforms = ([
            c_transforms.Decode(),
            c_transforms.RandomResizedCrop(256),
            c_transforms.RandomHorizontalFlip(),
            c_transforms.AutoAugment(),
            p_transforms.RandomErasing(prob=0.25),
            c_transforms.ConvertColor(c_transforms.ConvertMode.COLOR_RGB2BGR),
            p_transforms.ToTensor(),
        ])
    else:
        img_transforms = ([
            c_transforms.Decode(),
            c_transforms.RandomResizedCrop(256),
            c_transforms.RandomHorizontalFlip(),
            c_transforms.ConvertColor(c_transforms.ConvertMode.COLOR_RGB2BGR),
            p_transforms.ToTensor(),
        ])

    # dataset pipline
    dataset = ImageNet(args_opt.data_url,
                       split="train",
                       shuffle=True,
                       transform=img_transforms,
                       num_parallel_workers=args_opt.num_parallel_workers,
                       resize=args_opt.resize,
                       batch_size=args_opt.batch_size)

    dataset_train = dataset.run()
    step_size = dataset_train.get_dataset_size()

    # Create model.
    network = MobileViT(model_type=args_opt.model_type, num_classes=args_opt.num_classes)

    # Define the decreasing learning rate
    lr = nn.cosine_decay_lr(min_lr=args_opt.min_lr,
                            max_lr=args_opt.max_lr,
                            total_step=args_opt.epoch_size * step_size,
                            step_per_epoch=step_size,
                            decay_epoch=args_opt.decay_epoch)

    # Define loss scale
    loss_scale = 1024.0
    loss_scale_manager = mindspore.FixedLossScaleManager(loss_scale, False)

    # Define optimizer.
    network_opt = nn.SGD(network.trainable_params(), lr, momentum=args_opt.momentum, weight_decay=args_opt.weight_decay,
                         nesterov=False, loss_scale=loss_scale)

    # Define loss function.
    network_loss = CrossEntropySmooth(sparse=True,
                                      reduction="mean",
                                      smooth_factor=0.1,
                                      classes_num=args_opt.num_classes)

    # Define checkpoint
    ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=args_opt.keep_checkpoint_max)
    ckpt_callback = ModelCheckpoint(prefix=args_opt.model_type, directory=args_opt.ckpt_save_dir, config=ckpt_config)

    # Define metrics.
    metrics = {'acc', "loss"}

    # Define timer
    time_cb = TimeMonitor(data_size=dataset_train.get_dataset_size())

    # Init the model.
    if args_opt.device_target == "Ascend":
        model = Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=metrics, amp_level="auto",
                      loss_scale_manager=loss_scale_manager)
    else:
        model = Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=metrics)

    # Begin to train.
    model.train(args_opt.epoch_size,
                dataset_train,
                callbacks=[time_cb, ckpt_callback, LossMonitor(lr)],
                dataset_sink_mode=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MobileViT train.')
    parser.add_argument('--epoch_size', type=int, default=1, help='Train epoch size.')
    parser.add_argument('--model_type', default='xx_small', type=str, metavar='model_type')
    parser.add_argument('--batch_size', type=int, default=64, help='Number of batch size.')
    parser.add_argument('--decay_epoch', type=int, default=150, help='Number of decay epochs.')
    parser.add_argument('--num_classes', type=int, default=1001, help='Number of classification.')
    parser.add_argument('--data_url', default=r"/home/ma-user/work/imagenet/imagenet2012",
                        help='Location of data.')
    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for the moving average.')
    parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"])
    parser.add_argument('--max_lr', type=float, default=0.1, help='Number of the maximum learning rate.')
    parser.add_argument('--num_parallel_workers', type=int, default=5, help='Number of parallel workers.')
    parser.add_argument('--min_lr', type=float, default=1e-5, help='Number of the minimum learning rate.')
    parser.add_argument('--resize', type=int, default=256, help='Resize the height and weight of picture.')
    parser.add_argument('--weight_decay', type=float, default=4e-5, help='Momentum for the moving average.')
    parser.add_argument('--keep_checkpoint_max', type=int, default=40, help='Max number of checkpoint files.')
    parser.add_argument('--ckpt_save_dir', type=str, default="./Mobilevit_Ckpt", help='Location of training outputs.')
    args = parser.parse_known_args()[0]
    mobilevit_train(args)


Epoch:[  0/  1], step:[20010/20018], loss:[5.058/5.648], time:68.497 ms, lr:0.10000
Epoch:[  0/  1], step:[20011/20018], loss:[5.039/5.648], time:67.510 ms, lr:0.10000
Epoch:[  0/  1], step:[20012/20018], loss:[5.329/5.648], time:67.852 ms, lr:0.10000
Epoch:[  0/  1], step:[20013/20018], loss:[5.415/5.648], time:73.427 ms, lr:0.10000
Epoch:[  0/  1], step:[20014/20018], loss:[5.086/5.648], time:68.896 ms, lr:0.10000
Epoch:[  0/  1], step:[20015/20018], loss:[5.125/5.648], time:68.054 ms, lr:0.10000
Epoch:[  0/  1], step:[20016/20018], loss:[5.158/5.648], time:68.832 ms, lr:0.10000
Epoch:[  0/  1], step:[20017/20018], loss:[4.867/5.648], time:69.334 ms, lr:0.10000
Epoch:[  0/  1], step:[20018/20018], loss:[5.370/5.648], time:519.246 ms, lr:0.10000
epoch time: 3786117.721 ms, per step time: 189.136 ms
Epoch time: 3786118.951 ms, per step time: 189.136 ms, avg loss: 5.648


## 模型评估

可以通过下面的链接下载 MindSpore 模型：
https://download.mindspore.cn/vision/cyclegan/apple/mobilevit_xxs.ckpt
把下载好的模型文件路径如以下结构。
./src/
├── mobilevit_xxs.ckpt

In [43]:
def mobilevit_eval(args_opt):
    """mobilevit eval."""

    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)

    # Data pipeline.
    img_transforms = ([
        c_transforms.Decode(),
        c_transforms.Resize((int(math.ceil(256 / 0.875)) // 32) * 32, interpolation=Inter.BILINEAR),
        c_transforms.CenterCrop(256),
        c_transforms.ConvertColor(c_transforms.ConvertMode.COLOR_RGB2BGR),
        c_transforms.RandomHorizontalFlip(),
        p_transforms.ToTensor(),
    ])

    dataset = ImageNet(args_opt.data_url,
                       split="val",
                       transform=img_transforms,
                       num_parallel_workers=args_opt.num_parallel_workers,
                       resize=args_opt.resize,
                       shuffle=True,
                       batch_size=args_opt.batch_size)

    dataset_eval = dataset.run()

    # Create model.
    network = MobileViT(model_type=args_opt.model_type, num_classes=args_opt.num_classes)

    # load pertain model
    param_dict = load_checkpoint(args_opt.pretrained_model)
    load_param_into_net(network, param_dict)

    # Define loss function.
    network_loss = CrossEntropySmooth(sparse=True,
                                      reduction="mean",
                                      smooth_factor=0.1,
                                      classes_num=args_opt.num_classes)

    # Define eval metrics.
    eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy(),
                    'Top_5_Accuracy': nn.Top5CategoricalAccuracy()}

    # Init the model.
    model = Model(network, network_loss, metrics=eval_metrics)
    result = model.eval(dataset_eval)
    print(result)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MobileViT eval.')
    parser.add_argument("--data_url", default=r"/home/ma-user/work/imagenet/imagenet2012", help="Location of data.")
    parser.add_argument('--pretrained_model', default=r"/home/ma-user/work/Ascend_mobilevit/Mobilevit_Ckpt/4/mobilevit_xxs_6-180_20018.ckpt", type=str, metavar='PATH')
    parser.add_argument('--model_type', default='xx_small', type=str, metavar='model_type')
    parser.add_argument('--batch_size', type=int, default=100, help='Number of batch size.')
    parser.add_argument('--smooth_factor', type=float, default=0.1, help='The smooth factor.')
    parser.add_argument('--num_classes', type=int, default=1001, help='Number of classification.')
    parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"])
    parser.add_argument('--num_parallel_workers', type=int, default=8, help='Number of parallel workers.')
    parser.add_argument('--resize', type=int, default=256, help='Resize the height and weight of picture.')
    args = parser.parse_known_args()[0]
    mobilevit_eval(args)


{'Top_1_Accuracy': 0.62144, 'Top_5_Accuracy': 0.84314}


## 模型推理


### 首先我们定义推理的ImageNet数据集的转换

In [31]:
def infer_transform(dataset, columns_list, resize):
    """
    Implements validation transformation method (Resize --> CenterCrop --> ToTensor).
    """

    crop_ratio = 0.875
    scale_size = int(math.ceil(resize / crop_ratio))
    scale_size = (scale_size // 32) * 32

    img_transforms = [
        c_transforms.Decode(),
        c_transforms.Resize([scale_size, scale_size]),
        c_transforms.CenterCrop(resize),
        c_transforms.ConvertColor(c_transforms.ConvertMode.COLOR_RGB2BGR),
        c_transforms.RandomHorizontalFlip(),
        p_transforms.ToTensor(),
    ]

    dataset = dataset.map(operations=img_transforms,
                          input_columns=columns_list[0],
                          num_parallel_workers=1)
    dataset = dataset.batch(100)
    return dataset

### 然后定义将推理结果标记到图片上的函数

In [26]:
class Color(Enum):
    """An enum that defines engine colors.

    Contains red, green, blue, cyan, yellow, magenta, white and black.
    """
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


def color_val(color):
    """Convert various input to color tuples.

    Args:
        color (:obj:`Color`/str/tuple/int/ndarray): Color inputs

    Returns:
        tuple[int]: A tuple of 3 integers indicating BGR channels.
    """
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')


def imshow(img, win_name='', wait_time=0):
    """
    Show an image.

    Args:
        img (str or ndarray): The image to be displayed.
        win_name (str): The window name.
        wait_time (int): Value of waitKey param.
    """
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)


def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the results on the picture.

    Args:
        img (str): The image to be displayed.
        result (dict): The classification results to draw over `img`.
        text_color (str or tuple or :obj:`Color`): Color of texts.
        font_scale (float): Font scales of texts.
        row_width (int): width between each row of results on the image.
        show (bool): Whether to show the image. Default: False.
        win_name (str): The window name.
        wait_time (int): Value of waitKey param. Default: 0.
        out_file (str or None): The filename to write the image. Default: None.

    Returns:
        None
    """
    img = imread(img, mode="RGB")
    img = img.copy()

    # Write results on left-top of the image.
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width

    # If out_file specified, do not show image in window.
    if out_file:
        show = False
        imwrite(img, out_file)

    if show:
        imshow(img, win_name, wait_time)


### 模型推理文件

In [45]:
def mobilevit_infer(args_opt):
    """mobilevit infer"""

    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)

    # Data pipeline.
    dataset_analyse = ImageNet(args_opt.data_url,
                               split="val",
                               num_parallel_workers=8,
                               resize=args_opt.resize,
                               batch_size=args_opt.batch_size)

    # Create model.
    network = MobileViT(model_type=args_opt.model_type, num_classes=args_opt.num_classes)

    # load pertain model
    param_dict = load_checkpoint(args_opt.pretrained_model)
    load_param_into_net(network, param_dict)

    # Init the model.
    model = Model(network)

    # read inference picture
    image_list, image_label = read_dataset(args_opt.infer_url)
    columns_list = ('image', 'label')
    dataset_infer = ds.GeneratorDataset(DatasetGenerator(image_list, image_label),
                                        column_names=list(columns_list),
                                        num_parallel_workers=args_opt.num_parallel_workers,
                                        python_multiprocessing=False)
    dataset_infer = infer_transform(dataset_infer, columns_list, args_opt.resize)

    # read data for inference
    for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
        image = image["image"]
        image = Tensor(image, mindspore.float32)
        prob = model.predict(image)
        label = np.argmax(prob.asnumpy(), axis=1)
        predict = dataset_analyse.index2label[int(label)]
        output = {int(label): predict}
        print(output)
        show_result(img=image_list[i], result=output, out_file=image_list[i])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MobileViT infer.')
    parser.add_argument("--resize", type=int, default=256, help="Image resize.")
    parser.add_argument("--data_url", default="./dataset", help="Location of data.")
    parser.add_argument('--model_type', default='xx_small', type=str, metavar='model_type')
    parser.add_argument("--batch_size", type=int, default=100, help="Number of batch size.")
    parser.add_argument('--num_classes', type=int, default=1001, help='Number of classification.')
    parser.add_argument("--infer_url", default="./test_image.JPEG", help="Location of inference data.")
    parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend", "GPU", "CPU"])
    parser.add_argument('--pretrained_model', default="./Ascend_mobilevit/Mobilevit_Ckpt/4/mobilevit_xxs_6-180_20018.ckpt", type=str, metavar='PATH')
    parser.add_argument("--num_parallel_workers", type=int, default=8, help="Number of parallel workers.")
    args = parser.parse_known_args()[0]
    mobilevit_infer(args)

{7: 'cock'}


![image.png](image/infer.jpg)

## 总结

本案例对mobilevit的论文中提出的模型进行了详细的解释，向读者完整地展现了该算法的流程。如需查看详细代码，可参考mindspore/course/application_example/mobilevit

## 引用

[1] Mehta S ,  Rastegari M . MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer[J].  2021.