一、数据集准备(我已经上传了处理好的数据，这部分不用运行)
    1、下载LUNA16数据集到本地，放在data文件夹中
    2、进入路径./data/LUNA16, 将得到的subset0-9.rar和seg-lungs-LUNA16.rar共11个文件解压
    3、设置路径，运行下面代码，划分训练集和验证集，并且将数据格式转化为niffi，得到如下的文件结构
    ./train
        /image
        /seg
    ./val
        /image
        /seg
    注：
        （1）LUNA16数据集共887个volume，我们选择subset9的最后三个volume作为验证集，subset0-8进而subset9剩下的部分作为训练集。
        （2）可以通过下载ITK-SNAP软件可视化图片和分割标签

In [None]:
# Function: Adjust Data Content Structure & Convert volume to nifti format
# Author: Zhang Zhongzhou
# Date: 2022/9/16
# LUNA16
#   /train
#       /image
#       /seg
#   /val
#       /image
#       /seg
import os
import os.path as osp
import shutil
import SimpleITK as sitk
import glob
from tqdm import tqdm
from src.model_utils.config import config

def convert_nifti(input_im, save_image, roi_size):
    """
    Convert dataset into mifti format.

    Args:
        input_im(str): input image name
        save_image(str): output image name
        roi_size(list): The size to crop the image
    """
    img = sitk.ReadImage(input_im)
    image_array = sitk.GetArrayFromImage(img)
    D, H, W = image_array.shape
    if H < roi_size[0] or W < roi_size[1] or D < roi_size[2]:
        print("file {} size is smaller than roi size, ignore it.".format(input_im))
        # continue
    sitk.WriteImage(img, save_image)


if __name__ == "__main__":

    root = "../data/LUNA16"
    image_fold = ["subset0","subset1","subset2","subset3","subset4","subset5","subset6","subset7","subset8","subset9"]
    label_fold = "seg-lungs-LUNA16"

    save_train_image_path = osp.join(root, "train", "image")
    save_train_seg_path = osp.join(root, "train", "seg")
    save_val_image_path = osp.join(root, "val", "image")
    save_val_seg_path = osp.join(root, "val", "seg")

    # create save path
    if not os.path.exists(save_train_image_path):
        os.makedirs(save_train_image_path)
    if not os.path.exists(save_train_seg_path):
        os.makedirs(save_train_seg_path)
    if not os.path.exists(save_val_image_path):
        os.makedirs(save_val_image_path)
    if not os.path.exists(save_val_seg_path):
        os.makedirs(save_val_seg_path)

    # split train and val
    # Here we select the last 10 volume in subset9 for validation
    for subset in image_fold:
        image_list = sorted(glob.glob(osp.join(root, subset, "*.mhd")))
        for im in tqdm(image_list):
                im_name = im.split("/")[-1][:-4]
                seg = osp.join(root, label_fold, im_name + ".mhd")
                save_im_name = osp.join(save_train_image_path, im_name+".nii.gz")
                save_seg_name = osp.join(save_train_seg_path, im_name + ".nii.gz")
                convert_nifti(im, save_im_name, config.roi_size)
                convert_nifti(seg, save_seg_name, config.roi_size)

    all_images = sorted(os.listdir(save_train_image_path))[-10:]
    all_segs = sorted(os.listdir(save_train_seg_path))[-10:]
    for i, name in enumerate(all_images):
        sou_im = osp.join(save_train_image_path, name)
        sou_seg = osp.join(save_train_seg_path, name)
        shutil.move(sou_im, save_val_image_path)
        shutil.move(sou_seg, save_val_seg_path)

二、参数设置（训练参数和dataloader参数）
在训练过程中的参数设置，如图片大小、训练batchsize大小等

In [20]:
!pip install -r requirements.txt


import ml_collections
import warnings
# warnings.filterwarnings("ignore")

def get_config():
    """
    Get Config according to the yaml file and cli arguments.
    """
    cfg = ml_collections.ConfigDict()
    # Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
    # args = argparse.ArgumentParser(description="default name", add_help=False)
    cfg.enable_fp16_gpu=False
    cfg.enable_modelarts=False
    # Url for modelarts
    cfg.data_url=""
    cfg.train_ur=""
    cfg.checkpoint_url=""
    # Path for local
    cfg.run_distribute=False
    cfg.enable_profiling=False
    cfg.data_path= "data/LUNA16/train/"
    cfg.output_path="output"
    cfg.load_path="/checkpoint_path/"
    cfg.device_target = "GPU"  # or "GPU"
    cfg.checkpoint_path="./checkpoint/"
    cfg.checkpoint_file_path="Unet3d-10-110.ckpt"

    # ==============================================================================
    # data loader options
    cfg.num_worker = 4
    # Training options
    cfg.max_epoch = 3
    cfg.lr=0.0005
    cfg.batch_size=1
    cfg.epoch_size=10
    cfg.warmup_step=120
    cfg.warmup_ratio=0.3
    cfg.num_classes=4
    cfg.in_channels=1
    cfg.keep_checkpoint_max=1
    cfg.loss_scale=256.0
    cfg.roi_size=[224, 224, 96]
    cfg.overlap=0.25
    cfg.min_val=-500
    cfg.max_val=1000
    cfg.upper_limit=5
    cfg.lower_limit=3

    # Export options
    cfg.device_id=0
    cfg.ckpt_file="./checkpoint/Unet3d-10-110.ckpt"
    cfg.file_name="unet3d"
    cfg.file_format="MINDIR"

    # 310 infer options
    cfg.pre_result_path="./preprocess_Result"
    cfg.post_result_path="./result_Files"

    return cfg

config = get_config()
print(config)

