In [1]:
# 导入相关模块
import os
import mindspore as ms
import mindspore.context as context
#transforms.c_transforms用于通用型数据增强，vision.c_transforms用于图像类数据增强
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
#nn模块用于定义网络，model模块用于编译模型，callback模块用于设定监督指标
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
#设定运行模式为图模式，运行硬件为GPU
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') #  CPU, GPU
# 从obs下载数据
import moxing
moxing.file.copy_parallel(src_url="obs://minst-lab3-10c3/mnist-demo0/data", dst_url='MNIST/')
data_path = 'MNIST/'

INFO:root:Using MoXing-v2.1.0.5d9c87c8-5d9c87c8
INFO:root:Using OBS-Python-SDK-3.20.9.1


In [2]:
# 数据集生成函数
def create_dataset(data_dir, training=True, batch_size=32, resize=(28, 28),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    #生成训练集和测试集的路径
    data_train = os.path.join(data_dir, 'train') # train set
    data_test = os.path.join(data_dir, 'test') # test set
    #利用MnistDataset方法读取mnist数据集，如果training是True则读取训练集
    ds = ms.dataset.MnistDataset(data_train if training else data_test)
    # 数据增强操作
    transforms = [
        CV.Resize(resize),  # 调整图像大小
        CV.Rescale(rescale, shift),  # 归一化
        CV.HWC2CHW(),  # HWC to CHW
        CV.RandomRotation(degrees=1),  # 随机旋转图像
       # CV.RandomVerticalFlip(),  # 随机垂直翻转
       # CV.RandomHorizontalFlip()
    ]
    ds = ds.map(input_columns=["image"], operations=transforms)
    #利用map方法改变数据集标签的数据类型
    ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
    # shuffle是打乱操作，同时设定了batchsize的大小，并将最后不足一个batch的数据抛弃
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)

    return ds

