# 人体关键点检测之MSPN

人体关键点检测（Human Keypoints Detection）又称人体姿态估计，是计算机视觉中一个比较基础的任务。现有的人体姿态识别模型依据处理阶段的不同大致上可以分为两类：Single-stage与Multi-stage，即模型为单阶段或多阶段；依据处理关键点顺序的不同也可分为两类：Top-Down与Bottom-Up，即先定位人体再检测关键点或先检测关键点再关联人体实例。人体姿态识别作为人体动作识别、行为分析、人机交互等的前置任务，吸引了无数公司业务的融入参与以及各大计算机视觉大赛的背景赛题。本文介绍的MSPN模型在MS COCO Competition 2018 Keypoints赛道上拿下了冠军，在COCO数据集上取得了优异的检测效果。

## 模型简介

Multi-Stage Pose Network（MSPN）是一种Multi-stage、Top-Down的人体姿态识别算法。MSPN采用Residual Block作为Backbone结构的搭建范式，使用Cross Stage Feature Aggregation将不同Stage网络之间的信息巧妙地融合起来，再辅以Coarse to Fine Supervision与细致的参数选取，提升了Multi-stage模型的检测精度。

## 数据预处理

在运行本案例之前，请先确保运行环境已经配置好合适版本的Python环境并安装了MindSpore Vision套件。

### 数据下载

本案例使用MS COCO（2014）数据集作为训练集以及验证集，请在[MS COCO官网](https://cocodataset.org/#download)下载对应的图片数据文件，MS COCO标注文件存放于./datasets目录下。

将下载完的数据集以及标注文件解压，数据集图片文件可存放于运行环境所在设备上的任意路径下，这里假设存放于"/data0/coco/coco2014"目录下；将标注文件存放于与src文件夹同目录下，即“./annotation”。coco2014文件夹下包含了MS COCO数据集的训练以及验证图片文件夹；annotation文件夹下包含det_json以及gt_json两个文件夹，分别包含人体目标检测检测框的标注文件以及MS COCO训练集与验证集的标注文件。

数据集图片如下所示：

![MS COCO](./images/90891.jpg)

### 数据加载

通过数据集加载接口加载数据集，并执行相应的标准化操作，以Tensor的形式输出加载后的数据。

In [None]:
from mindspore.dataset.vision.c_transforms import Normalize
from mindspore.dataset.transforms.py_transforms import Compose
import mindspore.dataset.vision.py_transforms as py_trans
import mindspore.dataset as ds

from src.process_datasets.coco import COCODataset


# 构建数据读取器
normalize = Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
transform = Compose([normalize, py_trans.ToTensor()])
dataset = COCODataset(data_dir="/data0/coco/coco2014",
                      gt_file_path="./annotation/gt_json/train_val_minus_minival_2014.json",
                      keypoint_num=17,
                      flip_pairs=[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]],
                      upper_body_ids=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                      lower_body_ids=[11, 12, 13, 14, 15, 16],
                      input_shape=[256, 192],
                      output_shape=[64, 48],
                      stage='train'
                      )
dataloader = ds.GeneratorDataset(source=dataset,
                                 column_names=["img", "valid", "labels"],
                                 shuffle=True,
                                 num_parallel_workers=4)
dataloader = dataloader.map(operations=transform, input_columns=["img"]).batch(32)

## 模型构建

MSPN总体结构如下图所示：

![MSPN](./images/mspn.png)

### 模型基础卷积块

本部分对模型的基础卷积块进行构建，分为Conv-BN-ACTIVATION模块、Residual Block模块、输出图像Tensor预卷积模块。

In [None]:
from typing import Optional

import mindspore.nn as nn


class ConvBlock(nn.Cell):
    """ Basic Convolutional Block for MSPN

    Args:
        in_channels (int): Input Tensor Channels
        out_channels (int): Output Tensor Channels
        kernel_size (int): Convolutional Kernel Size
        stride (int): Convolutional Stride
        padding (int): Convolutional Padding
        use_bn (bool): Whether to Use Batch Normalization. Default: True.
        use_relu (bool): Whether to Use ReLU. Default: True.

    Inputs:
        Tensor

    Returns:
        Tensor, output tensor.

    Examples:
        >>> conv_block = ConvBlock(3, 16, kernel_size=1, stride=1, padding=0, use_bn=True, use_relu=True)
    """
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int,
                 padding: int,
                 use_bn: bool = True,
                 use_relu: bool = True
                 ) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, pad_mode='pad',
                              padding=padding, has_bias=True)
        self.use_bn = use_bn
        self.use_relu = use_relu
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def construct(self, x):
        """"Construct Func"""
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        if self.use_relu:
            x = self.relu(x)

        return x


