In [None]:
import time
import numpy as np
import mindspore as ms
import mindspore.dataset as ds
from mindspore import nn, ops, Tensor
from mindspore.dataset.transforms import Compose
from mindspore.common.initializer import HeUniform
from mindspore.dataset.vision import ToTensor, Normalize


def one_hot_encode(img0, lab):
    img = img0.copy()
    img[:, :10] = img0.min()
    img[range(img0.shape[0]), lab] = img0.max()
    return img

#Load MNIST Data
data_path = './MNIST_data/'

transform = Compose([
    ToTensor(),
    Normalize((0.1307,), (0.3081,)),
    lambda x: x.flatten()
])

train_loader = ds.MnistDataset(data_path, usage='train')
train_loader = train_loader.map(operations=transform, input_columns="image")
train_loader = train_loader.batch(60000)

test_loader = ds.MnistDataset(data_path, usage='test')
test_loader = test_loader.map(operations=transform, input_columns="image")
test_loader = test_loader.batch(10000)


ms.context.set_context(device_target="GPU")
print('Using device:', ms.context.get_context('device_target'))
device = ms.context.get_context('device_target')

for data in train_loader.create_dict_iterator():
    img0 = data["image"].asnumpy()
    lab = data["label"].asnumpy()
    break

for data in test_loader.create_dict_iterator():
    img0_tst = data["image"].asnumpy()
    lab_tst = data["label"].asnumpy()
    break

# Forward Forward Applied to a Single Perceptron for MNIST Classification
n_input, n_out = 784, 125
batch_size, learning_rate = 10, 0.0003
g_threshold = 10
epochs = 250

perceptron = nn.SequentialCell(
    nn.Dense(n_input, n_out, weight_init=HeUniform(), has_bias=True),
    nn.ReLU()
)

optimizer = nn.Adam(perceptron.trainable_params(), learning_rate=learning_rate)

# Define forward propagation and loss calculations
def forward_fn(img_pos_batch, img_neg_batch):
    g_pos = (perceptron(img_pos_batch)**2).mean(axis=1)
    loss_pos = ops.log(1 + ops.exp(-(g_pos - g_threshold))).sum()

    g_neg = (perceptron(img_neg_batch)**2).mean(axis=1)
    loss_neg = ops.log(1 + ops.exp(g_neg - g_threshold)).sum()

    loss = loss_pos + loss_neg
    return loss

# Define the gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

N_trn = img0.shape[0] #Use all training images (60000)

tic = time.time()

for epoch in range(epochs):
    img = img0.copy()

    for i in range(10):  # Random jittering of training images up to 2 pixels
        dx, dy = ops.randint(-2, 2, (2,))
        dx, dy = int(dx.asnumpy()), int(dy.asnumpy())

        # Scroll the rows and columns separately
        img[i] = ops.roll(Tensor(img0[i].reshape(28, 28)), shifts=dy, dims=0).flatten().asnumpy()
        img[i] = ops.roll(Tensor(img[i].reshape(28, 28)), shifts=dx, dims=1).flatten().asnumpy()

    perm = np.random.permutation(N_trn)
    img_pos = one_hot_encode(img[perm], lab[perm])

    lab_permuted = Tensor(lab[perm], dtype=ms.int32)
    rand_integers = ops.randint(low=1, high=10, size=(lab_permuted.shape))
    lab_neg = ops.add(lab_permuted, rand_integers)
    lab_neg = ops.select(ops.greater(lab_neg, 9), ops.sub(lab_neg, 10), lab_neg).asnumpy()
    img_neg = one_hot_encode(img[perm], lab_neg)  # Bad data (random error in label)

    L_tot = 0

    for i in range(0, N_trn, batch_size):
        perceptron.set_train(True)

        # Goodness and loss for good data in batch
        img_pos_batch = img_pos[i:i+batch_size]
        img_pos_batch = Tensor(img_pos_batch, dtype=ms.float32)

        # Goodness and loss for bad data in batch
        img_neg_batch = img_neg[i:i+batch_size]
        img_neg_batch = Tensor(img_neg_batch, dtype=ms.float32)

        loss = forward_fn(img_pos_batch, img_neg_batch)
        L_tot += loss.asnumpy() # Accumulate total loss for epoch

        grads = grad_fn(img_pos_batch, img_neg_batch)[1] # Compute gradients
        optimizer(grads) # Update parameters

    # Test model with validation set
    N_tst = img0_tst.shape[0]  # Use all test images (10000)
    
    #Evaluate goodness for all test images and labels 0...9
    g_tst = ops.zeros((10, N_tst), dtype=ms.float32)
    for n in range(10):
        img_tst = one_hot_encode(img0_tst, n * np.ones_like(lab_tst))
        img_tst = Tensor(img_tst, dtype=ms.float32)
        g_tst[n] = ((perceptron(img_tst)**2).mean(axis=1))
    predicted_label = g_tst.argmax(axis=0)
    
    # Count number of correctly classified images in validation set
    correct_predictions = predicted_label == Tensor(lab_tst, dtype=ms.int32)
    Ncorrect = correct_predictions.sum().asnumpy()

    print(f"Epoch {epoch+1}:\tLoss {L_tot}\tTime {round(time.time() - tic)}s\tTest Error {100 - Ncorrect / N_tst * 100}%")

Using device: GPU
Epoch 1:	Loss 59030.00390625	Time 56s	Test Error 12.299999999999997%
Epoch 2:	Loss 30825.529296875	Time 110s	Test Error 9.350000000000009%
Epoch 3:	Loss 22349.34375	Time 166s	Test Error 7.8999999999999915%
Epoch 4:	Loss 18233.3671875	Time 221s	Test Error 6.780000000000001%
Epoch 5:	Loss 15429.23046875	Time 276s	Test Error 6.179999999999993%
Epoch 6:	Loss 13746.412109375	Time 331s	Test Error 5.679999999999993%
Epoch 7:	Loss 12297.1728515625	Time 387s	Test Error 5.210000000000008%
Epoch 8:	Loss 11008.462890625	Time 442s	Test Error 4.939999999999998%
Epoch 9:	Loss 10210.99609375	Time 500s	Test Error 4.579999999999998%
Epoch 10:	Loss 9538.9736328125	Time 556s	Test Error 4.200000000000003%
Epoch 11:	Loss 8841.65625	Time 612s	Test Error 3.9599999999999937%
Epoch 12:	Loss 8372.162109375	Time 667s	Test Error 4.039999999999992%
Epoch 13:	Loss 7836.30517578125	Time 723s	Test Error 3.8400000000000034%
Epoch 14:	Loss 7419.05322265625	Time 778s	Test Error 3.3100000000000023%
Epoch