In [111]:
import mindspore
from mindspore import nn
import mindspore.dataset.vision as vision
from mindspore.dataset import MnistDataset, transforms
from mindspore.common.initializer import HeUniform

In [112]:
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init=HeUniform(), bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init=HeUniform(), bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init=HeUniform(), bias_init="zeros"),
        )


    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

In [113]:
def datapipe(path: str, batch_size: int = 32):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]

    label_transforms = transforms.TypeCast(mindspore.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(operations=image_transforms, input_columns="image")
    dataset = dataset.map(operations=label_transforms, input_columns="label")
    dataset = dataset.batch(batch_size=batch_size)

    return dataset


train_dataset = datapipe("MNIST_Data/train", 64)
test_dataset = datapipe("MNIST_Data/test", 64)

In [114]:
model = Network()

In [115]:
learning_rate = 0.0001
epochs = 3

## 单步训练逻辑及过程

In [116]:
# 定义损失函数
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

In [117]:
# 定义前向传播函数
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

In [118]:
# 定义优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=learning_rate)

In [119]:
# 定义梯度更新函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

In [120]:
# 定义单步训练函数
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

## 数据集遍历迭代

In [121]:
def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy().mean().item(), batch

            print(f"loss: {loss:>7f}[{current:>3d}/{size:>3d}]")

In [122]:
for t in range(epochs):
    print(f"Epoch {t+1}\n--------------------------------------------------")
    train(model, train_dataset)

print("Done!")

Epoch 1
--------------------------------------------------
loss: 0.640295[  0/938]




loss: -214.255569[100/938]
loss: -2382.041748[200/938]
loss: -9708.977539[300/938]
loss: -24820.634766[400/938]
loss: -48446.644531[500/938]
loss: -93399.703125[600/938]
loss: -139511.625000[700/938]
loss: -205110.468750[800/938]
loss: -303632.500000[900/938]
Epoch 2
--------------------------------------------------
loss: -325174.562500[  0/938]
loss: -453491.406250[100/938]
loss: -565709.500000[200/938]
loss: -740478.875000[300/938]
loss: -919496.375000[400/938]
loss: -1140126.000000[500/938]
loss: -1293254.250000[600/938]
loss: -1537703.750000[700/938]
loss: -1850517.750000[800/938]
loss: -2099988.500000[900/938]
Epoch 3
--------------------------------------------------
loss: -2165315.000000[  0/938]
loss: -2588993.500000[100/938]
loss: -2872904.500000[200/938]
loss: -3503566.750000[300/938]
loss: -3880827.250000[400/938]
loss: -4258735.000000[500/938]
loss: -4828865.000000[600/938]
loss: -5215308.000000[700/938]
loss: -5961315.000000[800/938]
loss: -6175308.500000[900/938]
Done!
