In [1]:
# 设置设备


# 导入必要的模块
import os
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, jit
from mindspore.dataset import MnistDataset
from mindspore.dataset.transforms import TypeCast
from mindspore.dataset.vision import Rescale, HWC2CHW, RandomCrop, RandomHorizontalFlip
from mindspore import context


context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# 数据加载函数
def create_dataset(data_path, usage="train", batch_size=256, num_workers=4):
    mnist_dataset = MnistDataset(dataset_dir=data_path, usage=usage, shuffle=True)

    transform = [
        Rescale(1.0 / 127.5, -1),  # 归一化到 [-1, 1]
        RandomCrop(28, padding=4),  # 随机裁剪
        RandomHorizontalFlip(prob=0.5),  # 随机水平翻转
        HWC2CHW()  # 确保通道在前 (HWC -> CHW)
    ]

    mnist_dataset = mnist_dataset.map(transform, 'image')
    mnist_dataset = mnist_dataset.map(TypeCast(ms.float32), 'image')  # 图像数据转为 float32
    mnist_dataset = mnist_dataset.map(TypeCast(ms.int32), 'label')   # 标签数据转为 int32

    mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)  # 按批量分组

    return mnist_dataset


# 定义网络
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1, pad_mode='valid')
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=1, pad_mode='valid')
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Dense(64 * 4 * 4, 512)
        self.fc2 = nn.Dense(512, 128)
        self.fc3 = nn.Dense(128, 10)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def construct(self, x):
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        x = x.view(x.shape[0], -1)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits, [logits]


# 训练函数
def train(net, loss_fn, optimizer, train_dataset, epoch_size=5):
    total_train_samples = 60000
    batch_size = 256
    batches_per_epoch = total_train_samples // batch_size

    @jit(hash_args=lambda data, label: (data.shape, label.shape), compile_once=True)
    def forward_fn(data, label):
        logits, _ = net(data)
        loss = loss_fn(logits, label)
        return loss

    grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

    @jit(hash_args=lambda data, label: (data.shape, label.shape), compile_once=True)
    def train_step(data, label):
        loss, grads = grad_fn(data, label)
        optimizer(grads)
        return loss

    try:
        for epoch in range(epoch_size):
            epoch_loss = 0
            for batch_idx, (data, label) in enumerate(train_dataset.create_tuple_iterator()):
                loss = train_step(data, label)
                epoch_loss += loss.asnumpy()

                if np.isnan(loss.asnumpy()):
                    print(f"NaN detected at batch {batch_idx}")
                    break

                if batch_idx % 100 == 0:
                    processed_samples = batch_idx * len(data)
                    print(
                        f'Train Epoch: {epoch} [{processed_samples}/{total_train_samples} ({100. * batch_idx / batches_per_epoch:0.0f}%)]\tLoss: {loss.asnumpy():.6f}')

            print(f"Epoch {epoch + 1} completed, Average Loss: {epoch_loss / batches_per_epoch:.6f}")
    except KeyboardInterrupt:
        print("Training interrupted by user.")
    except Exception as e:
        print(f"An error occurred: {e}")


# 测试函数
def test(net, test_dataset):
    net.set_train(False)
    correct = 0
    total = 0
    for data, label in test_dataset.create_tuple_iterator():
        logits, _ = net(data)
        pred = ops.Argmax(axis=1)(logits)
        correct += (pred == label).asnumpy().sum()
        total += len(label)

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")


# 主程序
if __name__ == "__main__":
    # 数据集路径
    data_path = r"D:\MindSporeProject\pythonProject1\Large_Margin_Loss_PyTorch\data"

    # 加载训练集和测试集
    train_dataset = create_dataset(data_path, usage="train", batch_size=256)
    test_dataset = create_dataset(data_path, usage="test", batch_size=2048)

    # 初始化网络、损失函数和优化器
    net = Net()
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')  # 使用交叉熵损失

    # 定义优化器
    optimizer = nn.Adam(net.trainable_params(), learning_rate=1e-3, weight_decay=1e-4)

    # 训练和测试
    for epoch in range(5):
        train(net, loss_fn, optimizer, train_dataset, epoch_size=1)
        test(net, test_dataset)



Epoch 1 completed, Average Loss: 0.808527




Test Accuracy: 88.37%
Epoch 1 completed, Average Loss: 0.291046




Test Accuracy: 92.50%
Epoch 1 completed, Average Loss: 0.199898
Test Accuracy: 94.63%
Epoch 1 completed, Average Loss: 0.157663




Test Accuracy: 96.13%
Epoch 1 completed, Average Loss: 0.128788
Test Accuracy: 95.95%


