# 图像分类网络之ConvNeXt

## 模型简介

ConvNeXt网络由Facebook AI研究所和UC Berkeley大学共同提出，它是一个面向2020s年代的卷积神经网络模型，并在论文[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)中首次对其进行描述。

ConvNeXt网络并没有在整体的网络框架和搭建思路上做出重大的创新，它主要是按照Transformer网络的一些思想对现有的经典ResNet网络做了一些改进，该网络在多个分类任务和识别任务中均超越了Swin-T模型达到最佳的性能表现。其中在ImageNet-1K的分类任务中，ConvNext网络与其他经典网络相比，性能如下图所示。

<div align=center><img src="./images/conv1.png"></div>

其中，图的左半部分可发现与ReNet、Swin-T等经典模型相比，ConvNext对分辨率为224x224的ImageNet-1K具有更好的分类精度。相同的，在图的右半部分中与ViT、Swin-T等经典模型相比，ConvNext对分辨率为384x384的ImageNet-22K具有更高的精度。不论是ImageNet-1K还是ImageNet-22K，ConvNext网络将2010年代的网络精度进行了提升，因此论文将所提出的ConvNext称为“2020年代的卷积网络”。

### 网络特点

下图概括了ConvNext所有的优化点，同时图中还给出了每一个优化点对网络精度以及GFLOPs的影响。它从ResNet-50/ResNet-200出发，依次从宏观设计（Macro design）、深度可分离卷积（ResNeXt）、逆瓶颈层（Inverted bottleneck）、大卷积核（Large kernel size）以及细节设计（Various layer-wise Micro designs）这五个方面依次借鉴了Swin Transformer的思想，然后在ImageNet-1K上进行了训练和评估，由下图可发现，在相同的GFLOPs下，相较于Swin TransFormer，ConvNext-T/B精度提升了0.7。

<div align=center><img src="./images/conv2.png"></div>

接下来依次介绍改动的五个部分，但在详细介绍每个部分之前，先介绍一下训练方法上的改进。

#### 训练技巧

论文采用与Swin Transformer相似的训练方法，大致从以下四个方法进行改进：

（1）epoch从ResNet的90增加到300；

（2）优化器从SGD改用AdamW优化器；

（3）在数据增强方面引进Mixup、Cutmix、RandAugment和Random Erasing等；

（4）增加正则化策略，例如使用随机深度、标签平滑、EMA等。

更加具体的预训练和微调技巧的超参数如下图。

<div align=center><img src="./images/train.jpg"></div>

采用这些训练技巧后，ResNet-50的性能提升了2.7%，从76.1%提升到78.8%。该结果证明，传统卷积与Transformer之间的性能差异也有一部分来自于训练的方法。

接下来对模型本身五大优化点进行详细的介绍。

#### Macro design

宏观设计部分主要有两项改动：

（1）stage比例：

ResNet和Swin-T网络均有四个stage阶段，其中Swin-T各个阶段堆叠Block块的比例为1:1:3:1，Swin-L 堆叠的比例为 1:1:9:3，由此可以发现 Transformer 网络的第三层的堆叠数量较多，因此ConvNeXt网络依照这个比例将ResNet各阶段的堆叠次数从 (3, 4, 6, 3)调整为(3, 3, 9, 3) ，其比例也保持在 1:1:3:1,改动如下图所示。

<div align=center><img src="./images/macro.png"></div>

这项改动使模型精度提高了0.6%，到达79.4%。

（2）Patchify Stem:

Swin-T 网络中的 stem 层为一个卷积核大小为4，步距为4的卷积层，而经典的 ResNet50的stem层是由一个卷积核大小为7，步距为2的卷积层加一个核大小为3，步距为2的最大池化层构成的。因此,ConvNeXt网络将stem层换成了与Swin-T网络相同的卷积核大小为4，步距为4的卷积层。这项改动给模型精度再度带来0.1%的提升，精度达到79.5%。

#### ResNeXt-ify

这一部分中，尝试使用ResNeXt的核心思想--分组卷积，其中为弥补模型容量上的损失增加了网络宽度。同时ConvNext直接让分组数与输入通道数相等，设为96。这样每个卷积核处理一个通道，只在空间维度上做信息混合，获得与自注意力机制类似的效果。这项改动使网络性能再提高1%，达到80.5%。

#### Inverted bottleneck

在Transformer网络中的MLP模块及MobileNet V2中的Inverted Bottleneck模块，都是采用“两头细，中间粗”的反瓶颈结构。因此，ConvNeXt 网络也参照设计了一个类似的 Inverted bottleneck 结构，如下图所示，该过程为从（a）到（b）。

<div align=center><img src="./images/invert.jpg"></div>

在做完这样的反转之后，虽然depthwise卷积层的FLOPs有所增加，但下采样残差块作用下，整个网络的FLOPs反而被减少，模型精度也提高了0.1%，达到了80.6%。

#### Large kernel size

在经典的CNN网络中我们一般习惯于使用 3×3 的卷积核，Swin Transformer引入了类似卷积核的局部窗口机制，但大小至少有7x7。而 ConvNeXt测试了各种不同尺寸的卷积核，在测试的过程中发现，反转瓶颈层之后放大了卷积层的维度，直接增大卷积核会让参数量显著增加。所以在这之前，还要再做一步操作，在反转瓶颈层的基础上把depthwise卷积层提前，如下图（b）到（c）。