batch_size: 1
checkpoint_file_path: Unet3d-10-110.ckpt
checkpoint_path: ./checkpoint/
checkpoint_url: ''
ckpt_file: ./checkpoint/Unet3d-10-110.ckpt
data_path: data/LUNA16/train/
data_url: ''
device_id: 0
device_target: GPU
enable_fp16_gpu: false
enable_modelarts: false
enable_profiling: false
epoch_size: 10
file_format: MINDIR
file_name: unet3d
in_channels: 1
keep_checkpoint_max: 1
load_path: /checkpoint_path/
loss_scale: 256.0
lower_limit: 3
lr: 0.0005
max_epoch: 3
max_val: 1000
min_val: -500
num_classes: 4
num_worker: 4
output_path: output
overlap: 0.25
post_result_path: ./result_Files
pre_result_path: ./preprocess_Result
roi_size:
- 224
- 224
- 96
run_distribute: false
train_ur: ''
upper_limit: 5
warmup_ratio: 0.3
warmup_step: 120



三、创建不同的数据集增强方式
完成上述的文件格式转换之后，并进一步划分了训练和测试数据集，但是直接将图片数据送入网络训练，结果往往不太理想，因此需要通过不同的transform操作进行数据集增强，数据增强的方式包括：ExpandChannel、ScaleIntensityRange、RandomCropSamples、OneHot等

In [21]:
# transforms
import re
import numpy as np
import logger

np_str_obj_array_pattern = re.compile(r'[SaUO]')

def correct_nifti_head(img):
    """
    Check nifti object header's format, update the header if needed.
    In the updated image pixdim matches the affine.

    Args:
        img: nifti image object
    """
    dim = img.header["dim"][0]
    if dim >= 5:
        return img
    pixdim = np.asarray(img.header.get_zooms())[:dim]
    norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
    if np.allclose(pixdim, norm_affine):
        return img
    # if hasattr(img, "get_sform"):
    #     return rectify_header_sform_qform(img)
    return img

class LoadData:
    """
    Load Image data from provided files.
    """
    def __init__(self, canonical=False, dtype=np.float32):
        """
        Args:
        canonical: if True, load the image as closest to canonical axis format.
        dtype: convert the loaded image to this data type.
        """
        self.canonical = canonical
        self.dtype = dtype

    def operation(self, filename):
        img_array = list()
        compatible_meta = dict()
        filename = filename.item()
        filename = [filename]
        for name in filename:
            img = nib.load(str(name)[2:-1])
            img = correct_nifti_head(img)
            header = dict(img.header)
            header["filename_or_obj"] = name
            header["affine"] = img.affine
            header["original_affine"] = img.affine.copy()
            header["canonical"] = self.canonical
            ndim = img.header["dim"][0]
            spatial_rank = min(ndim, 3)
            header["spatial_shape"] = img.header["dim"][1 : spatial_rank + 1]
            if self.canonical:
                img = nib.as_closest_canonical(img)
                header["affine"] = img.affine
            img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
            img.uncache()
            if not compatible_meta:
                for meta_key in header:
                    meta_datum = header[meta_key]
                    if isinstance(meta_datum, np.ndarray) \
                        and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
                        continue
                    compatible_meta[meta_key] = meta_datum
            else:
                assert np.allclose(header["affine"], compatible_meta["affine"])

        img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
        return img_array

    def __call__(self, img, label):
        img_array = self.operation(img)
        seg_array = self.operation(label)
        return img_array, seg_array

class ExpandChannel:
    """
    Expand a 1-length channel dimension to the input image.
    """
    def __call__(self, img, label):
        img_array = img[None]
        seg_array = label[None]
        return img_array, seg_array

class Orientation:
    """
    Change the input image's orientation into the specified based on `ax`.
    """
    def __init__(self, ax="RAS", labels=tuple(zip("LPI", "RAS"))):
        """
        Args:
            ax: N elements sequence for ND input's orientation.
            labels: optional, None or sequence of (2,) sequences
                (2,) sequences are labels for (beginning, end) of output axis.
        """
        self.ax = ax
        self.labels = labels

    def operation(self, data, affine=None):
        """
        original orientation of `data` is defined by `affine`.

        Args:
            data: in shape (num_channels, H[, W, ...]).
            affine (matrix): (N+1)x(N+1) original affine matrix for spatially ND `data`. Defaults to identity.

        Returns:
            data (reoriented in `self.ax`), original ax, current ax.
        """
        if data.ndim <= 1:
            raise ValueError("data must have at least one spatial dimension.")
        if affine is None:
            affine = np.eye(data.ndim, dtype=np.float64)
            affine_copy = affine
        else:
            affine_copy = to_affine_nd(data.ndim-1, affine)
        src = nib.io_orientation(affine_copy)
        dst = nib.orientations.axcodes2ornt(self.ax[:data.ndim-1], labels=self.labels)
        spatial_ornt = nib.orientations.ornt_transform(src, dst)
        ornt = spatial_ornt.copy()
        ornt[:, 0] += 1
        ornt = np.concatenate([np.array([[0, 1]]), ornt])
        data = nib.orientations.apply_orientation(data, ornt)
        return data

    def __call__(self, img, label):
        img_array = self.operation(img)
        seg_array = self.operation(label)
        return img_array, seg_array

class ScaleIntensityRange:
    """
    Apply specific intensity scaling to the whole numpy array.
    Scaling from [src_min, src_max] to [tgt_min, tgt_max] with clip option.

    Args:
        src_min: intensity original range min.
        src_max: intensity original range max.
        tgt_min: intensity target range min.
        tgt_max: intensity target range max.
        is_clip: whether to clip after scaling.
    """
    def __init__(self, src_min, src_max, tgt_min, tgt_max, is_clip=False):
        self.src_min = src_min
        self.src_max = src_max
        self.tgt_min = tgt_min
        self.tgt_max = tgt_max
        self.is_clip = is_clip

    def operation(self, data):
        if self.src_max - self.src_min == 0.0:
            logger.warning("Divide by zero (src_min == src_max)")
            return data - self.src_min + self.tgt_min
        data = (data - self.src_min) / (self.src_max - self.src_min)
        data = data * (self.tgt_max - self.tgt_min) + self.tgt_min
        if self.is_clip:
            data = np.clip(data, self.tgt_min, self.tgt_max)
        return data

    def __call__(self, image, label):
        image = self.operation(image)
        return image, label