class Bottleneck(nn.Cell):
    """ Residual Block for MSPN

    Args:
        in_channels (int): Input Tensor Channels
        channels (int): Output Tensor Channels
        stride (int): Convolutional Stride. Default: 1.
        downsample (nn.Cell, optional): Downsample Module Implemented Under the Wrap of nn.Cell. Default: None.

    Inputs:
        Tensor

    Returns:
        Tensor, output tensor.

    Examples:
        >>> bottleneck = Bottleneck(3, 16, stride=1, downsample=None)
    """
    expansion = 4

    def __init__(self,
                 in_channels: int,
                 channels: int,
                 stride: int = 1,
                 downsample: Optional[nn.Cell] = None
                 ) -> None:
        super(Bottleneck, self).__init__()
        self.conv_bn_relu1 = ConvBlock(in_channels, channels, kernel_size=1, stride=1, padding=0, use_bn=True,
                                       use_relu=True)
        self.conv_bn_relu2 = ConvBlock(channels, channels, kernel_size=3, stride=stride, padding=1, use_bn=True,
                                       use_relu=True)
        self.conv_bn_relu3 = ConvBlock(channels, channels * self.expansion, kernel_size=1, stride=1, padding=0,
                                       use_bn=True, use_relu=False)
        self.relu = nn.ReLU()
        self.downsample = downsample

    def construct(self, x):
        """Construct Func"""
        conv_block_out = self.conv_bn_relu1(x)
        conv_block_out = self.conv_bn_relu2(conv_block_out)
        conv_block_out = self.conv_bn_relu3(conv_block_out)

        if self.downsample:
            x = self.downsample(x)

        conv_block_out += x
        conv_block_out = self.relu(conv_block_out)

        return conv_block_out


class ResNetTop(nn.Cell):
    """ First Module of MSPN

    Inputs:
        Tensor

    Returns:
        Tensor, output tensor.

    Examples:
        >>> res_top = ResNetTop()
    """
    def __init__(self) -> None:
        super(ResNetTop, self).__init__()
        self.conv = ConvBlock(3, 64, kernel_size=7, stride=2, padding=3, use_bn=True, use_relu=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, x):
        """Construct Func"""
        out = self.conv(x)
        out = self.maxpool(out)

        return out

### 下采样模块

本部分对模型的下采样模块进行实现，采用上述实现的基础卷积块与残差块对输入图像进行卷积，在减少Tensor尺寸的同时增大通道数，提取高层语义信息，同时输出用于上采样模块聚合信息的Skip Connection信息以及用于Cross Stage Feature Aggregation的信息。

Cross Stage Feature Aggregation的细节结构如下图所示：

![Cross Stage Feature Aggregation](./images/csfa.png)

In [None]:
import mindspore
from mindspore.common.initializer import initializer, HeNormal


class ResNetDownsampleModule(nn.Cell):
    """ Residual Downsample Module for MSPN

    Args:
        block (nn.Cell): Residual Downsample Module Implemented Under the Wrap of nn.Cell
        layer_num_list (list): Num of Stacking Residual Blocks
        use_skip (bool): Whether to Use Cross Stage Feature Aggregation. Default: False.
        zero_init_bn (bool): Whether to Use zero initialization On Weight of Batch Normalization. Default: False.

    Inputs:
        Tensor, skip_tensor of previous MSPN Stage

    Returns:
        Tensor, output tensor.

    Examples:
        >>> downsample = ResNetDownsampleModule(Bottleneck, layer_num_list=[3, 4, 6, 3])
    """
    def __init__(self,
                 block: nn.Cell,
                 layer_num_list: list,
                 use_skip: bool = False,
                 zero_init_bn: bool = False
                 ) -> None:
        super(ResNetDownsampleModule, self).__init__()
        self.use_skip = use_skip
        self.in_channels = 64
        self.layer1 = self._make_layer(block, 64, layer_num_list[0])
        self.layer2 = self._make_layer(block, 128, layer_num_list[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layer_num_list[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layer_num_list[3], stride=2)

        for _, m in self.cells_and_names():
            if isinstance(m, nn.Conv2d):
                m.weight = initializer(HeNormal(mode='fan_out', nonlinearity='relu'), m.weight.shape, mindspore.float32)
            elif isinstance(m, nn.BatchNorm2d):
                m.gamma = initializer(1, m.gamma.shape, mindspore.float32)
                m.beta = initializer(0, m.beta.shape, mindspore.float32)

        if zero_init_bn:
            for _, m in self.cells_and_names():
                if isinstance(m, Bottleneck):
                    m.bn3.weight = initializer(0, m.bn3.weight.shape, mindspore.float32)

    def _make_layer(self, block, channels, num_blocks, stride=1):
        """Stacking Residual Blocks"""
        downsample = None
        if stride != 1 or self.in_channels != channels * block.expansion:
            downsample = ConvBlock(self.in_channels, channels * block.expansion, kernel_size=1, stride=stride,
                                   padding=0, use_bn=True, use_relu=False)

        layers = list()
        layers.append(block(self.in_channels, channels, stride, downsample))
        self.in_channels = channels * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, channels))

        return nn.SequentialCell(*layers)

    def construct(self, x, skip_tensor_1, skip_tensor_2):
        """Construct Func"""
        x1 = self.layer1(x)
        if self.use_skip:
            x1 = x1 + skip_tensor_1[0] + skip_tensor_2[0]
        x2 = self.layer2(x1)
        if self.use_skip:
            x2 = x2 + skip_tensor_1[1] + skip_tensor_2[1]
        x3 = self.layer3(x2)
        if self.use_skip:
            x3 = x3 + skip_tensor_1[2] + skip_tensor_2[2]
        x4 = self.layer4(x3)
        if self.use_skip:
            x4 = x4 + skip_tensor_1[3] + skip_tensor_2[3]

        return x4, x3, x2, x1