<div align=center><img src="./images/Large.jpg"></div>

这项改动暂时将模型精度下降到了79.9%。之后对卷积核大小的试验从3x3到11x11都有尝试，在7x7时模型精度重回80.6%。再往上增加效果则不明显，在ResNet-200上同样如此，最后卷积核大小就定在7x7。

#### Various layer-wise Micro designs

该部分主要将重点放在了激活函数和归一化上，主要进行以下五部分的微观改动。

（1）传统的 CNN 网络中通常使用 RELU 作为网络的激活函数，而目前 Transformer 类型的网络主流上采用 GELU 激活函数，因此 ConvNeXt 网络将 RELU 替换为更常用的 GELU 激活函数。

（2）Swin-T 网络的每一个 Swin Transformer Block 中均只含有一个激活函数，因此受 Swin-T 的启发，ConvNeXt 网络减少了激活函数的使用，每个块只使用一个激活函数，部署在在第二层之后。

（3）与激活函数类似，ConvNeXt 网络也减少了正则化函数的使用，每个块只使用一个正则化函数，部署在第一层之后。

（4）ConvNeXt 不仅减少了正则化函数的使用，还将正则化函数由 BN 替换成 LN。

（5）参考 Swin-T 网络中的 Patch Merging 模块，ConvNeXt 网络单独设计了一个下采样层对特征进行单独的下采样操作。

将以上所有改动汇总起来，ConvNext单个块的结构如下图所示。最终精度达到82.0%，优于Swin-T的81.3%。

<div align=center><img src="./images/Microdesign.png"></div>

> 本教程将使用ImageNet数据集对ConvNeXt网络进行训练，并对测试结果进行可视化展示。为了节省运行时间，建议用户使用Ascend来运行本实验。

## 构建数据

开始实验之前，请确保本地已经安装了Python环境并安装了MindSpore Vision套件1.7.0版本。

### 数据准备

