In [None]:
from mindvision.dataset import Mnist
# Download and process the MNIST dataset.
download_train = Mnist(path="./mnist", split="train", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)
download_eval = Mnist(path="./mnist", split="test", batch_size=32, resize=32, download=True)
dataset_train = download_train.run()
dataset_eval = download_eval.run()

In [None]:
# load model lenet
from mindvision.classification.models import lenet
network = lenet(num_classes=10, pretrained=False)

In [None]:
import mindspore.nn as nn
from mindspore.train import Model
# Define the loss function.
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# Define the optimizer function.
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)

In [None]:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
# Save model
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(prefix="lenet", directory="./lenet", config=config_ck)

In [None]:
from mindvision.engine.callback import LossMonitor
# Ini model parameter
model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})
# Train the network model.
model.train(10, dataset_train, callbacks=[ckpoint, LossMonitor(0.01, 1875)])

In [None]:
# Print accuracy
acc = model.eval(dataset_eval)
print(f"{acc}")

In [None]:
# load model
from mindspore import load_checkpoint, load_param_into_net
# Load the saved model used for testing.
param_dict = load_checkpoint("./lenet/lenet-1_1875.ckpt")
# Load parameters to the network.
load_param_into_net(network, param_dict)

In [None]:
# use model
import numpy as np
from mindspore import Tensor
import matplotlib.pyplot as plt

mnist = Mnist("./mnist", split="test", batch_size=6, resize=32)
dataset_infer = mnist.run()
ds_test = dataset_infer.create_dict_iterator()
data = next(ds_test)
images = data["image"].asnumpy()
labels = data["label"].asnumpy()

plt.figure()
for i in range(1, 7):
    plt.subplot(2, 3, i)
    plt.imshow(images[i-1][0], interpolation="None", cmap="gray")
plt.show()

# Use the model.predict function to predict the classification of the image.
output = model.predict(Tensor(data['image']))
predicted = np.argmax(output.asnumpy(), axis=1)

# Output the predicted classification and the actual classification.
print(f'Predicted: "{predicted}", Actual: "{labels}"')