class RandomCropSamples:
    def __init__(self, roi_size, num_samples=1):
        self.roi_size = roi_size
        self.num_samples = num_samples
        self.set_random_state(0)

    def set_random_state(self, seed=None):
        """
        Set the random seed to control the slice size.

        Args:
            seed: set the random state with an integer seed.
        """
        MAX_SEED = np.iinfo(np.uint32).max + 1
        if seed is not None:
            _seed = seed % MAX_SEED
            self.rand_fn = np.random.RandomState(_seed)
        else:
            self.rand_fn = np.random.RandomState()
        return self

    def get_random_patch(self, dims, patch_size, rand_fn=None):
        """
        Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
        """
        rand_int = np.random.randint if rand_fn is None else rand_fn.randint
        min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
        return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))

    def get_random_slice(self, img_size):
        slices = (slice(None),) + self.get_random_patch(img_size, self.roi_size, self.rand_fn)
        return slices

    def __call__(self, image, label):
        res_image = []
        res_label = []
        for _ in range(self.num_samples):
            slices = self.get_random_slice(image.shape[1:])
            img = image[slices]
            label_crop = label[slices]
            res_image.append(img)
            res_label.append(label_crop)
        return np.array(res_image), np.array(res_label)

class OneHot:
    def __init__(self, num_classes):
        self.num_classes = num_classes

    def one_hot(self, labels):
        N, K = labels.shape
        one_hot_encoding = np.zeros((N, self.num_classes, K), dtype=np.float32)
        for i in range(N):
            for j in range(K):
                one_hot_encoding[i, labels[i][j], j] = 1
        return one_hot_encoding

    def operation(self, labels):
        N, _, D, H, W = labels.shape
        labels = labels.astype(np.int32)
        labels = np.reshape(labels, (N, -1))
        labels = self.one_hot(labels)
        labels = np.reshape(labels, (N, self.num_classes, D, H, W))
        return labels

    def __call__(self, image, label):
        label = self.operation(label)
        return image, label

class ConvertLabel:
    """
    Crop at the center of image with specified ROI size.
    Args:
        roi_size: the spatial size of the crop region e.g. [224,224,128]
        If its components have non-positive values, the corresponding size of input image will be used.
    """
    def operation(self, data):
        """
        Apply the transform to `img`, assuming `img` is channel-first and
        slicing doesn't apply to the channel dim.
        """
        data[data > config.upper_limit] = 0
        data = data - (config.lower_limit - 1)
        data = np.clip(data, 0, config.lower_limit)
        return data

    def __call__(self, image, label):
        label = self.operation(label)
        return image, label


四、创建Dataloader
设置好数据预处理之后，接下来需要定义一个可迭代的Dataloader用于数据加载，然后送入网络

In [22]:
import glob
import mindspore.dataset as ds
from mindspore.dataset.transforms.transforms import Compose
import nibabel as nib
import os

class Dataset:
    def __init__(self, data, seg):
        self.data = data
        self.seg = seg
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        data = self.data[index]
        seg = self.seg[index]
        return [data], [seg]

def create_dataset(data_path, seg_path, rank_size=1, rank_id=0, is_training=True):
    # print(seg_path)
    seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
    train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
    # print(train_files)
    train_ds = Dataset(data=train_files, seg=seg_files)
    train_loader = ds.GeneratorDataset(train_ds, column_names=["image", "seg"], num_parallel_workers=config.num_worker, \
                                       shuffle=is_training, num_shards=rank_size, shard_id=rank_id)

    if is_training:
        transform_image = Compose([LoadData(),
                                   ExpandChannel(),
                                   Orientation(),
                                   ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
                                                       tgt_max=1.0, is_clip=True),
                                   RandomCropSamples(roi_size=config.roi_size, num_samples=4),
                                   ConvertLabel(),
                                   OneHot(num_classes=config.num_classes)])
    else:
        transform_image = Compose([LoadData(),
                                   ExpandChannel(),
                                   Orientation(),
                                   ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
                                                       tgt_max=1.0, is_clip=True),
                                   ConvertLabel()])

    train_loader = train_loader.map(operations=transform_image,
                                    input_columns=["image", "seg"],
                                    num_parallel_workers=config.num_worker,
                                    python_multiprocessing=True)
    if not is_training:
        train_loader = train_loader.batch(1)
    return train_loader

if __name__=="__main__":
    # cfg.data_path
    # print(config.data_path)
    train_dataset = create_dataset(data_path=config.data_path+"image",\
                                   seg_path=config.data_path+"seg", \
                                   is_training=True)
    train_data_size = train_dataset.get_dataset_size()
    print("create dataloader successfully!!")
    print("train dataset length is:", train_data_size)

create dataloader successfully!!
train dataset length is: 877


五、构建Unet3D网络结构
构建Unet3D网络，包括Encoder和Decoder两部分，Encoder有4个下采样层；Decoder有4个上采样层，最后的输出和原图大小相同的分割结果。

In [23]:
import mindspore as ms
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore import Tensor, Model, context
# from src.unet3d_parts import Down, Up
import numpy as np

class BatchNorm3d(nn.Cell):
    def __init__(self, num_features):
        super().__init__()
        self.reshape = P.Reshape()
        self.shape = P.Shape()
        self.bn2d = nn.BatchNorm2d(num_features, data_format="NCHW")

    def construct(self, x):
        x_shape = self.shape(x)
        x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
        bn2d_out = self.bn2d(x)
        bn3d_out = self.reshape(bn2d_out, x_shape)
        return bn3d_out