在本教程中，我们将使用[ImageNet数据集](https://image-net.org/)，该数据集总共1000个类，每张都是224*224的彩色图像。其中训练集共1,281，167张图像，测试集共50,000张图像。
本案例应用的数据集是ImageNet中筛选出来的子集，运行第一段代码时会自动下载并解压。请确保你的数据集路径如下所示：

```Text
.ImageNet/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val
```

### 数据加载及处理

在进行数据处理之前，先导入本案例所需的所有参数。

In [1]:
import ast
import argparse

def parse_args():
    """
    Parse parameters.

    Returns:
        parsed parameters.
    """
    parser = argparse.ArgumentParser(description='config')
    parser.add_argument("--seed", default=0, type=int, help="seed for initializing training. ")
    parser.add_argument("--device_target", default="Ascend", choices=["GPU", "Ascend"], type=str)
    parser.add_argument("--device_id", default=0, type=int, help="device id")
    parser.add_argument("-a", "--arch", metavar="ARCH", default="convnext_tiny", help="model architecture")
    parser.add_argument("--in_chans", default=3, type=int)
    parser.add_argument("--num_classes", default=1000, type=int)
    parser.add_argument("--drop_path_rate", default=0.1, type=float)
    parser.add_argument("--amp_level", default="O1", choices=["O0", "O1", "O2", "O3"], help="AMP Level")
    parser.add_argument("--label_smoothing", type=float, help="label smoothing to use, default 0.1", default=0.1)
    parser.add_argument("--mix_up", default=0.8, type=float, help="mix up")
    parser.add_argument("--cutmix", default=1.0, type=float, help="cutmix")
    parser.add_argument("--run_modelarts", type=ast.literal_eval, default=False, help="whether run on modelarts")
    parser.add_argument("--pretrained", dest="pretrained", default="/home/ma-user/work/check/src/train_parallel0/convnext_tiny0-300_533.ckpt", type=str, help="use pre-trained model")
    parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet")
    parser.add_argument('--data_url', default="/home/ma-user/work/imagenet2012", help='location of data.')
    parser.add_argument("--image_size", default=224, help="image Size.", type=int)
    parser.add_argument('--interpolation', type=str, default="bicubic")
    parser.add_argument('--auto_augment', type=str, default="rand-m9-mstd0.5-inc1")
    parser.add_argument("-j", "--num_parallel_workers", default=1, type=int, metavar="N",
                        help="number of data loading workers (default: 1)")
    parser.add_argument("--batch_size", default=300, type=int, metavar="N",
                        help="mini-batch size (default: 256), this is the total "
                             "batch size of all Devices on the current node when "
                             "using Data Parallel or Distributed Data Parallel")
    parser.add_argument("--re_prob", default=0.0, type=float, help="re prob")
    parser.add_argument('--re_mode', type=str, default="pixel")
    parser.add_argument("--re_count", default=1, type=int, help="re count")
    parser.add_argument("--mixup_prob", default=1., type=float, help="mixup prob")
    parser.add_argument("--switch_prob", default=0.5, type=float, help="switch prob")
    parser.add_argument("--mixup_mode", default='batch', type=str, help="mixup_mode")
    parser.add_argument("--optimizer", help="Which optimizer to use", default="adamw")
    parser.add_argument("--lr_scheduler", default="cosine_lr", help="schedule for the learning rate.")
    parser.add_argument("--warmup_length", default=20, type=int, help="number of warmup iterations")
    parser.add_argument("--warmup_lr", default=0.00000007, type=float, help="warm up learning rate")
    parser.add_argument("--base_lr", default=0.004, type=float, help="base learning rate")
    parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument("--min_lr", default=0.0000006, type=float, help="min learning rate")
    parser.add_argument("--start_epoch", default=0, type=int, metavar="N",
                        help="manual epoch number (useful on restarts)")
    parser.add_argument("--accumulation_step", default=1, type=int, help="accumulation step")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument("--beta", default=[0.9, 0.999], type=lambda x: [float(a) for a in x.split(",")],
                        help="beta for optimizer")
    parser.add_argument("--eps", default=1e-8, type=float)
    parser.add_argument("--wd", "--weight_decay", default=0.05, type=float, metavar="W",
                        help="weight decay (default: 0.05)", dest="weight_decay")
    parser.add_argument("--is_dynamic_loss_scale", default=True, type=bool, help="is_dynamic_loss_scale ")
    parser.add_argument("--loss_scale", default=1024, type=int, help="loss_scale")
    parser.add_argument("--with_ema", default=False, type=ast.literal_eval, help="training with ema")
    parser.add_argument("--save_every", default=20, type=int, help="save every ___ epochs(default:20)")
    parser.add_argument('--train_url', default="./", help='location of training outputs.')
    parser.add_argument('--resize', type=int, default=224, help='Resize the image.')
    parser.add_argument('--repeat_num', type=int, default=1, help='Number of repeat.')
    parser.add_argument('--ckpt_save_dir', type=str, default="./ConvNeXt", help='Location of training outputs.')
    parser.add_argument("--ema_decay", default=0.9999, type=float, help="ema decay")

    return parser.parse_known_args()[0]

if __name__ == '__main__':
    args = parse_args()
    print("parameter is finished.")

parameter is finished.


本教程使用的是ImageNet数据集，在mindspore中加载该数据集可使用两个接口：一是mindspore.dataset.GeneratorDataset，另一个是mindspore.dataset.ImageFolderDataset。考虑两者在加载数据集时网络的性能问题，本案例使用ImageFolderDataset接口对数据集进行加载。使该数据成为mindspore所能识别的格式。

将ImageNet数据集加载成mindspore所能识别的格式之后，使用RandomCropDecodeResize、RandomHorizontalFlip等诸多数据处理接口完成数据集的处理工作。

本案例的create_dataset_imagenet函数中，包含了对训练集及验证集的处理操作，具体代码如下所示。

In [2]:
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as transforms
import mindspore.dataset.vision.py_transforms as py_transforms
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision.utils import Inter

from src.process_datasets.augment.auto_augment import pil_interp, rand_augment_transform
from src.process_datasets.augment.mixup import Mixup
from src.process_datasets.augment.random_erasing import RandomErasing


def create_dataset_imagenet(dataset_dir, args, repeat_num=1, training=True):
    """
    create a train or eval imagenet2012 dataset for TNT

    Args:
        dataset_dir(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1

    Returns:
        dataset
    """
    # 使用ImageFolderDataset加载数据集
    if training:
        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers,
                                         shuffle=True)
    else:
        data_set = ds.ImageFolderDataset(dataset_dir, num_parallel_workers=args.num_parallel_workers,
                                         shuffle=False)
    image_size = 224
    # define map operations
    # BICUBIC: 3
    if training:
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        aa_params = dict(
            translate_const=int(image_size * 0.45),
            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
        )
        # 进行插值运算
        interpolation = args.interpolation
        # 数据增强
        auto_augment = args.auto_augment
        assert auto_augment.startswith('rand')
        aa_params['interpolation'] = pil_interp(interpolation)

        transform_img = [
            transforms.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(3 / 4, 4 / 3),
                                              interpolation=Inter.PILCUBIC),
            transforms.RandomHorizontalFlip(prob=0.5),
            py_transforms.ToPIL()
        ]
        transform_img += [rand_augment_transform(auto_augment, aa_params)]
        transform_img += [
            py_transforms.ToTensor(),
            py_transforms.Normalize(mean=mean, std=std),
            RandomErasing(args.re_prob, mode=args.re_mode, max_count=args.re_count)
        ]
    else:
        mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
        std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
        # test transform complete
        transform_img = [
            transforms.Decode(),
            transforms.Resize(int(256 / 224 * image_size), interpolation=Inter.PILCUBIC),
            transforms.CenterCrop(image_size),
            transforms.Normalize(mean=mean, std=std),
            transforms.HWC2CHW()
        ]

    transform_label = C.TypeCast(mstype.int32)

    data_set = data_set.map(input_columns="image", num_parallel_workers=args.num_parallel_workers,
                            operations=transform_img)
    data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
                            operations=transform_label)
    if (args.mix_up > 0. or args.cutmix > 0.) and not training:
        # if use mixup and not training(False), one hot val data label
        one_hot = C.OneHot(num_classes=args.num_classes)
        data_set = data_set.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
                                operations=one_hot)
    # apply batch operations
    data_set = data_set.batch(args.batch_size, drop_remainder=True,
                              num_parallel_workers=args.num_parallel_workers)

    if (args.mix_up > 0. or args.cutmix > 0.) and training:
        mixup_fn = Mixup(
            mixup_alpha=args.mix_up, cutmix_alpha=args.cutmix, cutmix_minmax=None,
            prob=args.mixup_prob, switch_prob=args.switch_prob, mode=args.mixup_mode,
            label_smoothing=args.label_smoothing, num_classes=args.num_classes)

        data_set = data_set.map(operations=mixup_fn, input_columns=["image", "label"],
                                num_parallel_workers=args.num_parallel_workers)

    # apply dataset repeat operation
    data_set = data_set.repeat(repeat_num)

    return data_set