### 上采样模块

本部分对模型的上采样模块进行实现，融合了上采样操作的实现以及聚合从下采样模块输出的特征信息以及上一个上采样单元的输出信息，并输出用于进行Coarse to Fine Supervison的监督信息以及用于下一个Stage的Cross Stage Feature Aggragation的跨阶段特征信息。

In [None]:
class UpsampleUnit(nn.Cell):
    """ Upsample Unit for MSPN

    Args:
        ind (int): The order index of UpsampleUnit in UpsampleModule
        in_channels (int): Input Tensor Channels
        upsample_size (tuple): Upsample Tensor Shape
        output_channel_num (int): Output Tensor Channels
        output_shape (tuple): Output Tensor Shape
        channel_num (int): Interim Tensor Channels. Default: 256.
        generate_skip (bool): Whether to generate skip connection output tensor. Default: False.
        generate_cross_conv (bool): Whether to apply Convolution for the output Tensor to further feed into next stage \
        Module. Default: False.

    Inputs:
        Tensor, Upsample Tensor from Previous Layer.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> upsample_unit = UpsampleUnit(0, 3, (128, 128), 16, (256, 256), channel_num=256, generate_skip=False)
    """
    def __init__(self,
                 ind: int,
                 in_channels: int,
                 upsample_size: tuple,
                 output_channel_num: int,
                 output_shape: tuple,
                 channel_num: int = 256,
                 generate_skip: bool = False,
                 generate_cross_conv: bool = False
                 ) -> None:
        super(UpsampleUnit, self).__init__()
        self.output_shape = output_shape
        self.resize_bilinear = nn.ResizeBilinear()
        self.u_skip = ConvBlock(in_channels, channel_num, kernel_size=1, stride=1, padding=0, use_bn=True,
                                use_relu=False)
        self.relu = nn.ReLU()

        self.ind = ind
        if self.ind > 0:
            self.upsample_size = upsample_size
            self.up_conv = ConvBlock(channel_num, channel_num, kernel_size=1, stride=1, padding=0, use_bn=True,
                                     use_relu=False)

        self.generate_skip = generate_skip
        if self.generate_skip:
            self.skip1 = ConvBlock(in_channels, in_channels, kernel_size=1, stride=1, padding=0, use_bn=True,
                                   use_relu=True)
            self.skip2 = ConvBlock(channel_num, in_channels, kernel_size=1, stride=1, padding=0, use_bn=True,
                                   use_relu=True)

        self.generate_cross_conv = generate_cross_conv
        if self.ind == 3 and self.generate_cross_conv:
            self.cross_conv = ConvBlock(channel_num, 64, kernel_size=1, stride=1, padding=0, use_bn=True,
                                        use_relu=True)

        self.res_conv1 = ConvBlock(channel_num, channel_num, kernel_size=1, stride=1, padding=0, use_bn=True,
                                   use_relu=True)
        self.res_conv2 = ConvBlock(channel_num, output_channel_num, kernel_size=3, stride=1, padding=1, use_bn=True,
                                   use_relu=False)

    def construct(self, x, up_x):
        """Construct Func"""
        out = self.u_skip(x)

        if self.ind > 0:
            up_x = self.resize_bilinear(up_x, size=self.upsample_size, align_corners=True)
            up_x = self.up_conv(up_x)
            out += up_x
        out = self.relu(out)

        res = self.res_conv1(out)
        res = self.res_conv2(res)
        res = self.resize_bilinear(res, size=self.output_shape, align_corners=True)

        skip_tensor_1 = None
        skip_tensor_2 = None
        if self.generate_skip:
            skip_tensor_1 = self.skip1(x)
            skip_tensor_2 = self.skip2(out)

        cross_conv = None
        if self.ind == 3 and self.generate_cross_conv:
            cross_conv = self.cross_conv(out)

        return out, res, skip_tensor_1, skip_tensor_2, cross_conv