class ResidualUnit(nn.Cell):
    def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), down=True, is_output=False):
        super().__init__()
        self.stride = stride
        self.down = down
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.down_conv_1 = nn.Conv3d(in_channel, out_channel, kernel_size=(3, 3, 3), \
                                     pad_mode="pad", stride=self.stride, padding=1)
        self.is_output = is_output
        if not is_output:
            self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
            self.relu1 = nn.PReLU()
        if self.down:
            self.down_conv_2 = nn.Conv3d(out_channel, out_channel, kernel_size=(3, 3, 3), \
                                         pad_mode="pad", stride=1, padding=1)
            self.relu2 = nn.PReLU()
            if kernel_size[0] == 1:
                self.residual = nn.Conv3d(in_channel, out_channel, kernel_size=(1, 1, 1), \
                                          pad_mode="valid", stride=self.stride)
            else:
                self.residual = nn.Conv3d(in_channel, out_channel, kernel_size=(3, 3, 3), \
                                          pad_mode="pad", stride=self.stride, padding=1)
            self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)


    def construct(self, x):
        out = self.down_conv_1(x)
        if self.is_output:
            return out
        out = self.batchNormal1(out)
        out = self.relu1(out)
        if self.down:
            out = self.down_conv_2(out)
            out = self.batchNormal2(out)
            out = self.relu2(out)
            res = self.residual(x)
        else:
            res = x
        return out + res

class Down(nn.Cell):
    def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), dtype=mstype.float16):
        super().__init__()
        self.stride = stride
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.down_conv = ResidualUnit(self.in_channel, self.out_channel, stride, kernel_size).to_float(dtype)

    def construct(self, x):
        x = self.down_conv(x)
        return x


class Up(nn.Cell):
    def __init__(self, in_channel, down_in_channel, out_channel, stride=2, is_output=False, dtype=mstype.float16):
        super().__init__()
        self.in_channel = in_channel
        self.down_in_channel = down_in_channel
        self.out_channel = out_channel
        self.stride = stride
        self.conv3d_transpose = nn.Conv3dTranspose(in_channels=self.in_channel + self.down_in_channel, \
                                                   out_channels=self.out_channel, kernel_size=(3, 3, 3), \
                                                   pad_mode="pad", stride=self.stride, \
                                                   output_padding=(1, 1, 1), padding=1)

        self.concat = P.Concat(axis=1)
        self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
                                 is_output=is_output).to_float(dtype)
        self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
        self.relu = nn.PReLU()

    def construct(self, input_data, down_input):
        x = self.concat((input_data, down_input))
        x = self.conv3d_transpose(x)
        x = self.batchNormal1(x)
        x = self.relu(x)
        x = self.conv(x)
        return x

class UNet3d_(nn.Cell):
    """
    UNet3d_ support fp32 and fp16(amp) training on GPU.
    """
    def __init__(self):
        super(UNet3d_, self).__init__()
        self.n_channels = config.in_channels
        self.n_classes = config.num_classes

        # down
        self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float32)
        self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float32)
        self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float32)
        self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float32)
        self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
                          dtype=mstype.float32)
        # up
        self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
                      dtype=mstype.float32)
        self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
                      dtype=mstype.float32)
        self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
                      dtype=mstype.float32)
        self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
                      dtype=mstype.float32, is_output=True)

    def construct(self, input_data):
        x1 = self.down1(input_data)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return x

class UNet3d(nn.Cell):
    def __init__(self):
        super(UNet3d, self).__init__()
        self.n_channels = config.in_channels
        self.n_classes = config.num_classes

        # down
        self.transpose = P.Transpose()
        self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float16).to_float(mstype.float16)
        self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float16).to_float(mstype.float16)
        self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float16).to_float(mstype.float16)
        self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float16).to_float(mstype.float16)
        self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
                          dtype=mstype.float16).to_float(mstype.float16)
        # up
        self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
                      dtype=mstype.float16).to_float(mstype.float16)
        self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
                      dtype=mstype.float16).to_float(mstype.float16)
        self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
                      dtype=mstype.float16).to_float(mstype.float16)
        self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
                      dtype=mstype.float16, is_output=True).to_float(mstype.float16)

        self.cast = P.Cast()

    def construct(self, input_data):
        input_data = self.cast(input_data, mstype.float16)
        x1 = self.down1(input_data)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.cast(x, mstype.float32)
        return x

if __name__=="__main__":
    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
    x = ms.Tensor(np.zeros([1, 1, 224, 224, 96]), ms.float32)
    model = UNet3d().set_train(True)
    out = model(x)

    print(model)
    print("UNet3D model input:", x.shape)
    print("UNet3D model output:", out.shape)