if __name__ == '__main__':
    # 设置数据加载及处理时所需的参数
    args = parse_args()
    train_dir = "/home/ma-user/work/imagenet2012/train"
    val_dir = "/home/ma-user/work/imagenet2012/val"
    # 加载并处理数据
    dataset_train = create_dataset_imagenet(train_dir, args, 1, True)
    dataset_eval = create_dataset_imagenet(val_dir, args, 1, False)
    print("train and eval dataset is finished.")

train and eval dataset is finished.


## 构建网络

ConvNeXt主干网络主要由ConvNextLayerNorm、Block模块组成，下面将主要介绍这两个网络模块。

### ConvNextLayerNorm模块

该模块与上述Various layer-wise Micro designs部分中的第三个微观改动相呼应。区别于其他网络，ConvNeXt网络不仅减少了正则化函数的使用，还将正则化函数由BatchNormal替换成LayerNormal。在该类中，当data_format为“channel_last”时，直接使用官方给出的mindspore.ops.LayerNorm进行正则化。反之，当data_format为“channel_first”时，先使用mindspore.ops.Transpose对数据进行通道的转换，使数据格式成为NHWC，再接着使用mindspore.ops.LayerNorm完成正则化任务。

如下代码定义ConvNextLayerNorm类实现ConvNeXt LayerNorm结构。

In [3]:
from mindspore import ops, nn


class ConvNextLayerNorm(nn.LayerNorm):
    """ConvNextLayerNorm"""
    def __init__(self, normalized_shape, epsilon, norm_axis=-1):
        super(ConvNextLayerNorm, self).__init__(normalized_shape=normalized_shape, epsilon=epsilon)
        assert norm_axis in (-1, 1), "ConvNextLayerNorm's norm_axis must be 1 or -1."
        self.norm_axis = norm_axis

    def construct(self, input_x):
        if self.norm_axis == -1:
            y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
        else:
            input_x = ops.Transpose()(input_x, (0, 2, 3, 1))
            y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
            y = ops.Transpose()(y, (0, 3, 1, 2))
        return y


### Block模块

Block结构参照了在Transformer网络中的MLP模块及MobileNet V2中的Inverted Bottleneck模块，采用了“两头细，中间粗”的反瓶颈结构。同时，相较于经典的CNN网络使用$3\times3$大小的卷积核和ReLU激活函数，该模块中使用了$7\times7$大小的大卷积核以及GELU激活函数。

如下代码定义Block类实现Block结构。

In [4]:
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import dtype as mstype

from src.models.layers.drop_path import DropPath2D
from src.models.layers.identity import Identity


class Block(nn.Cell):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Dense -> GELU -> Dense; Permute back
    We use (2) as we find it slightly faster in PyTorch
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, group=dim, has_bias=True)  # depthwise conv
        self.norm = ConvNextLayerNorm((dim,), epsilon=1e-6)
        self.pwconv1 = nn.Dense(dim, 4 * dim)  # pointwise/1x1 convs, implemented with Dense layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Dense(4 * dim, dim)
        self.gamma = Parameter(Tensor(layer_scale_init_value * np.ones((dim)), dtype=mstype.float32),
                               requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath2D(drop_path, 2) if drop_path > 0. else Identity()

    def construct(self, x):
        """Block construct"""
        downsample = x
        x = self.dwconv(x)
        x = ops.Transpose()(x, (0, 2, 3, 1))
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = ops.Transpose()(x, (0, 3, 1, 2))
        x = downsample + self.drop_path(x)
        return x

#### DropPath2D层

在上述Block结构中，应用了DropPath2D层，该层根据给定的概率drop_prob来随机选择网络上数值传递的路径进行drop，可以对整体的模型训练起到防止过拟合的作用，并且参数值根据keep_prob进行量化。

如下代码给出DropPath2D的定义方式。

In [5]:
import numpy as np

from mindspore import nn, ops


class DropPath2D(nn.Cell):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""
    def __init__(self, drop_prob, ndim):
        super(DropPath2D, self).__init__()
        self.drop = nn.Dropout(keep_prob=1 - drop_prob)
        shape = (1,) + (1,) * (ndim + 1)
        self.ndim = ndim
        self.mask = Tensor(np.ones(shape), dtype=mstype.float32)

    def construct(self, x):
        if not self.training:
            return x
        mask = ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
        out = self.drop(mask)
        out = out * x
        return out

### 构建ConvNeXt主干网络

ConvNeXt模型主干网络依次由四个downsample_layers层、以及四个stage层组成，其中，downsample_layers的第一层依次由nn.Conv2d、ConvNextLayerNorm构成，其余三层依次由ConvNextLayerNorm、nn.Conv2d构成。每个stage层由Block模块组成，其中每个stage层中利用depth参数控制Block个数，dims参数控制每部分的特征数目。

在主干网络的代码中，值得一说的是init_weights函数。该函数可对网络中的nn.Dense以及nn.Conv2d的权重进行方差为0.02的TruncatedNormal初始化，对网络中Dense中的bias初始化为Zero。

In [6]:
import mindspore.nn as nn
from mindspore.common import initializer as weight_init

def init_weights(self):
    """init_weights"""
    for _, cell in self.cells_and_names():
        if isinstance(cell, (nn.Dense, nn.Conv2d)):
            cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
                                                         cell.weight.shape,
                                                         cell.weight.dtype))
            if isinstance(cell, nn.Dense) and cell.bias is not None:
                cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
                                                           cell.bias.shape,
                                                           cell.bias.dtype))

