In [2]:
# 导入相关模块
import os
import mindspore as ms
import mindspore.context as context
#transforms.c_transforms用于通用型数据增强，vision.c_transforms用于图像类数据增强
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
#nn模块用于定义网络，model模块用于编译模型，callback模块用于设定监督指标
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
#设定运行模式为图模式，运行硬件为昇腾芯片
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') # CPU, GPU
# 从OBS导入数据
import moxing
moxing.file.copy_parallel(src_url="obs://cifar-10-bc53/cifar-10", dst_url='CIFAR10/')
data_path = 'CIFAR10/'


In [3]:
def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    #生成路径
    data_dir = os.path.join(data_dir)
    #利用Cifar10Dataset方法读cifar10数据集，如果training是True则读取训练集
    ds = ms.dataset.Cifar10Dataset(dataset_dir=data_dir, usage='train' if training else 'test')
    # 数据增强操作
    transforms = [
        CV.Resize(resize),  # 调整图像大小
        CV.Rescale(rescale, shift),  # 归一化
        CV.HWC2CHW(),  # HWC to CHW
        CV.RandomRotation(degrees=1),  # 随机旋转图像
        # CV.RandomVerticalFlip(),  # 随机垂直翻转
        # CV.RandomHorizontalFlip()
    ]
    #map方法是非常有效的方法，可以整体对数据集进行处理，resize改变数据形状，rescale进行归一化，HWC2CHW改变图像通道
    ds = ds.map(input_columns=["image"], operations=transforms)
    #利用map方法改变数据集标签的数据类型
    ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
    # shuffle是打乱操作，同时设定了batchsize的大小，并将最后不足一个batch的数据抛弃
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)

    return ds