class UpsampleModule(nn.Cell):
    """ Upsample Module for MSPN

    Args:
        output_channel_num (int): Output Tensor Channels
        output_shape (tuple): Output Tensor Shape
        channel_num (int): Interim Tensor Channels. Default: 256.
        generate_skip (bool): Whether to generate skip connection output tensor. Default: False.
        generate_cross_conv (bool): Whether to apply Convolution for the output Tensor to further feed into next stage \
        Module. Default: False.

    Inputs:
        Tensors from Output of Every Downsample Layer.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> upsample = UpsampleModule(64, (256, 256), channel_num=256, generate_skip=False, generate_cross_conv=False)
    """
    def __init__(self,
                 output_channel_num: int,
                 output_shape: tuple,
                 channel_num: int = 256,
                 generate_skip: bool = False,
                 generate_cross_conv: bool = False
                 ) -> None:
        super(UpsampleModule, self).__init__()
        self.in_channels = [2048, 1024, 512, 256]
        h, w = output_shape
        self.upsample_sizes = [(h // 8, w // 8), (h // 4, w // 4), (h // 2, w // 2), (h, w)]
        self.generate_skip = generate_skip
        self.generate_cross_conv = generate_cross_conv

        self.up1 = UpsampleUnit(0, self.in_channels[0], self.upsample_sizes[0], output_channel_num=output_channel_num,
                                output_shape=output_shape, channel_num=channel_num, generate_skip=self.generate_skip,
                                generate_cross_conv=self.generate_cross_conv)
        self.up2 = UpsampleUnit(1, self.in_channels[1], self.upsample_sizes[1], output_channel_num=output_channel_num,
                                output_shape=output_shape, channel_num=channel_num, generate_skip=self.generate_skip,
                                generate_cross_conv=self.generate_cross_conv)
        self.up3 = UpsampleUnit(2, self.in_channels[2], self.upsample_sizes[2], output_channel_num=output_channel_num,
                                output_shape=output_shape, channel_num=channel_num, generate_skip=self.generate_skip,
                                generate_cross_conv=self.generate_cross_conv)
        self.up4 = UpsampleUnit(3, self.in_channels[3], self.upsample_sizes[3], output_channel_num=output_channel_num,
                                output_shape=output_shape, channel_num=channel_num, generate_skip=self.generate_skip,
                                generate_cross_conv=self.generate_cross_conv)

    def construct(self, x4, x3, x2, x1):
        """Construct Func"""
        out1, res1, skip1_1, skip2_1, _ = self.up1(x4, None)
        out2, res2, skip1_2, skip2_2, _ = self.up2(x3, out1)
        out3, res3, skip1_3, skip2_3, _ = self.up3(x2, out2)
        _, res4, skip1_4, skip2_4, cross_conv = self.up4(x1, out3)

        # 'res' starts from small size
        res = [res1, res2, res3, res4]
        skip_tensor_1 = [skip1_4, skip1_3, skip1_2, skip1_1]
        skip_tensor_2 = [skip2_4, skip2_3, skip2_2, skip2_1]

        return res, skip_tensor_1, skip_tensor_2, cross_conv

### 单阶段模型构建

在构建出上述基本模块后，我们可以构建出MSPN中的单阶段网络结构模块。

In [None]:
class SingleStageModule(nn.Cell):
    """ Single Stage Module for MSPN

    Args:
        output_channel_num (int): Output Tensor Channels
        output_shape (tuple): Output Tensor Shape
        use_skip (bool): Whether to Use Cross Stage Feature Aggregation. Default: False.
        generate_skip (bool): Whether to generate skip connection output tensor. Default: False.
        generate_cross_conv (bool): Whether to apply Convolution for the output Tensor to further feed into next stage \
        Module. Default: False.
        channel_num (int): Interim Tensor Channels. Default: 256.
        zero_init_bn (bool): Whether to Use zero initialization On Weight of Batch Normalization. Default: False.

    Inputs:
        Tensor, skip_tensor of previous MSPN Stage

    Returns:
        Tensor, output tensor.

    Examples:
        >>> single_stage = SingleStageModule(128, (256, 256), use_skip=False, generate_skip=False)
    """
    def __init__(self,
                 output_channel_num: int,
                 output_shape: tuple,
                 use_skip: bool = False,
                 generate_skip: bool = False,
                 generate_cross_conv: bool = False,
                 channel_num: int = 256,
                 zero_init_bn: bool = False
                 ) -> None:
        super(SingleStageModule, self).__init__()
        self.use_skip = use_skip
        self.generate_skip = generate_skip
        self.generate_cross_conv = generate_cross_conv
        self.channel_num = channel_num
        self.zero_init_bn = zero_init_bn
        self.layers = [3, 4, 6, 3]
        self.downsample = ResNetDownsampleModule(Bottleneck, self.layers, self.use_skip, self.zero_init_bn)
        self.upsample = UpsampleModule(output_channel_num, output_shape, self.channel_num, self.generate_skip,
                                       self.generate_cross_conv)

    def construct(self, x, skip_tensor_1, skip_tensor_2):
        """Construct Func"""
        x4, x3, x2, x1 = self.downsample(x, skip_tensor_1, skip_tensor_2)
        res, skip_tensor_1, skip_tensor_2, cross_conv = self.upsample(x4, x3, x2, x1)

        return res, skip_tensor_1, skip_tensor_2, cross_conv

### 损失函数

MSPN的损失函数计算采取L2 Loss，即Mean Square Error（MSE）作为基础损失函数：

$$
MSE = \frac{\sum\limits_{i=1}^{n}(y_i-y_i^p)^2}{n}
$$

其中，$n$为样本数，$y_i$为样本真实值，$y_i^p$为模型对于该样本的预测值

在这之上，作者额外使用了两种辅助策略来计算损失，以便使模型达到更好的效果：

- Coarse to Fine Supervision（CTF）：作者对不同Stage的输出采取不同尺寸高斯卷积核处理后的Ground Truth进行损失计算。越靠前的Stage使用的卷积核尺寸越大，Ground Truth越“模糊”；越靠后的Stage使用的卷积核尺寸越小，Ground Truth越“精细”。这种Ground Truth的变化趋势正好对应着模型学习的阶段过程，示意图如下：

![Coarse to Fine Supervison](./images/ctf.png)

- Online Hard Keypoint Minning（OHKM）：作者采用了困难样本挖掘的思想对损失函数进行改进，即在每一个Stage只取损失最高的K个关键点作为损失计算，这种方式有利于网络加强对难学习样本的拟合性能，从而提高模型整体性能。

In [None]:
class StageLoss(nn.Cell):
    """ Stage Loss for Every Stage of MSPN

    Args:
        has_ohkm (bool): Whether to use Online Hard Key points Mining (OHKM). Default: False.
        topk (int): OHKM Top-k Largest Loss Hyper-parameter. Default: 8.
        vis_thresh_wo_ohkm (int): Joints Visible Thresh when has_ohkm sets to False. Default: 1.
        vis_thresh_w_ohkm (int): Joints Visible Thresh when has_ohkm sets to True. Default: 0.

    Inputs:
        MSPN Output Tensor, Keypoints Visible Tensor, Keypoints Ground Truth Tensor

    Returns:
        Tensor of Loss.

    Examples:
        >>> stage_loss = StageLoss(has_ohkm=True, topk=8)
    """
    def __init__(self,
                 has_ohkm: bool = False,
                 topk: int = 8,
                 vis_thresh_wo_ohkm: int = 1,
                 vis_thresh_w_ohkm: int = 0
                 ) -> None:
        super(StageLoss, self).__init__()
        self.has_ohkm = has_ohkm
        self.topk = topk
        self.t1 = vis_thresh_wo_ohkm
        self.t2 = vis_thresh_w_ohkm
        method = 'none' if self.has_ohkm else 'mean'
        self.calculate = nn.MSELoss(reduction=method)

    def construct(self, output, valid, label):
        """Construct Func"""
        greater = mindspore.ops.Greater()
        topk = mindspore.ops.TopK(sorted=False)
        batch_size = output.shape[0]
        keypoint_num = output.shape[1]
        loss = 0

        for i in range(batch_size):
            pred = output[i].reshape(keypoint_num, -1)
            gt = label[i].reshape(keypoint_num, -1)
            if not self.has_ohkm:
                weight = greater(valid[i], self.t1).astype(mindspore.float32)
                gt = gt * weight

            tmp_loss = self.calculate(pred, gt)
            if self.has_ohkm:
                tmp_loss = tmp_loss.mean(axis=1)
                weight = greater(valid[i].squeeze(), self.t2).astype(mindspore.float32)
                tmp_loss = tmp_loss * weight
                topk_val, _ = topk(tmp_loss, self.topk)
                sample_loss = topk_val.mean(axis=0)
            else:
                sample_loss = tmp_loss

            loss = loss + sample_loss

        return loss / batch_size


class JointsL2Loss(nn.Cell):
    """ JointsL2 Loss for MSPN

    Args:
        stage_num (int): MSPN Stage Num
        ctf (bool): Whether to enable Coarse-to-Fine Supervision. Default: True.
        has_ohkm (bool): Whether to use Online Hard Key points Mining (OHKM). Default: False.
        topk (int): OHKM Top-k Largest Loss Hyper-parameter. Default: 8.

    Inputs:
        MSPN Output Tensor, Keypoints Visible Tensor, Keypoints Ground Truth Tensor

    Returns:
        Tensor of Loss.

    Examples:
        >>> joints_loss = JointsL2Loss(stage_num=0, ctf=True, has_ohkm=True, topk=8)
    """
    def __init__(self,
                 stage_num: int,
                 ctf: bool = True,
                 has_ohkm: bool = False,
                 topk: int = 8
                 ) -> None:
        super(JointsL2Loss, self).__init__()
        self.stage_num = stage_num
        self.ctf = ctf
        self.ohkm = has_ohkm
        self.topk = topk
        self.loss = StageLoss()
        self.loss_ohkm = StageLoss(has_ohkm=self.ohkm, topk=self.topk)

    def construct(self, outputs, valids, labels):
        """Construct Func"""
        loss = 0
        for i in range(self.stage_num):
            for j in range(4):
                ind = j
                if i == self.stage_num - 1 and self.ctf:
                    ind += 1
                tmp_labels = labels[:, ind, :, :, :]

                if j == 3 and self.ohkm:
                    tmp_loss = self.loss_ohkm(outputs[i][j], valids, tmp_labels)
                else:
                    tmp_loss = self.loss(outputs[i][j], valids, tmp_labels)

                if j < 3:
                    tmp_loss = tmp_loss / 4

                loss += tmp_loss

        return loss

### MSPN模型实现

在前述介绍中，我们从最基础的卷积块开始，依次实现了下采样、上采样、MSPN单阶段网络结构以及损失函数。接下来我们综合上述实现的子模块对MSPN模型进行实现。

In [None]:
class MSPN(nn.Cell):
    """ MSPN

    Args:
        total_stage_num (int): MSPN Stage Num
        output_channel_num (int): Output Tensor Channels
        output_shape (tuple): Output Tensor Shape
        upsample_channel_num (int): Upsample Tensor Channels
        online_hard_key_mining (bool): Whether to use Online Hard Key points Mining (OHKM). Default: True.
        topk_keys (int): OHKM Top-k Largest Loss Hyper-parameter. Default: 8.
        coarse_to_fine (bool): Whether to enable Coarse-to-Fine Supervision. Default: True.

    Inputs:
        Image Tensor, Keypoints Visible Tensor, Keypoints Ground Truth.

    Returns:
        Tensor, output tensor.

    Examples:
        >>> mspn = MSPN(2, 17, (64, 48), 256, ohkm=True, topk=8, ctf=True)
    """
    def __init__(self,
                 total_stage_num: int,
                 output_channel_num: int,
                 output_shape: tuple,
                 upsample_channel_num: int,
                 online_hard_key_mining: bool = True,
                 topk_keys: int = 8,
                 coarse_to_fine: bool = True,
                 ) -> None:
        super(MSPN, self).__init__()
        self.top = ResNetTop()
        self.total_stage_num = total_stage_num
        self.output_channel_num = output_channel_num
        self.output_shape = output_shape
        self.upsample_channel_num = upsample_channel_num
        self.online_hard_key_mining = online_hard_key_mining
        self.topk_keys = topk_keys
        self.coarse_to_fine = coarse_to_fine
        self.mspn_modules = list()
        for i in range(self.total_stage_num):
            if i == 0:
                use_skip = False
            else:
                use_skip = True
            if i != self.total_stage_num - 1:
                generate_skip = True
                generate_cross_conv = True
            else:
                generate_skip = False
                generate_cross_conv = False
            self.mspn_modules.append(
                SingleStageModule(
                    self.output_channel_num, self.output_shape,
                    use_skip=use_skip, generate_skip=generate_skip,
                    generate_cross_conv=generate_cross_conv,
                    channel_num=self.upsample_channel_num,
                )
            )
            setattr(self, 'stage%d' % i, self.mspn_modules[i])
        self.loss = JointsL2Loss(stage_num=self.total_stage_num, ctf=self.coarse_to_fine,
                                 has_ohkm=self.online_hard_key_mining, topk=self.topk_keys)

    def construct(self, imgs, valids=None, labels=None):
        """Construct Func"""
        x = self.top(imgs)
        skip_tensor_1 = None
        skip_tensor_2 = None
        outputs = list()
        for i in range(self.total_stage_num):
            res, skip_tensor_1, skip_tensor_2, x = self.mspn_modules[i](x, skip_tensor_1, skip_tensor_2)
            outputs.append(res)

        if valids is None and labels is None:
            return outputs[-1][-1]

        return self.loss(outputs, valids, labels)

## 模型训练

在完成上述模块搭建工作后，我们可以对MSPN模型进行训练。为了简化训练过程，我们仅对模型训练1个epoch，且MSPN的阶段数设为2，以便于适应不同性能的计算设备。

In [None]:
from mindspore import context
from mindspore.train.callback import LossMonitor


context.set_context(mode=context.PYNATIVE_MODE, device_id=0, device_target="GPU")

step_size = dataloader.get_dataset_size()
EPOCH = 1

# 实例化MSPN模型
net = MSPN(total_stage_num=2,
           output_channel_num=17,
           output_shape=[64, 48],
           upsample_channel_num=256,
           online_hard_key_mining=True,
           topk_keys=8,
           coarse_to_fine=True)

# 设定学习率调整器
lr = nn.cosine_decay_lr(min_lr=0.0,
                        max_lr=5e-4,
                        total_step=EPOCH * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=EPOCH)

# 设定优化器
optimizer = nn.Adam(net.trainable_params(), learning_rate=lr)

# 加载预训练权重
mindspore.load_param_into_net(net, mindspore.load_checkpoint("mspn.ckpt"))

# 初始化模型
model = mindspore.Model(net, optimizer=optimizer)

# 开始训练
model.train(epoch=EPOCH, train_dataset=dataloader, callbacks=LossMonitor())

# Save Model ckpt
mindspore.save_checkpoint(net, "./mspn_new.ckpt")

## 模型评估

在训练完成后，我们对训练收敛的模型在MS COCO验证集上进行评估。

为了更全面地衡量MSPN模型的性能，我们沿用MSPN原文的AP、AR指标对模型进行评估。类似于目标检测中使用IoU对预测结果与Ground Truth进行相似度度量，从而根据IoU设定阈值计算出各种AP、AR值。类似地，在姿态识别领域，MS COCO官方对关键点也给出了类似于IoU的定义：

$$
O K S=\frac{\sum_{i} \exp \left[\frac{-d_{i}^{2}}{2 s^{2} k_{i}^{2}} \delta\left(v_{i}>0\right)\right]}{\sum_{i} \delta\left(v_{i}>0\right)}
$$

OKS全称Object Keypoint Similarity。其中，$i$为关键点个数；$d_i^2$表示关键点$i$预测值与Ground Truth之间的欧氏距离；$s$表示Ground Truth行人的尺度因子，其值为行人检测框面积的平方根，这里的$s$与关键点无关；$k_i$表示第$i$个关键点的归一化因子的2倍，这个因子是通过对所有的样本集中的Ground Truth关键点由人工标注与真实值存在的标准差，越大表示此类型的关键点越难标注，而$k_i$越大，对应的OKS的值就会越大；$v_i$表示第$i$个关键点的可见性，$0$表示关键点未标记，$1$表示无遮挡并且已经标记，$2$表示有遮挡但是已经标记。

In [None]:
import json

from mindspore import context
from src.utils.mspn_utils import compute_on_dataset, evaluate


FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]

context.set_context(mode=context.PYNATIVE_MODE, device_id=0, device_target="GPU")

normalize = Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
transform = Compose([normalize, py_trans.ToTensor()])
dataset = COCODataset(data_dir="/data0/coco/coco2014",
                      det_file_path="./annotation/det_json/minival_2014_det.json",
                      gt_file_path="./annotation/gt_json/minival_2014.json",
                      keypoint_num=17,
                      flip_pairs=FLIP_PAIRS,
                      upper_body_ids=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                      lower_body_ids=[11, 12, 13, 14, 15, 16],
                      input_shape=[256, 192],
                      output_shape=[64, 48],
                      stage='val'
                      )
dataloader = ds.GeneratorDataset(source=dataset,
                                 column_names=["img", "score", "center", "scale", "img_id"],
                                 shuffle=False,
                                 num_parallel_workers=4)
dataloader = dataloader.map(operations=transform, input_columns=["img"]).batch(32)

# 实例化MSPN模型
net = MSPN(total_stage_num=2,
           output_channel_num=17,
           output_shape=[64, 48],
           upsample_channel_num=256,
           online_hard_key_mining=True,
           topk_keys=8,
           coarse_to_fine=True)

# 加载预训练权重
mindspore.load_param_into_net(net, mindspore.load_checkpoint("mspn_new.ckpt"))

# 初始化模型
model = mindspore.Model(net)

# 开始评估
results = compute_on_dataset(model=model,
                             dataloader=dataloader,
                             flip_pairs=FLIP_PAIRS,
                             keypoint_num=17,
                             input_shape=[256, 192],
                             output_shape=[64, 48])
results.sort(key=lambda res: (res['image_id'], res['score']), reverse=True)
with open('./results.json', 'w') as f:
    json.dump(results, f)

evaluate(val_gt_path="./annotation/gt_json/minival_2014.json",
         pred_path='./results.json')

## 模型推理

由于MSPN是Top-Down模型，即需要先检测出人体检测框再对各个检测框进行关键点检测，所以MSPN模型推理需要额外提供待推理图片的人体检测框标注JSON文件，单个检测框的标注格式如下所示：

In [None]:
{
    "category_id": 1,
    "image_id": 398905,
    "bbox": [
        216.39,
        32.29,
        295.74,
        333.67
    ],
    "score": 0.9951
}

- category_id：在MSPN中恒为1，与MS COCO Detection的行人标注标签一致。
- image_id：待推理图像的文件名ID
- bbox：检测框的尺寸信息XYHW，分别表示左上角点的横坐标与纵坐标、检测框的长与宽
- score：检测框的置信度

将待推理的图像修改名称为非0开头的纯数字文件名，并放置于./infer_img文件夹下；将待推理图像的检测框信息JSON文件放置于./annotation/det_json文件夹下。可同时将多张图片存放于./infer_img文件下一次性全部进行推理，但需保证标注JSON文件中的image_id必须和./infer_img文件夹下的待推理图片名称相对应。

In [None]:
from mindspore import context
from src.utils.mspn_utils import compute_on_dataset, visualize


FLIP_PAIRS = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]