如下代码定义ConvNeXt的主干网络结构。

In [7]:
import numpy as np

import mindspore.nn as nn
from mindspore.common import initializer as weight_init


class ConvNeXt(nn.Cell):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """

    def __init__(self, in_chans, num_classes, depths, dims, drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.):
        super().__init__()

        self.downsample_layers = nn.CellList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.SequentialCell(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, has_bias=True),
            ConvNextLayerNorm((dims[0],), epsilon=1e-6, norm_axis=1)
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.SequentialCell(
                ConvNextLayerNorm((dims[i],), epsilon=1e-6, norm_axis=1),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, has_bias=True),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.CellList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x for x in np.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            blocks = []
            for j in range(depths[i]):
                blocks.append(Block(dim=dims[i], drop_path=dp_rates[cur + j],
                                    layer_scale_init_value=layer_scale_init_value))
            stage = nn.SequentialCell(blocks)
            self.stages.append(stage)
            cur += depths[i]

        self.norm = ConvNextLayerNorm((dims[-1],), epsilon=1e-6)  # final norm layer
        self.head = nn.Dense(dims[-1], num_classes)

        self.init_weights()
        self.head.weight.set_data(self.head.weight * head_init_scale)
        self.head.bias.set_data(self.head.bias * head_init_scale)

    def init_weights(self):
        """init_weights"""
        for _, cell in self.cells_and_names():
            if isinstance(cell, (nn.Dense, nn.Conv2d)):
                cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
                                                             cell.weight.shape,
                                                             cell.weight.dtype))
                if isinstance(cell, nn.Dense) and cell.bias is not None:
                    cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
                                                               cell.bias.shape,
                                                               cell.bias.dtype))

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)

    def construct(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

## 模型实现

本案例传入ConvNeXt类中depths为[3, 3, 9, 3]，以及dims为[96, 192, 384, 768]，构成tiny规格的convnext_tiny模型，接下来将详细介绍convnext_tiny模型。

### convnext_tiny模型

convnext_tiny网络共有5个卷积结构，一个平均池化层，一个LayerNorm层，一个全连接层，以ImageNet数据集为例：

+ **stem**：输入图片大小为$224\times224$，输入channel为3.首先经过一个卷积核数量为96，卷积核大小为$4\times4$，stride为4的卷积层，接着通过一个LayerNorm层。该层输出feature map大小为$56\times56$，输出channel为96。
+ **res2**：输入feature map大小为$56\times56$，输入channel为96。经过堆叠3个$[d7\times7，96；1\times1，384；1\times1，96]$结构的ConvNeXt Block。该层输出feature map大小为$56\times56$，输出channel为96。
+ **res3**：输入feature map大小为$56\times56$，输入channel为96。经过依次堆叠3个$[d7\times7，192；1\times1，768；1\times1，192]$结构的DownSample块及ConvNeXt Block块。该层输出feature map大小为$28\times28$，输出channel为192。
+ **res4**：输入feature map大小为$28\times28$，输入channel为192。经过依次堆叠9个$[d7\times7，384；1\times1，1536；1\times1，384]$结构的DownSample块及ConvNeXt Block块。该层输出feature map大小为$14\times14$，输出channel为384。
+ **res5**：输入feature map大小为$14\times14$，输入channel为384。经过依次堆叠3个$[d7\times7，768；1\times1，3072；1\times1，768]$结构的DownSample块及ConvNeXt Block块。该层输出feature map大小为$7\times7$，输出channel为768。
+ **average pool & LayerNorm & fc**：输入channel为768，输出channel为分类的类别数。

如下示例代码实现convnext_tiny模型的构建，通过用调函数convnext_tiny即可构建convnext_tiny模型：

In [8]:
def convnext_tiny(args, **kwargs):
    """convnext_tiny"""
    model = ConvNeXt(in_chans=3, num_classes=1000,
                     depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
                     drop_path_rate=0.1, **kwargs)
    return model

## 模型训练

本节调用convnext_tiny网络，然后定义AdamWeightDecay优化器和SoftTargetCrossEntropy损失函数，通过model.train接口对网络进行训练，其中，该网络采用混合精度的训练方法，使用支持单精度和半精度数据来提高网络的训练速度，同时保持单精度训练的精度。在训练的过程中将会打印训练的损失值，并保存评估精度最高的ckpt文件。下面主要介绍在训练过程中一些比较重要的函数。

### 参数分组

在网络训练的过程中使用get_param_groups函数达到参数分组的目的。大多数其他经典网络在训练的过程中一般都是寻找所有可训练的参数，将其构建成一个列表，再将该列表传入优化器。但在本案例ConvNeXt网络的训练中，主要将参数分为两组，一组为不进行LR正则的参数，例如模块名称为bias的参数；另一组为进行LR正则的参数，例如模块名称为weight的参数。

如下代码实现参数分组目标。

In [9]:
def get_param_groups(network):
    """ get param groups """
    decay_params = []
    no_decay_params = []
    for x in network.trainable_params():
        parameter_name = x.name
        if parameter_name.endswith(".weight"):
            # Dense or Conv's weight using weight decay
            decay_params.append(x)
        else:
            # all bias not using weight decay
            # bn weight bias not using weight decay, be carefully for now x not include LN
            no_decay_params.append(x)

    return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

### 学习率

在本案例中，学习率主要由两部分结合而成。如果当前epoch数小于学习率热身轮数时，此时的lr变换遵从wram_up的方法，否则，采用余弦下降的学习率，具体实现代码如下所示。

In [10]:
import numpy as np

def _warmup_lr(warmup_lr, base_lr, warmup_length, epoch):
    """Linear warmup"""
    return epoch / warmup_length * (base_lr - warmup_lr) + warmup_lr


def cosine_lr(args, batch_num):
    """Get cosine lr"""
    learning_rate = []

    def _lr_adjuster(epoch):
        if epoch < args.warmup_length:
            lr = _warmup_lr(args.warmup_lr, args.base_lr, args.warmup_length, epoch)
        else:
            e = epoch - args.warmup_length
            es = args.epochs - args.warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * args.base_lr

        return lr

    for epoch in range(args.epochs):
        for batch in range(batch_num):
            learning_rate.append(_lr_adjuster(epoch + batch / batch_num))
    learning_rate = np.clip(learning_rate, args.min_lr, max(learning_rate))
    return learning_rate

### 训练

训练模型前，需要先按照论文中给出的参数设置损失函数，优化器以及回调函数口。本案例为了方便展示，将epochs设置为10，sink_size设置为10，使其跑10轮，每轮跑10个step的数据，具体代码如下所示。

In [11]:
import os

from mindspore import Model
from mindspore import context
from mindspore import nn
from mindspore.common import set_seed
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor

from src.models.loss import get_criterion, NetWithLoss, get_train_one_step
from src.process_datasets.imagenet import ImageNet
from src.models.convnext import convnext_tiny
from src.utils.cell import cast_amp
from src.utils.optimizer import get_optimizer


def train(args):
    """"Train ConvNext model."""
    # 设置随机数及环境支持设备和模式
    set_seed(args.seed)
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
   # 获取卡号和卡数
    rank = 0
    # 获取模型，并将模型cast成fp16
    net = convnext_tiny(args)
    cast_amp(net, args)
    # 获取loss
    criterion = get_criterion(args)
    # 将loss和网络进行连接
    net_with_loss = NetWithLoss(net, criterion)
    # 获取训练数据集的大小
    data = ImageNet(args, True)
    batch_num = data.train_dataset.get_dataset_size()
    # 获取Adamw优化器
    optimizer = get_optimizer(args, net, batch_num)
    # 对网络进行封装，使其具备train的条件
    net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
    eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
    eval_indexes = [0, 1, 2]
    model = Model(net_with_loss, metrics={"acc", "loss"},
                  eval_network=eval_network,
                  eval_indexes=eval_indexes)
    # 设置保存多少个ckpt和时常计时器
    config_ck = CheckpointConfig(save_checkpoint_steps=data.train_dataset.get_dataset_size(),
                                 keep_checkpoint_max=args.save_every)
    time_cb = TimeMonitor(data_size=data.train_dataset.get_dataset_size())
    # 设置ckpt的保存位置
    ckpt_save_dir = "./ckpt_" + str(rank)
    if args.run_modelarts:
        ckpt_save_dir = "/cache/ckpt_" + str(rank)
    ckpoint_cb = ModelCheckpoint(prefix=args.arch + str(rank), directory=ckpt_save_dir,
                                 config=config_ck)
    # 设置loss检测器
    loss_cb = LossMonitor()
    # 开始训练
    print("begin train")
    model.train(5, data.train_dataset,
                callbacks=[time_cb, ckpoint_cb, loss_cb],
                dataset_sink_mode=True,
                sink_size=10)
    print("train success")