In [4]:
#定义模型结构，MindSpore中的模型时通过construct定义模型结构，在__init__中初始化各层的对象
class LeNet5(nn.Cell):
    def __init__(self):
        super(LeNet5, self).__init__()
        #定义卷积层，ReLU激活函数，平坦层和全连接层
        #conv2d的输入通道为3维，输出为6维，卷积核尺寸为5*5，步长为1，不适用padding
        self.conv1 = nn.Conv2d(3, 6, 5, stride=1, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, pad_mode='valid')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(keep_prob=0.9)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(400, 512)
        self.fc1_dropout = nn.Dropout(keep_prob=0.5)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Dense(512, 256)
        self.fc2_dropout = nn.Dropout(keep_prob=0.5)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Dense(256,128)
        self.fc3_dropout = nn.Dropout(keep_prob=0.7)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Dense(128,64)
        self.fc4_dropout = nn.Dropout(keep_prob=0.8)
        self.bn4 = nn.BatchNorm1d(64)
        self.fc5 = nn.Dense(64, 10)
        # self.bn5 = nn.BatchNorm1d(32)
        # self.fc6 = nn.Dense(32,10)
    #构建Lenet5架构，x代表网络的输入
    def construct(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.relu(self.bn1(self.fc1_dropout(self.fc1(x))))
        x = self.relu(self.bn2(self.fc2_dropout(self.fc2(x))))
        x = self.relu(self.bn3(self.fc3_dropout(self.fc3(x))))
        x = self.relu(self.bn4(self.fc4_dropout(self.fc4(x))))
        x = self.fc5(x)
        return x


In [5]:
from mindspore.train.callback import Callback
import numpy as np

class EarlyStopping(Callback):
    def __init__(self, patience=4, delta=0, monitor='loss', mode='min'):
        """
        Early stopping callback.
        
        :param patience: 当验证性能在若干个周期内没有改善时，训练会停止。默认是 5。
        :param delta: 用于判断改善的阈值。如果 `monitor` 是 'loss'，则损失小于 delta 表示改善。
        :param monitor: 用于监控的指标，通常是 'loss' 或 'acc'。
        :param mode: 'min' 或 'max'。表示在监控指标上，我们希望是最小化（'min'）还是最大化（'max'）。
        """
        self.patience = patience
        self.delta = delta
        self.monitor = monitor
        self.mode = mode
        self.best_score = None
        self.epochs_without_improvement = 0

    def step(self, run_context):
        # 获取当前周期的验证损失或准确率
        metrics = run_context.original_args()
        current_score = metrics[self.monitor]  # 使用 'loss' 或 'acc'

        # 判断是否为第一次初始化
        if self.best_score is None:
            self.best_score = current_score
        elif (self.mode == 'min' and current_score < self.best_score - self.delta) or \
             (self.mode == 'max' and current_score > self.best_score + self.delta):
            self.best_score = current_score
            self.epochs_without_improvement = 0  # 重置计数器
        else:
            self.epochs_without_improvement += 1

        # 如果在 `patience` 轮内没有改善，则提前停止训练
        if self.epochs_without_improvement >= self.patience:
            print(f"Early stopping triggered at epoch {run_context.cur_epoch_num}")
            run_context.request_stop()  # 停止训练


In [6]:
# 余弦退火学习率衰减
def lr_schedule(epoch, initial_lr=0.001, total_epochs=50):
    if epoch < 5:  # 前5个epoch进行学习率预热
        lr = initial_lr * (epoch / 5)
    else:
        lr = initial_lr * 0.5 * (1 + np.cos(np.pi * (epoch - 5) / (total_epochs - 5)))
    return lr


In [7]:
# 构建训练、验证函数进行模型训练和验证，提供数据路径，设定学习率，epoch数量
def train(data_dir, lr=0.0005, momentum=0.9, num_epochs=50):
    #调用函数，读取训练集
    ds_train = create_dataset(data_dir)
    #调用函数，读取验证集
    ds_eval = create_dataset(data_dir, training=False)
    #构建网络
    net = LeNet5()
    #设定loss函数
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    #设定优化器
    # opt = nn.Momentum(net.trainable_params(), lr, momentum)
    opt = nn.Adam(net.trainable_params(),lr)
    #设定损失监控
    loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
     # 创建EarlyStopping回调
    early_stopping_cb = EarlyStopping(patience=4,monitor='loss',mode='min')
    #编译形成模型
    model = Model(net, loss, opt, metrics={'acc', 'loss'})
    # 训练网络，dataset_sink_mode为on_device模式
    model.train(num_epochs, ds_train, callbacks=[loss_cb,early_stopping_cb], dataset_sink_mode=False)
    #用验证机评估网络表现
    metrics = model.eval(ds_eval, dataset_sink_mode=False)
    #输出相关指标
    print('Metrics:', metrics)
    # for epoch in range(num_epochs):
    #     model.train(1, ds_train, callbacks=[loss_cb,early_stopping_cb], dataset_sink_mode=False)
    #     metrics = model.eval(ds_eval, dataset_sink_mode=False)
    #     print('Metrics:', metrics)

In [9]:
#main函数负责调用之前定义的函数，完成整个训练验证过程
if __name__ == "__main__":
    #argsparse是python的命令行解析的标准模块，可以通过命令行传入参数
    import argparse
    parser = argparse.ArgumentParser()
    #设定训练数据路径
    parser.add_argument('--data_url', required=False, default=None, help='Location of data.')
    parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
    args, unknown = parser.parse_known_args()
    #调用train函数，训练并验证模型
    train(data_path)

epoch: 1 step: 1562, loss is 1.5643280744552612
epoch: 2 step: 1562, loss is 1.4399712085723877
epoch: 3 step: 1562, loss is 1.293887734413147
epoch: 4 step: 1562, loss is 1.081912875175476
epoch: 5 step: 1562, loss is 1.0442205667495728
epoch: 6 step: 1562, loss is 1.3243683576583862
epoch: 7 step: 1562, loss is 1.4306213855743408
epoch: 8 step: 1562, loss is 0.9349492788314819
epoch: 9 step: 1562, loss is 1.3609750270843506
epoch: 10 step: 1562, loss is 1.4849300384521484
epoch: 11 step: 1562, loss is 0.9384766221046448
epoch: 12 step: 1562, loss is 1.0386909246444702
epoch: 13 step: 1562, loss is 1.2069979906082153
epoch: 14 step: 1562, loss is 1.4659512042999268
epoch: 15 step: 1562, loss is 1.172181487083435
epoch: 16 step: 1562, loss is 1.0796278715133667
epoch: 17 step: 1562, loss is 1.0796332359313965
epoch: 18 step: 1562, loss is 1.1411136388778687
epoch: 19 step: 1562, loss is 1.3716362714767456
epoch: 20 step: 1562, loss is 0.713984489440918
epoch: 21 step: 1562, loss is 0.9