In [2]:
import os
import matplotlib.pyplot as plt
import mindspore.ops as ops
import numpy as np
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.train import Model
from mindspore.common import dtype as mstype
from mindspore.nn.loss import MSELoss, SoftmaxCrossEntropyWithLogits
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import LossMonitor, Callback
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.common.initializer import TruncatedNormal

context.set_context(mode=context.GRAPH_MODE, device_target="GPU") # 或者 "Ascend", "CPU"

# 数据集处理函数 - 修改后
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1, apply_one_hot=False): # 增加 apply_one_hot 参数
    mnist_ds = ds.MnistDataset(data_path)
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # 数据增强操作
    resize_op = ds.vision.c_transforms.Resize((resize_height, resize_width))
    rescale_op = ds.vision.c_transforms.Rescale(rescale, shift)
    rescale_nml_op = ds.vision.c_transforms.Rescale(rescale_nml, shift_nml)
    hwc2chw_op = ds.vision.c_transforms.HWC2CHW()
    type_cast_op_label = ds.transforms.c_transforms.TypeCast(mstype.int32) # 原始标签类型转换

    # 对标签应用 TypeCast
    mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op_label, num_parallel_workers=num_parallel_workers)

    # 如果需要，应用 OneHot 编码 (在 batch 之前)
    if apply_one_hot:
        depth = 10  # 类别数
        on_value = Tensor(1.0, mstype.float32)
        off_value = Tensor(0.0, mstype.float32)
        one_hot_op = ds.transforms.c_transforms.OneHot(num_classes=depth) # 使用 dataset 内置 OneHot
        # 注意：内置 OneHot 输出默认是 int32，MSELoss 需要 float32，所以后面加一个 TypeCast
        type_cast_op_onehot = ds.transforms.c_transforms.TypeCast(mstype.float32)
        mnist_ds = mnist_ds.map(input_columns="label", operations=[one_hot_op, type_cast_op_onehot], num_parallel_workers=num_parallel_workers)
        print(f"Applied OneHot encoding to labels for path: {data_path}")

    # 对图像应用变换
    mnist_ds = mnist_ds.map(input_columns="image", operations=[resize_op, rescale_op, rescale_nml_op, hwc2chw_op],
                            num_parallel_workers=num_parallel_workers)

    # Shuffle, Batch, Repeat
    mnist_ds = mnist_ds.shuffle(buffer_size=10000)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)
    return mnist_ds

# 定义LeNet5网络 (保持不变)
class LeNet5(nn.Cell):
    def __init__(self, activation="relu", use_dropout=False):
        super(LeNet5, self).__init__()
        self.use_dropout = use_dropout
        self.activation = nn.ReLU() if activation == "relu" else nn.Sigmoid()
        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode="valid", weight_init=TruncatedNormal(0.02))
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode="valid", weight_init=TruncatedNormal(0.02))
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=TruncatedNormal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=TruncatedNormal(0.02))
        self.fc3 = nn.Dense(84, 10, weight_init=TruncatedNormal(0.02))
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(keep_prob=0.5) # MindSpore >= 1.6 推荐使用 p=0.5

    def construct(self, x):
        x = self.max_pool2d(self.activation(self.conv1(x)))
        x = self.max_pool2d(self.activation(self.conv2(x)))
        x = self.flatten(x)
        x = self.activation(self.fc1(x))
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x

# 自定义训练函数 - 修改后
def train_model(loss_fn, batch_size, activation, use_dropout, epoch_size=10):
    # 根据损失函数决定是否在创建数据集时应用 OneHot
    apply_one_hot_for_loss = isinstance(loss_fn, MSELoss)

    # 创建数据集，传入 apply_one_hot 参数
    train_dataset = create_dataset("./data/train", batch_size=batch_size, apply_one_hot=apply_one_hot_for_loss)
    test_dataset = create_dataset("./data/test", batch_size=batch_size, apply_one_hot=apply_one_hot_for_loss)

    network = LeNet5(activation=activation, use_dropout=use_dropout)
    optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)

    # SoftmaxCrossEntropyWithLogits 的 sparse 参数需要根据标签是否是 OneHot 来设置
    # 如果标签已经是 OneHot (MSELoss情况)，sparse=False
    # 如果标签是原始类别索引 (CrossEntropy情况)，sparse=True
    if isinstance(loss_fn, SoftmaxCrossEntropyWithLogits):
        loss_fn = SoftmaxCrossEntropyWithLogits(sparse=not apply_one_hot_for_loss, reduction="mean")

    model = Model(network, loss_fn, optimizer, metrics={"Accuracy": Accuracy()})

    # 移除 train_model 内部的 OneHot 逻辑

    print(f"开始训练: Loss={loss_fn.__class__.__name__}, Batch Size={batch_size}, Activation={activation}, Dropout={use_dropout}, ApplyOneHot={apply_one_hot_for_loss}")
    model.train(epoch_size, train_dataset, callbacks=[LossMonitor(1000)], dataset_sink_mode=False)

    # 评估时，Accuracy metric 会自动处理 OneHot 或非 OneHot 标签
    acc = model.eval(test_dataset)
    print(f"测试结果: {acc}")
    print("-" * 40)

# --- 对比实验 ---
# 1. 平方差损失 vs 交叉熵损失
print("实验 1: MSELoss vs SoftmaxCrossEntropyWithLogits")
train_model(MSELoss(), batch_size=32, activation="relu", use_dropout=False)
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=32, activation="relu", use_dropout=False) 

# 2. Mini-batch vs No batch (Batch Size 1)
print("实验 2: Batch Size 32 vs 1")
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=32, activation="relu", use_dropout=False)
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=1, activation="relu", use_dropout=False) 

# 3. ReLU vs Sigmoid
print("实验 3: ReLU vs Sigmoid")
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=32, activation="relu", use_dropout=False)
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=32, activation="sigmoid", use_dropout=False)

# 4. 有Dropout vs 无Dropout
print("实验 4: Dropout vs No Dropout")
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=32, activation="relu", use_dropout=True)
train_model(SoftmaxCrossEntropyWithLogits(), batch_size=32, activation="relu", use_dropout=False)

print("所有实验完成。")

实验 1: MSELoss vs SoftmaxCrossEntropyWithLogits
Applied OneHot encoding to labels for path: ./data/train
Applied OneHot encoding to labels for path: ./data/test
开始训练: Loss=MSELoss, Batch Size=32, Activation=relu, Dropout=False, ApplyOneHot=True
epoch: 1 step: 1000, loss is 0.09021913260221481
epoch: 2 step: 125, loss is 0.08975284546613693
epoch: 2 step: 1125, loss is 0.08704525232315063
epoch: 3 step: 250, loss is 0.026349369436502457
epoch: 3 step: 1250, loss is 0.019763758406043053
epoch: 4 step: 375, loss is 0.011463145725429058
epoch: 4 step: 1375, loss is 0.012111770920455456
epoch: 5 step: 500, loss is 0.008468328975141048
epoch: 5 step: 1500, loss is 0.009978756308555603
epoch: 6 step: 625, loss is 0.009144991636276245
epoch: 6 step: 1625, loss is 0.006385263986885548
epoch: 7 step: 750, loss is 0.0055742026306688786
epoch: 7 step: 1750, loss is 0.0029047676362097263
epoch: 8 step: 875, loss is 0.001899060676805675
epoch: 8 step: 1875, loss is 0.0028625493869185448
epoch: 9 step