if __name__ == '__main__':
    train(args)

=> using amp_level O1
=> When using train_wrapper, using optimizer adamw
=> Get LR from epoch: 0
=> Start step: 0
=> Total step: 1281000
=> Accumulation step:1
learning_rate 0.004
=> Using DynamicLossScaleUpdateCell
begin train
epoch: 1 step: 10, loss is 7.037727355957031
epoch time: 137322.846 ms, per step time: 13732.285 ms
epoch: 2 step: 10, loss is 7.037711143493652
epoch time: 29620.601 ms, per step time: 2962.060 ms
epoch: 3 step: 10, loss is 7.011240005493164
epoch time: 29268.201 ms, per step time: 2926.820 ms
epoch: 4 step: 10, loss is 7.029295921325684
epoch time: 29534.745 ms, per step time: 2953.474 ms
epoch: 5 step: 10, loss is 6.974303245544434
epoch time: 28968.114 ms, per step time: 2896.811 ms
train success


## 模型评估

使用上述验证精度最高的模型对ImageNet测试数据集进行验证。在此过程中主要应用了Model、ImageNet、convnext_tiny、load_checkpoint、load_param_into_net等诸多接口。在评估的过程中，为方便后续云上运行及终端的调试工作，将load_checkpoint、load_param_into_net等接口封装在函数pretrained中，以下是函数pretrained的具体代码。