UNet3d<
  (down1): Down<
    (down_conv): ResidualUnit<
      (down_conv_1): Conv3d<input_channels=1, output_channels=16, kernel_size=(3, 3, 3), stride=(2, 2, 2), pad_mode=pad, padding=1, dilation=(1, 1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCDHW>
      (batchNormal1): BatchNorm3d<
        (bn2d): BatchNorm2d<num_features=16, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=down1.down_conv.batchNormal1.bn2d.gamma, shape=(16,), dtype=Float32, requires_grad=True), beta=Parameter (name=down1.down_conv.batchNormal1.bn2d.beta, shape=(16,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=down1.down_conv.batchNormal1.bn2d.moving_mean, shape=(16,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=down1.down_conv.batchNormal1.bn2d.moving_variance, shape=(16,), dtype=Float32, requires_grad=False)>
        >
      (relu1): PReLU<>
      (down_conv_2): Conv3d<input_channels=16, output_channels=16, kernel_size=(3,

六、自定义Metrics
在医学图像分割领域，通过Dice coefficient、Jaccard coefficient（JC）、Hausdorff distance 95（HD95）、Average surface distance（ASD）、Average symmetric surface distance metric（ASSD）、sensitivity（Sens）等量化指标来衡量分割效果的好坏。

In [24]:
from medpy.metric import binary

class metrics:
    def __init__(self, smooth=1e-5):
        self.smooth=1e-5

    def dice_metric(self, y_pred, y_label, empty_score=1.0):
        """Calculates the dice coefficient for the images"""
        return binary.dc(y_pred, y_label)

    def jc_metric(self, y_pred, y_label):
        """Jaccard coefficient"""
        return binary.jc(y_pred, y_label)

    def hd95_metric(self, y_pred, y_label):
        """Calculates the hausdorff distance for the images"""
        return binary.hd95(y_pred, y_label, voxelspacing=1)

    def asd_metric(self, y_pred, y_label):
        """Average surface distance metric."""
        return binary.asd(y_pred, y_label, voxelspacing=None)

    def assd_metric(self, y_pred, y_label):
        """Average symmetric surface distance metric."""
        return binary.assd(y_pred, y_label, voxelspacing=None)

    def precision_metric(self, y_pred, y_label):
        """precision metric."""
        return binary.precision(y_pred, y_label, voxelspacing=None)

    def sensitivity_metric(self, y_pred, y_label, smooth = 1e-5):
        """recall(also sensitivity) metric."""
        return binary.recall(y_pred, y_label)

七、设置学习率策略
学习率的设置对网络的训练至关重要，在这里我们使用两阶段的学习率，前三个epoch进行warm up，使用线性上升学习率策略，后面七个epoch使用consine下降学习率策略

In [25]:
# dyanmic learning rate
import math

def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
    lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
    learning_rate = float(init_lr) + lr_inc * current_step
    return learning_rate

def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
    base = float(current_step - warmup_steps) / float(decay_steps)
    learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
    return learning_rate

def dynamic_lr(config, base_step):
    """dynamic learning rate generator"""
    base_lr = config.lr
    total_steps = int(base_step * config.epoch_size)
    warmup_steps = config.warmup_step
    lr = []
    for i in range(total_steps):
        if i < warmup_steps:
            lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
        else:
            lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
    return lr

In [26]:
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Moxing adapter for ModelArts"""

import os
import functools
from mindspore import context
# from src.model_utils.config import config

_global_sync_count = 0

def get_device_id():
    device_id = os.getenv('DEVICE_ID', '0')
    return int(device_id)


def get_device_num():
    device_num = os.getenv('RANK_SIZE', '1')
    return int(device_num)


def get_rank_id():
    global_rank_id = os.getenv('RANK_ID', '0')
    return int(global_rank_id)


def get_job_id():
    job_id = os.getenv('JOB_ID')
    job_id = job_id if job_id != "" else "default"
    return job_id

def sync_data(from_path, to_path):
    """
    Download data from remote obs to local directory if the first url is remote url and the second one is local path
    Upload data from local directory to remote obs in contrast.
    """
    import moxing as mox
    import time
    global _global_sync_count
    sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
    _global_sync_count += 1

    # Each server contains 8 devices as most.
    if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
        print("from path: ", from_path)
        print("to path: ", to_path)
        mox.file.copy_parallel(from_path, to_path)
        print("===finish data synchronization===")
        try:
            os.mknod(sync_lock)
        except IOError:
            pass
        print("===save flag===")

    while True:
        if os.path.exists(sync_lock):
            break
        time.sleep(1)

    print("Finish sync data from {} to {}.".format(from_path, to_path))


def moxing_wrapper(pre_process=None, post_process=None):
    """
    Moxing wrapper to download dataset and upload outputs.
    """
    def wrapper(run_func):
        @functools.wraps(run_func)
        def wrapped_func(*args, **kwargs):
            # Download data from data_url
            if config.enable_modelarts:
                if config.data_url:
                    sync_data(config.data_url, config.data_path)
                    print("Dataset downloaded: ", os.listdir(config.data_path))
                if config.checkpoint_url:
                    sync_data(config.checkpoint_url, config.load_path)
                    print("Preload downloaded: ", os.listdir(config.load_path))
                if config.train_url:
                    sync_data(config.train_url, config.output_path)
                    print("Workspace downloaded: ", os.listdir(config.output_path))

                context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
                config.device_num = get_device_num()
                config.device_id = get_device_id()
                if not os.path.exists(config.output_path):
                    os.makedirs(config.output_path)

                if pre_process:
                    pre_process()

            run_func(*args, **kwargs)

            # Upload data to train_url
            if config.enable_modelarts:
                if post_process:
                    post_process()

                if config.train_url:
                    print("Start to copy output directory")
                    sync_data(config.output_path, config.train_url)
        return wrapped_func
    return wrapper


八、主函数训练
主函数训练过程主要包括以下几步：

选择运行设备GPU或者Ascend；
调用create_dataset函数，创建dataloader；
调用Unet3D函数，构建网络；
定义损失函数，这里我们使用常见的dice loss和交叉熵损失（cross entropy loss）；
调用学习率函数，设置优化器；
设置网络为训练模式；
使用for循环，不断将数据送入网络进行训练；

In [32]:
import os
import mindspore
import mindspore.nn as nn
from mindspore import ops
from mindspore import SummaryRecord
import mindspore.common.dtype as mstype
from mindspore import Tensor, Model, context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.loss_scale_manager import FixedLossScaleManager

# config.device_target == 'Ascend'
if config.device_target == 'Ascend':
    device_id = int(os.getenv('DEVICE_ID'))
    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
else:
    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)

mindspore.set_seed(1)

@moxing_wrapper()
def train_net(run_distribute=False):
    if run_distribute:
        init()
        if config.device_target == 'Ascend':
            rank_id = get_device_id()
            rank_size = get_device_num()
        else:
            rank_id = get_rank()
            rank_size = get_group_size()
        parallel_mode = ParallelMode.DATA_PARALLEL
        context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                          device_num=rank_size,
                                          gradients_mean=True)
    else:
        rank_id = 0
        rank_size = 1
    # (1) create dataloader
    train_dataset = create_dataset(data_path=config.data_path + "/image/",
                                   seg_path=config.data_path + "/seg/",
                                   rank_size=rank_size,
                                   rank_id=rank_id, is_training=True)
    train_data_size = train_dataset.get_dataset_size()
    print("train dataset length is:", train_data_size)
    # (2) construct network
    if config.device_target == 'Ascend':
        network = UNet3d()
    else:
        network = UNet3d_()

    # (3) define loss funtion
    loss_ce_fn = nn.CrossEntropyLoss()
    loss_dice_fn = nn.DiceLoss(smooth=1e-5)
    # (4) lr shedule and optimizor
    lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
    optimizer = nn.Adam(params=network.trainable_params(), learning_rate=lr)
    # (5) set training mode
    network.set_train()
    # (6) Start training
    print("============== Starting Training ==============")
    def forward_fn(data, label):
        logits = network(data)
        loss_ce = loss_ce_fn(logits, label)
        loss_dice = loss_dice_fn(logits, label)
        loss = loss_dice + loss_ce
        return loss, loss_dice, loss_ce, logits
    # Get gradient function
    grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
    # Define function of one-step training
    def train_step(data, label):
        (loss, loss_dice, loss_ce, _), grads = grad_fn(data, label)
        loss = ops.depend(loss, optimizer(grads))
        return loss, loss_dice, loss_ce

    with SummaryRecord('./summary_dir/summary_01') as summary_record:
        for epoch in range(config.max_epoch):
            for step, (data, label) in enumerate(train_dataset.create_tuple_iterator()):
                loss, loss_dice, loss_ce = train_step(data, label)

                current_step = epoch * train_data_size + step
                current_lr = optimizer.get_lr()
                summary_record.record(current_step)
                summary_record.add_value('scalar', 'lr', current_lr)
                summary_record.add_value('scalar', 'loss_total', loss)
                summary_record.add_value('scalar', 'loss_dice', loss_dice)
                summary_record.add_value('scalar', 'loss_ce', loss_ce)

                loss, loss_dice, loss_ce = loss.asnumpy(), loss_dice.asnumpy(), loss_ce.asnumpy()
                print("Epoch: %d [%d/%d] [%d/%d] lr:%.7f Loss: %.4f Loss_dice: %.4f Loss_ce: %.4f" %
                      (epoch, step, train_data_size, current_step, train_data_size*config.max_epoch,
                       current_lr, loss, loss_dice, loss_ce))
            # Save checkpoint
            ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
            if not os.path.exists(ckpt_save_dir):
                os.makedirs(ckpt_save_dir)
            mindspore.save_checkpoint(network, os.path.join(ckpt_save_dir, "Epoch_"+str(epoch)+"_model.ckpt"))
            print("Saved Model to {}/Epoch_{}_model.ckpt".format(ckpt_save_dir, epoch))

    print("============== End Training ==============")

if __name__ == '__main__':
    train_net()



train dataset length is: 877




Epoch: 0 [0/877] [0/2631] lr:0.0001529 Loss: 2.4361 Loss_dice: 1.0371 Loss_ce: 1.3990
Epoch: 0 [1/877] [1/2631] lr:0.0001588 Loss: 2.4083 Loss_dice: 1.0172 Loss_ce: 1.3911
Epoch: 0 [2/877] [2/2631] lr:0.0001646 Loss: 2.3624 Loss_dice: 0.9854 Loss_ce: 1.3770
Epoch: 0 [3/877] [3/2631] lr:0.0001704 Loss: 2.3289 Loss_dice: 0.9619 Loss_ce: 1.3670
Epoch: 0 [4/877] [4/2631] lr:0.0001763 Loss: 2.3446 Loss_dice: 0.9709 Loss_ce: 1.3736
Epoch: 0 [5/877] [5/2631] lr:0.0001821 Loss: 2.3213 Loss_dice: 0.9547 Loss_ce: 1.3667
Epoch: 0 [6/877] [6/2631] lr:0.0001879 Loss: 2.2483 Loss_dice: 0.9056 Loss_ce: 1.3427
Epoch: 0 [7/877] [7/2631] lr:0.0001938 Loss: 2.2662 Loss_dice: 0.9167 Loss_ce: 1.3495
Epoch: 0 [8/877] [8/2631] lr:0.0001996 Loss: 2.2343 Loss_dice: 0.8950 Loss_ce: 1.3392
Epoch: 0 [9/877] [9/2631] lr:0.0002054 Loss: 2.1536 Loss_dice: 0.8417 Loss_ce: 1.3119
Epoch: 0 [10/877] [10/2631] lr:0.0002112 Loss: 2.1069 Loss_dice: 0.8109 Loss_ce: 1.2960
Epoch: 0 [11/877] [11/2631] lr:0.0002171 Loss: 2.087



Saved Model to output/./checkpoint//Epoch_0_model.ckpt




Epoch: 1 [0/877] [877/2631] lr:0.0004583 Loss: 0.7121 Loss_dice: 0.3131 Loss_ce: 0.3991
Epoch: 1 [1/877] [878/2631] lr:0.0004582 Loss: 0.6560 Loss_dice: 0.2901 Loss_ce: 0.3659
Epoch: 1 [2/877] [879/2631] lr:0.0004581 Loss: 0.6204 Loss_dice: 0.2755 Loss_ce: 0.3449
Epoch: 1 [3/877] [880/2631] lr:0.0004580 Loss: 0.6449 Loss_dice: 0.2739 Loss_ce: 0.3710
Epoch: 1 [4/877] [881/2631] lr:0.0004579 Loss: 0.7549 Loss_dice: 0.2869 Loss_ce: 0.4680
Epoch: 1 [5/877] [882/2631] lr:0.0004578 Loss: 0.7031 Loss_dice: 0.2751 Loss_ce: 0.4280
Epoch: 1 [6/877] [883/2631] lr:0.0004577 Loss: 0.6264 Loss_dice: 0.2454 Loss_ce: 0.3810
Epoch: 1 [7/877] [884/2631] lr:0.0004576 Loss: 0.6576 Loss_dice: 0.2445 Loss_ce: 0.4131
Epoch: 1 [8/877] [885/2631] lr:0.0004575 Loss: 0.6069 Loss_dice: 0.2354 Loss_ce: 0.3715
Epoch: 1 [9/877] [886/2631] lr:0.0004574 Loss: 0.7210 Loss_dice: 0.2661 Loss_ce: 0.4549
Epoch: 1 [10/877] [887/2631] lr:0.0004573 Loss: 0.6224 Loss_dice: 0.2435 Loss_ce: 0.3789
Epoch: 1 [11/877] [888/2631] lr



Saved Model to output/./checkpoint//Epoch_1_model.ckpt




Epoch: 2 [0/877] [1754/2631] lr:0.0003373 Loss: 0.5599 Loss_dice: 0.2976 Loss_ce: 0.2623
Epoch: 2 [1/877] [1755/2631] lr:0.0003371 Loss: 0.5535 Loss_dice: 0.2873 Loss_ce: 0.2662
Epoch: 2 [2/877] [1756/2631] lr:0.0003370 Loss: 0.5350 Loss_dice: 0.2869 Loss_ce: 0.2481
Epoch: 2 [3/877] [1757/2631] lr:0.0003368 Loss: 0.5481 Loss_dice: 0.2872 Loss_ce: 0.2609
Epoch: 2 [4/877] [1758/2631] lr:0.0003366 Loss: 0.5650 Loss_dice: 0.2609 Loss_ce: 0.3041
Epoch: 2 [5/877] [1759/2631] lr:0.0003365 Loss: 0.5511 Loss_dice: 0.2890 Loss_ce: 0.2621
Epoch: 2 [6/877] [1760/2631] lr:0.0003363 Loss: 0.5289 Loss_dice: 0.2739 Loss_ce: 0.2550
Epoch: 2 [7/877] [1761/2631] lr:0.0003361 Loss: 0.5699 Loss_dice: 0.2927 Loss_ce: 0.2772
Epoch: 2 [8/877] [1762/2631] lr:0.0003360 Loss: 0.5328 Loss_dice: 0.2820 Loss_ce: 0.2508
Epoch: 2 [9/877] [1763/2631] lr:0.0003358 Loss: 0.5601 Loss_dice: 0.2947 Loss_ce: 0.2654
Epoch: 2 [10/877] [1764/2631] lr:0.0003356 Loss: 0.5359 Loss_dice: 0.2844 Loss_ce: 0.2515
Epoch: 2 [11/877] [1

In [33]:
import math
import numpy as np

def first(iterable, default=None):
    """
    Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
    """
    for i in iterable:
        return i
    return default

def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
    """
    Compute scan interval according to the image size, roi size and overlap.
    Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
    use 1 instead to make sure sliding window works.
    """
    if len(image_size) != num_image_dims:
        raise ValueError("image different from spatial dims.")
    if len(roi_size) != num_image_dims:
        raise ValueError("roi size different from spatial dims.")

    scan_interval = []
    for i in range(num_image_dims):
        if roi_size[i] == image_size[i]:
            scan_interval.append(int(roi_size[i]))
        else:
            interval = int(roi_size[i] * (1 - overlap))
            scan_interval.append(interval if interval > 0 else 1)
    return tuple(scan_interval)

def dense_patch_slices(image_size, patch_size, scan_interval):
    """
    Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.

    Args:
        image_size: dimensions of image to iterate over
        patch_size: size of patches to generate slices
        scan_interval: dense patch sampling interval

    Returns:
        a list of slice objects defining each patch
    """
    num_spatial_dims = len(image_size)
    patch_size = patch_size
    scan_num = []
    for i in range(num_spatial_dims):
        if scan_interval[i] == 0:
            scan_num.append(1)
        else:
            num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
            scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
            scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
    starts = []
    for dim in range(num_spatial_dims):
        dim_starts = []
        for idx in range(scan_num[dim]):
            start_idx = idx * scan_interval[dim]
            start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
            dim_starts.append(start_idx)
        starts.append(dim_starts)
    out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
    return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]

def create_sliding_window(image, roi_size, overlap):
    num_image_dims = len(image.shape) - 2
    if overlap < 0 or overlap >= 1:
        raise AssertionError("overlap must be >= 0 and < 1.")
    image_size_temp = list(image.shape[2:])
    image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))

    scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
    slices = dense_patch_slices(image_size, roi_size, scan_interval)
    windows_sliding = [image[slice] for slice in slices]
    return windows_sliding, slices

def CalculateDice(y_pred, label):
    """
    Args:
        y_pred: predictions. As for classification tasks,
            `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
            the shape should be [BNHW] or [BNHWD].
        label: ground truth, the first dim is batch.
    """
    metric = metrics()
    y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
    y_pred = one_hot(y_pred_output)
    y = one_hot(label)
    y_pred, y = ignore_background(y_pred, y)

    dice = metric.dice_metric(y_pred, y)
    hd95 = metric.hd95_metric(y_pred, y)
    jc = metric.jc_metric(y_pred, y)
    asd = metric.asd_metric(y_pred, y)
    sens = metric.sensitivity_metric(y_pred, y)
    return dice, hd95, jc, asd, sens

def ignore_background(y_pred, label):
    """
    This function is used to remove background (the first channel) for `y_pred` and `y`.
    Args:
        y_pred: predictions. As for classification tasks,
            `y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
            the shape should be [BNHW] or [BNHWD].
        label: ground truth, the first dim is batch.
    """
    label = label[:, 1:] if label.shape[1] > 1 else label
    y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
    return y_pred, label

def one_hot(labels):
    N, _, D, H, W = labels.shape
    labels = np.reshape(labels, (N, -1))
    labels = labels.astype(np.int32)
    N, K = labels.shape
    one_hot_encoding = np.zeros((N, config.num_classes, K), dtype=np.float32)
    for i in range(N):
        for j in range(K):
            one_hot_encoding[i, labels[i][j], j] = 1
    labels = np.reshape(one_hot_encoding, (N, config.num_classes, D, H, W))
    return labels

九、模型预测
设置好测试数据集路径和加载模型路径，就可以开始测试啦！

In [37]:
import os
import numpy as np
from mindspore import dtype as mstype
from mindspore import Model, context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net


if config.device_target == 'Ascend':
    device_id = int(os.getenv('DEVICE_ID'))
    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)
else:
    context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, save_graphs=False)

# @moxing_wrapper()
def test_net(data_path, ckpt_path):
    data_dir = data_path + "/image/"
    seg_dir = data_path + "/seg/"
    eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, is_training=False)
    eval_data_size = eval_dataset.get_dataset_size()
    print("test dataset length is:", eval_data_size)

    metrics_score = {}
    metrics_score["dice"] = 0
    metrics_score["hd95"] = 0
    metrics_score["jc"] = 0

    if config.device_target == 'Ascend':
        network = UNet3d()
    else:
        network = UNet3d_()
    network.set_train(False)
    param_dict = load_checkpoint(ckpt_path)
    load_param_into_net(network, param_dict)
    # metrics
    results = {}
    results["dice"] = 0
    results["hd95"] = 0
    results["jc"] = 0
    results["asd"] = 0
    results["sens"] = 0
    model = Model(network)
    index = 0
    total_dice = 0
    config.batch_size=1
    for batch in eval_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
        image = batch["image"]
        seg = batch["seg"]
        print("current image shape is {}".format(image.shape), flush=True)
        sliding_window_list, slice_list = create_sliding_window(image, config.roi_size, config.overlap)
        image_size = (config.batch_size, config.num_classes) + image.shape[2:]
        output_image = np.zeros(image_size, np.float32)
        count_map = np.zeros(image_size, np.float32)
        importance_map = np.ones(config.roi_size, np.float32)
        for window, slice_ in zip(sliding_window_list, slice_list):
            window_image = Tensor(window, mstype.float32)
            pred_probs = model.predict(window_image)
            output_image[slice_] += pred_probs.asnumpy()
            count_map[slice_] += importance_map
        output_image = output_image / count_map
        dice, hd95, jc, asd, sens = CalculateDice(output_image, seg)

        print("The %d batch Dice: %.4f, HD95:%.4f, JC: %.4f, ASD: %.4f, Sens: %.4f"%(index, dice, hd95, jc, asd, sens))
        total_dice += dice
        results["dice"] += dice
        results["hd95"] += hd95
        results["jc"] += jc
        results["asd"] += asd
        results["sens"] += sens
        index = index + 1
    avg_dice = results["dice"] / eval_data_size
    avg_hd95 = results["hd95"] / eval_data_size
    avg_jc = results["jc"] / eval_data_size
    avg_asd = results["asd"] / eval_data_size
    avg_sens = results["sens"] / eval_data_size
    print("**********************End Eval***************************************")
    print("The average Dice: %.4f, HD95:%.4f, JC: %.4f, ASD: %.4f, Sens: %.4f" %
          (avg_dice, avg_hd95, avg_jc, avg_asd, avg_sens))

if __name__ == '__main__':
    test_net(data_path="data/LUNA16/val/",
             ckpt_path="./output/checkpoint/model.ckpt")



test dataset length is: 10




current image shape is (1, 1, 512, 512, 140)
The 0 batch Dice: 0.9764, HD95:0.0000, JC: 0.9539, ASD: 0.2789, Sens: 0.9799
current image shape is (1, 1, 512, 512, 158)
The 1 batch Dice: 0.9766, HD95:0.0000, JC: 0.9542, ASD: 0.4121, Sens: 0.9796
current image shape is (1, 1, 512, 512, 133)
The 2 batch Dice: 0.9728, HD95:0.0000, JC: 0.9470, ASD: 0.6999, Sens: 0.9795
current image shape is (1, 1, 512, 512, 114)
The 3 batch Dice: 0.9201, HD95:27.0185, JC: 0.8521, ASD: 5.8562, Sens: 0.9867
current image shape is (1, 1, 512, 512, 350)
The 4 batch Dice: 0.9581, HD95:0.0000, JC: 0.9195, ASD: 0.0336, Sens: 0.9396
current image shape is (1, 1, 512, 512, 290)
The 5 batch Dice: 0.9668, HD95:0.0000, JC: 0.9357, ASD: 1.9029, Sens: 0.9753
current image shape is (1, 1, 512, 512, 172)
The 6 batch Dice: 0.9849, HD95:0.0000, JC: 0.9703, ASD: 0.0262, Sens: 0.9892
current image shape is (1, 1, 512, 512, 125)
The 7 batch Dice: 0.9795, HD95:0.0000, JC: 0.9598, ASD: 0.0651, Sens: 0.9805
current image shape is 