context.set_context(mode=context.PYNATIVE_MODE, device_id=0, device_target="GPU")

normalize = Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
transform = Compose([normalize, py_trans.ToTensor()])
dataset = COCODataset(data_dir="./infer_img",
                      det_file_path="./annotation/det_json/test.json",
                      keypoint_num=17,
                      flip_pairs=FLIP_PAIRS,
                      upper_body_ids=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
                      lower_body_ids=[11, 12, 13, 14, 15, 16],
                      input_shape=[256, 192],
                      output_shape=[64, 48],
                      stage='test'
                      )
dataloader = ds.GeneratorDataset(source=dataset,
                                 column_names=["img", "score", "center", "scale", "img_id"],
                                 shuffle=False,
                                 num_parallel_workers=4)
dataloader = dataloader.map(operations=transform, input_columns=["img"]).batch(32)

# 实例化MSPN模型
net = MSPN(total_stage_num=2,
           output_channel_num=17,
           output_shape=[64, 48],
           upsample_channel_num=256,
           online_hard_key_mining=True,
           topk_keys=8,
           coarse_to_fine=True)

# 加载预训练权重
mindspore.load_param_into_net(net, mindspore.load_checkpoint("mspn_new.ckpt"))

# 初始化模型
model = mindspore.Model(net)

# 开始评估
results = compute_on_dataset(model=model,
                             dataloader=dataloader,
                             flip_pairs=FLIP_PAIRS,
                             keypoint_num=17,
                             input_shape=[256, 192],
                             output_shape=[64, 48])
results.sort(key=lambda res: (res['image_id'], res['score']), reverse=True)
visualize(results=results,
          infer_dir="./infer_img",
          save_dir="./res_img",
          score_thre=0.7)

![Image](./images/90891_res.jpg)

## 总结

本案例对MSPN模型进行了详细的解释，向读者展现了算法从训练、评估到推理的完整流程，分析了MSPN中解决的若干问题。如需查看详细代码，可参考MindSpore Vision套件。

## 引用

[1] Wenbo Li, Zhicheng Wang, Binyi Yin, et al. Rethinking on Multi-Stage Networks for Human Pose Estimation[J]. arXiv preprint arXiv: 1901.00148v4, 2019.