In [12]:
from src.models.convnext import convnext_tiny
from src.utils.cell import cast_amp

from mindspore.train.serialization import load_checkpoint, load_param_into_net


def pretrained(args, model):
    """"Load pretrained weights if args.pretrained is given"""
    # 云上运行时，加载ckpt文件
    if args.run_modelarts:
        print('Syncing data.')
        local_data_path = '/cache/weight'
        name = args.pretrained.split('/')[-1]
        path = f"/".join(args.pretrained.split("/")[:-1])
        sync_data(path, local_data_path, threads=128)
        args.pretrained = os.path.join(local_data_path, name)
        print("=> loading pretrained weights from '{}'".format(args.pretrained))
        param_dict = load_checkpoint(args.pretrained)
        for key, value in param_dict.copy().items():
            if 'head' in key:
                if value.shape[0] != args.num_classes:
                    print(f'==> removing {key} with shape {value.shape}')
                    param_dict.pop(key)
        # 将训练好的权重加载到网络模型中
        load_param_into_net(model, param_dict)
    # 终端使用时，加载ckpt文件
    elif os.path.isfile(args.pretrained):
        print("=> loading pretrained weights from '{}'".format(args.pretrained))
        param_dict = load_checkpoint(args.pretrained)
        for key, value in param_dict.copy().items():
            if 'head' in key:
                if value.shape[0] != args.num_classes:
                    print(f'==> removing {key} with shape {value.shape}')
                    param_dict.pop(key)
        load_param_into_net(model, param_dict)
    else:
        print("=> no pretrained weights found at '{}'".format(args.pretrained))