In [3]:
#定义模型结构，MindSpore中的模型时通过construct定义模型结构，在__init__中初始化各层的对象
class MLP(nn.Cell):
    def __init__(self):
        super(MLP, self).__init__()
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.softmax = nn.Softmax()
        self.fc1 = nn.Dense(784, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Dense(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Dense(256,128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Dense(128,64)
        self.bn4 = nn.BatchNorm1d(64)
        self.fc5 = nn.Dense(64,10)
        
    def construct(self, x):
        x = self.flatten(x)
        x = self.relu(self.bn1(self.fc1(x)))
        x = self.relu(self.bn2(self.fc2(x)))
        x = self.relu(self.bn3(self.fc3(x)))
        x = self.relu(self.bn4(self.fc4(x)))
        x = self.fc5(x)
        x = self.softmax(x)
        return x

In [4]:
from mindspore.train.callback import Callback
import numpy as np

class EarlyStopping(Callback):
    def __init__(self, patience=4, delta=0, monitor='loss', mode='min'):
        """
        Early stopping callback.
        
        :param patience: 当验证性能在若干个周期内没有改善时，训练会停止。默认是 4。
        :param delta: 用于判断改善的阈值。如果 `monitor` 是 'loss'，则损失小于 delta 表示改善。
        :param monitor: 用于监控的指标，通常是 'loss' 或 'acc'。
        :param mode: 'min' 或 'max'。表示在监控指标上，我们希望是最小化（'min'）还是最大化（'max'）。
        """
        self.patience = patience
        self.delta = delta
        self.monitor = monitor
        self.mode = mode
        self.best_score = None
        self.epochs_without_improvement = 0

    def step(self, run_context):
        # 获取当前周期的验证损失或准确率
        metrics = run_context.original_args()
        current_score = metrics[self.monitor]  # 使用 'loss' 或 'acc'

        # 判断是否为第一次初始化
        if self.best_score is None:
            self.best_score = current_score
        elif (self.mode == 'min' and current_score < self.best_score - self.delta) or \
             (self.mode == 'max' and current_score > self.best_score + self.delta):
            self.best_score = current_score
            self.epochs_without_improvement = 0  # 重置计数器
        else:
            self.epochs_without_improvement += 1

        # 如果在 `patience` 轮内没有改善，则提前停止训练
        if self.epochs_without_improvement >= self.patience:
            print(f"Early stopping triggered at epoch {run_context.cur_epoch_num}")
            run_context.request_stop()  # 停止训练


In [8]:
# 构建训练、验证函数进行模型训练和验证，提供数据路径，设定学习率，epoch数量
def train(data_dir, lr=0.001, momentum=0.9, num_epochs=20):
    #调用函数，读取训练集
    ds_train = create_dataset(data_dir)
    #调用函数，读取验证集
    ds_eval = create_dataset(data_dir, training=False)
    #构建网络
    net = MLP()
    #设定loss函数
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    #设定优化器
    opt = nn.Momentum(net.trainable_params(), lr, momentum)
    # opt = nn.Adam(net.trainable_params(),lr)

    #设定损失监控
    loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size())
    
    # 创建EarlyStopping回调
    early_stopping_cb = EarlyStopping(patience=4,monitor='loss',mode='min')
    #编译形成模型
    model = Model(net, loss, opt, metrics={'acc', 'loss'})
    # 训练网络，dataset_sink_mode为on_device模式
    model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
    #用验证机评估网络表现
    metrics = model.eval(ds_eval, dataset_sink_mode=False)
    #输出相关指标
    print('Metrics:', metrics)
    # for epoch in range(num_epochs):
    #     new_lr = adjust_learning_rate(epoch,initial_lr=lr,decay_rate=0.9,decay_epochs=10)
    #     lr_scheduler = nn.piecewise_constant_lr([1, 10, 20], [new_lr, new_lr * 0.5, new_lr * 0.1])
    #     opt.learning_rate = lr_scheduler
    #     model.train(1, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
    #     metrics = model.eval(ds_eval, dataset_sink_mode=False)
    #     print('Metrics:', metrics)

In [9]:
#main函数负责调用之前定义的函数，完成整个训练验证过程
if __name__ == "__main__":
    #argsparse是python的命令行解析的标准模块，可以通过命令行传入参数
    import argparse
    parser = argparse.ArgumentParser()
    #设定训练数据路径
    parser.add_argument('--data_url', required=False, default=None, help='Location of data.')
    parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
    args, unknown = parser.parse_known_args()
    #调用train函数，训练并验证模型
    train(data_path)

epoch: 1 step: 1875, loss is 1.9828053712844849
epoch: 2 step: 1875, loss is 1.759229302406311
epoch: 3 step: 1875, loss is 1.7169243097305298
epoch: 4 step: 1875, loss is 1.5133346319198608
epoch: 5 step: 1875, loss is 1.5038005113601685
epoch: 6 step: 1875, loss is 1.5055663585662842
epoch: 7 step: 1875, loss is 1.5337867736816406
epoch: 8 step: 1875, loss is 1.4641516208648682
epoch: 9 step: 1875, loss is 1.5298594236373901
epoch: 10 step: 1875, loss is 1.462080478668213
epoch: 11 step: 1875, loss is 1.4763365983963013
epoch: 12 step: 1875, loss is 1.465675950050354
epoch: 13 step: 1875, loss is 1.4702107906341553
epoch: 14 step: 1875, loss is 1.4808160066604614
epoch: 15 step: 1875, loss is 1.5062050819396973
epoch: 16 step: 1875, loss is 1.4798637628555298
epoch: 17 step: 1875, loss is 1.4631993770599365
epoch: 18 step: 1875, loss is 1.4623202085494995
epoch: 19 step: 1875, loss is 1.476589322090149
epoch: 20 step: 1875, loss is 1.5100959539413452
Metrics: {'acc': 0.98417467948717