In [2]:
import os
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, context
from mindspore.dataset import MnistDataset
from mindspore.dataset.transforms import TypeCast
from mindspore.dataset.vision import Rescale, HWC2CHW, RandomCrop, RandomHorizontalFlip
from mindspore.train import Model
from large_margin import LargeMarginLoss

context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")


# 数据加载函数
def create_dataset(data_path, usage="train", batch_size=256):
    mnist_dataset = MnistDataset(dataset_dir=data_path, usage=usage, shuffle=True)
    transform = [
        Rescale(1.0 / 127.5, -1),
        RandomCrop(28, padding=4),
        RandomHorizontalFlip(prob=0.5),
        HWC2CHW()
    ]
    mnist_dataset = mnist_dataset.map(transform, 'image')
    mnist_dataset = mnist_dataset.map(TypeCast(ms.float32), 'image')
    mnist_dataset = mnist_dataset.map(TypeCast(ms.int32), 'label')
    mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)
    return mnist_dataset

# 定义网络
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, pad_mode='same')
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, pad_mode='same')
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.flatten = ops.Flatten()
        self.fc1 = nn.Dense(64 * 7 * 7, 256)
        self.relu_fc1 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Dense(256, 10)

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        conv1 = x

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        conv2 = x

        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu_fc1(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits, [conv1, conv2]

# 训练函数
# 将 test 定义为独立函数
def test(net, test_dataset):
    net.set_train(False)
    correct = 0
    total = 0
    for data, label in test_dataset.create_tuple_iterator():
        logits, _ = net(data)
        pred = ops.Argmax(axis=1)(logits)
        correct += (pred == label).asnumpy().sum()
        total += len(label)

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")
    return accuracy


def train_lm(net, train_dataset, test_dataset, optimizer, epoch_size=5):
    lm_loss = LargeMarginLoss(gamma=10000, alpha_factor=4, top_k=1, epsilon=1e-6)

    # 定义前向传播和梯度计算的函数
    def forward_fn(data, label):
        logits, feature_maps = net(data)
        one_hot = Tensor(np.eye(10)[label.asnumpy()], ms.float32)
        loss = lm_loss(logits, one_hot, feature_maps)
        return loss

    # 使用MindSpore的value_and_grad函数来获取梯度
    grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

    # 定义单步训练函数
    def train_step(data, label):
        loss, grads = grad_fn(data, label)
        optimizer(grads)  # MindSpore中优化器直接用来更新参数
        return loss

    best_accuracy = 0.0

    for epoch in range(epoch_size):
        net.set_train(True)
        epoch_loss = 0
        total_batches = len(train_dataset)  # 获取批次总数

        for batch_idx, (data, label) in enumerate(train_dataset.create_tuple_iterator()):
            loss = train_step(data, label)
            epoch_loss += loss.asnumpy()

            if batch_idx % 100 == 0:
                batch_size = data.shape[0]  # 从当前批次数据获取批次大小
                processed_samples = batch_idx * batch_size
                print(
                    f'Train Epoch: {epoch} [{processed_samples}/{total_batches * batch_size} ({100. * batch_idx / total_batches:.0f}%)]\tLoss: {loss.asnumpy():.6f}')

        print(f"Epoch {epoch + 1} completed, Average Loss: {epoch_loss / total_batches:.6f}")

        # 在每个周期后评估模型
        accuracy = test(net, test_dataset)

        # 跟踪最佳准确率
        if accuracy > best_accuracy:
            best_accuracy = accuracy

    print(f"Training completed. Best accuracy: {best_accuracy * 100:.2f}%")


# 主程序
if __name__ == "__main__":
    # 数据集路径
    data_path = r"D:\MindSporeProject\pythonProject1\Large_Margin_Loss_PyTorch\data"

    # 加载训练集和测试集
    train_dataset = create_dataset(data_path, usage="train", batch_size=256)
    test_dataset = create_dataset(data_path, usage="test", batch_size=2048)

    # 初始化网络和优化器
    net = Net()
    optimizer = nn.Adam(net.trainable_params(), learning_rate=1e-3, weight_decay=1e-4)

    # 训练和测试
    train_lm(net, train_dataset, test_dataset, optimizer, epoch_size=5)



Epoch 1 completed, Average Loss: 0.046905
Test Accuracy: 66.72%
Epoch 2 completed, Average Loss: 0.035024
Test Accuracy: 72.80%
Epoch 3 completed, Average Loss: 0.031155
Test Accuracy: 75.28%
Epoch 4 completed, Average Loss: 0.027380
Test Accuracy: 76.33%
Epoch 5 completed, Average Loss: 0.024723
Test Accuracy: 65.60%
Training completed. Best accuracy: 76.33%
