In [None]:
import mindspore
from mindspore.dataset import MnistDataset
from mindspore.dataset import transforms
from mindspore.dataset import vision
import mindspore.nn as nn
import mindspore.ops as ops
from download import *
import time
from matplotlib import pyplot as plt
import os
import cv2


class LeNet(nn.Cell):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.maxpool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Dense(16 * 8 * 8, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, 10)

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = x.reshape((-1, 16 * 8 * 8))
        x = ops.relu(self.fc1(x))
        x = ops.relu(self.fc2(x))
        x = self.fc3(x)
        output = ops.log_softmax(x, axis=1)
        return output


# 训练过程
def train(model, trainloader, optimizer, epoch):
    total = 0
    correct = 0.0
    # 遍历数据
    for i, (inputs, labels) in enumerate(trainloader.create_tuple_iterator()):
        # 得到训练结果
        (loss, outputs), grad = model(inputs, labels)
        # 获取概率最大的预测结果
        predict = outputs.argmax(axis=1)
        total += labels.shape[0]
        correct += (predict == labels).sum().item()
        # 反向传播
        optimizer(grad)
        if i % 1000 == 0:
            # loss.item()表示当前loss的数值
            print(
                "epoch{}:Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100 * (correct / total)))
            Loss.append(loss.item())
            Accuracy.append(correct / total)
    return loss.item(), correct / total


# 模型测试过程
def test(model, testloader):
    # 测试模式
    model.set_train(False)
    # 统计模型正确率
    correct = 0.0
    test_loss = 0.0
    total = 0
    for (data, label) in testloader.create_tuple_iterator():
        output = model(data)
        test_loss += ops.cross_entropy(output, label).item()
        predict = output.argmax(axis=1)
        # 正确数量
        total += label.shape[0]
        correct += (predict == label).sum().item()
    # loss
    print("测试集loss: {:.6f}, 测试集accuracy: {:.6f}%".format(test_loss / total, 100 * (correct / total)))





#数据预处理

def datapipe(path, batch_size):
    image_transforms = [
        # 随机旋转
        vision.RandomHorizontalFlip(),
        # 将图片尺寸resize到32x32
        vision.Resize((32, 32)),
        # 正则化
        vision.Normalize((0.1307,), (0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)
    dataset = MnistDataset(path)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset


if __name__ == "__main__":
    # 下载数据集
    url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
          "notebook/datasets/MNIST_Data.zip"
    path = download(url, "./", kind="zip", replace=True)
    # 加载数据集
    train_dataset = datapipe('MNIST_Data/train', batch_size=64)
    test_dataset = datapipe('MNIST_Data/test', batch_size=64)
    # 创建模型
    mindspore.set_context(device_target='CPU', mode=mindspore.PYNATIVE_MODE)
    model = LeNet()
    # 优化器
    optimizer = nn.Adam(model.trainable_params(), learning_rate=0.001)
    # 计算损失，使用交叉熵损失函数
    loss_fn = nn.CrossEntropyLoss()

    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss, logits

    grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)


    #模型训练并计算loss和accuracy
    epoch = 5
    Loss = []
    Accuracy = []
    # 设置为训练模式
    model.set_train()
    for epoch in range(1, epoch + 1):
        loss, acc = train(grad_fn, train_dataset, optimizer, epoch)
        Loss.append(loss)
        Accuracy.append(acc)
        test(model, test_dataset)
        print()

    print('训练结束！')
    plt.subplot(2, 1, 1)
    plt.plot(Loss)
    plt.title('Loss')
    # plt.show()
    plt.subplot(2, 1, 2)
    plt.plot(Accuracy)
    plt.title('Accuracy')
    plt.show()
    # 5、模型保存
    best_ckpt_dir = './models'
    if not os.path.exists(best_ckpt_dir):
        os.mkdir(best_ckpt_dir)
    mindspore.save_checkpoint(model, './model/mnist_model.ckpt')  # 保存模型