验证流程大致可以描述为使用convnext_tiny接口定义网络结构，加载ImageNet数据集，并将[ckpt文件](https://download.mindspore.cn/vision/convnext/convnext_tiny0-300_533.ckpt)中的参数加载到定义好的网络结构中，随后设置损失函数，评价指标等等，最后对模型进行编译验证，其中本教程使用的评价标准为Top_1_Accuracy和Top_5_Accuracy。

模型评估具体代码如下所示。

In [17]:
from mindspore import Model
from mindspore import context
from mindspore import nn
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from src.models.loss import get_criterion, NetWithLoss, get_train_one_step
from src.models.convnext import convnext_tiny
from src.utils.cell import cast_amp
from src.utils.optimizer import get_optimizer
from src.process_datasets.imagenet import ImageNet


def pretrained(args, model):
    """"Load pretrained weights if args.pretrained is given"""
    print("=> loading pretrained weights from '{}'".format("/home/ma-user/work/check/src/train_parallel0/ckpt_0/convnext_tiny0-300_533.ckpt"))
    param_dict = load_checkpoint("/home/ma-user/work/check/src/train_parallel0/ckpt_0/convnext_tiny0-300_533.ckpt")
    for key, value in param_dict.copy().items():
        if 'head' in key:
            if value.shape[0] != args.num_classes:
                print(f'==> removing {key} with shape {value.shape}')
                param_dict.pop(key)
    load_param_into_net(model, param_dict)


def convnext_eval(args):
    """"Eval ConvNext model."""
    set_seed(args.seed)
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
    # 获取模型，并将模型cast成fp16
    net = convnext_tiny(args)
    cast_amp(net, args)
    # 获取loss
    criterion = get_criterion(args)
    # 将获取到的loss与网络连接起来
    net_with_loss = NetWithLoss(net, criterion)
    # 将训练好的权重参数加载到网络模型中
    pretrained(args, net)
    # 加载验证数据集
    data = ImageNet(args, training=False)
    batch_num = data.val_dataset.get_dataset_size()
    # 获取优化器
    optimizer = get_optimizer(args, net, batch_num)
    # 将网络封装为可训练的模式
    net_with_loss = get_train_one_step(args, net_with_loss, optimizer)
    eval_network = nn.WithEvalCell(net, criterion, args.amp_level in ["O2", "O3", "auto"])
    eval_indexes = [0, 1, 2]
    # 加入评价指标Top1-Acc和Top5-Acc
    eval_metrics = {'Loss': nn.Loss(),
                    'Top1-Acc': nn.Top1CategoricalAccuracy(),
                    'Top5-Acc': nn.Top5CategoricalAccuracy()}
    model = Model(net_with_loss, metrics=eval_metrics,
                  eval_network=eval_network,
                  eval_indexes=eval_indexes)
    # 开始验证
    print(f"=> begin eval")
    results = model.eval(data.val_dataset)
    print(f"=> eval results:{results}")
    print(f"=> eval success")


if __name__ == '__main__':
    convnext_eval(args)

=> using amp_level O1
=> loading pretrained weights from '/home/ma-user/work/check/src/train_parallel0/ckpt_0/convnext_tiny0-300_533.ckpt'
=> When using train_wrapper, using optimizer adamw
=> Get LR from epoch: 0
=> Start step: 0
=> Total step: 49800
=> Accumulation step:1
learning_rate 0.004
=> Using DynamicLossScaleUpdateCell
=> begin eval
=> eval results:{'Loss': 0.7911256961075657, 'Top1-Acc': 0.8230522088353414, 'Top5-Acc': 0.9593172690763052}
=> eval success


```text
{'Top1-Acc': 0.8230522088353414, 'Top5-Acc': 0.9593172690763052}
```

## 模型推理

模型推理过程较为简单，首先需要使用ImageNet数据集接口读取要推理的图片，接着对读取到的图片进行Decode、Resize、CenterCrop等操作。

在处理完推理图片之后，将训练好的参数加载到网络中，紧接着通过Model.predict方法对图片进行推理即可，具体代码如下所示。

In [19]:
import os
from enum import Enum
import pathlib
from typing import Dict, Optional
from PIL import Image
from scipy import io
import cv2
import numpy as np

import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore.train import Model
import mindspore.dataset as ds
from mindspore.common import set_seed
from mindspore import context
import mindspore.dataset.vision.c_transforms as transforms
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision.utils import Inter

from src.utils.cell import cast_amp
from src.models.convnext import convnext_tiny


def pretrained(args, model):
    """"Load pretrained weights if args.pretrained is given"""
    print("=> loading pretrained weights from '{}'".format("/home/ma-user/work/check/src/train_parallel0/ckpt_0/convnext_tiny0-300_533.ckpt"))
    param_dict = load_checkpoint("/home/ma-user/work/check/src/train_parallel0/ckpt_0/convnext_tiny0-300_533.ckpt")
    for key, value in param_dict.copy().items():
        if 'head' in key:
            if value.shape[0] != args.num_classes:
                print(f'==> removing {key} with shape {value.shape}')
                param_dict.pop(key)
    load_param_into_net(model, param_dict)


def infer(args):
    """"infer ConvNext model."""
    set_seed(args.seed)
    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
    # get model
    net = convnext_tiny(args)
    cast_amp(net, args)
    criterion = get_criterion(args)
    NetWithLoss(net, criterion)
    pretrained(args, net)
    # Read data for inference
    dataset_infer = ds.ImageFolderDataset(os.path.join(args.data_url, "infer"), shuffle=True)
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    # test transform complete
    transform_img = [
        transforms.Decode(),
        transforms.Resize(int(256 / 224 * 224), interpolation=Inter.PILCUBIC),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=mean, std=std),
        transforms.HWC2CHW()
        ]
    transform_label = C.TypeCast(mstype.int32)
    dataset_infer = dataset_infer.map(input_columns="image", num_parallel_workers=args.num_parallel_workers,
                                      operations=transform_img)
    dataset_infer = dataset_infer.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
                                      operations=transform_label)
    one_hot = C.OneHot(num_classes=args.num_classes)
    dataset_infer = dataset_infer.map(input_columns="label", num_parallel_workers=args.num_parallel_workers,
                                      operations=one_hot)
    # apply batch operations
    dataset_infer = dataset_infer.batch(1, drop_remainder=True,
                                        num_parallel_workers=args.num_parallel_workers)
    model = Model(net)
    for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
        print('i is :', i)
        image = image["image"]
        image = ms.Tensor(image)
        prob = model.predict(image)
        print("predict is finished.")
        label = np.argmax(prob.asnumpy(), axis=1)
        mapping = index2label(args)
        output = {int(label): mapping[int(label)]}
        print(output)
        show_result(img="/home/ma-user/work/imagenet2012/infer/n01440764/ILSVRC2012_test_00000293.jpg",
                    result=output,
                    out_file="/home/ma-user/work/imagenet2012/infer/ILSVRC2012_test_00000.JPEG")


class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")


def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')


def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)

    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")

    return image


def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)

    image = Image.fromarray(image)
    image.save(image_path)


def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)


def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)

    if show:
        imshow(img, win_name, wait_time)


def index2label(args):
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(args.data_url, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']

    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])

    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping


if __name__ == '__main__':
    infer(args)

=> using amp_level O1
=> loading pretrained weights from '/home/ma-user/work/check/src/train_parallel0/ckpt_0/convnext_tiny0-300_533.ckpt'
i is : 0
predict is finished.
{394: 'sturgeon'}
i is : 1
predict is finished.
{394: 'sturgeon'}
i is : 2
predict is finished.
{394: 'sturgeon'}
i is : 3
predict is finished.
{394: 'sturgeon'}


推理后的图片如下图所示：
<div align=center><img src="./images/convnext_infer.jpg"></div>

## 总结

本教程实现了一个ConvNeXt模型在ImageNet数据集上进行训练、验证和推理的过程。其中对ConvNeXt模型结构和原理做了简单介绍。

> 如果要详细了解ConvNeXt模型的工作原理，建议对源码进行深层次的阅读，可以参考以下链接:
> https://gitee.com/mindspore/course/tree/master/application_example/convnext

## 引用

[1] Liu Z ,  Mao H ,  Wu C Y , et al. A ConvNet for the 2020s[J]. arXiv e-prints, 2022.