In [1]:
import sys
sys.path.append("/home/user/relax/python")

## Training a primitive model to recognize handwritten digits from the MNIST dataset.

In [2]:
import tvm
from tvm.relay import Call
from tvm import relax, tir
from tvm.relax.testing import nn
from tvm.script import relax as R
import numpy as np

import pandas as pd

from tvm.relax.training import SetupTrainer
from tvm.relax.training.trainer import Trainer
from tvm.script.parser import ir as I
from tvm.relax.training.optimizer import SGD, Adam
from tvm.relax.training.loss import MSELoss, CrossEntropyLoss, L1Loss

In [3]:
import tvm.testing
import numpy as np
from tvm import relax
from tvm.script.parser import ir as I
from tvm.relax.training.optimizer import SGD
from tvm.relax.training.loss import MSELoss
from tvm.relax.training.trainer import Trainer

### 1. Create a model that consists of two fully connected layers. The model takes an image in the format (1, 784) and returns a probability vector for each digit (1, 10).

In [6]:
@I.ir_module
class MLP:
    @R.function
    def predict(
        w0: R.Tensor((784, 512), "float32"),
        b0: R.Tensor((1, 512), "float32"),
        w1: R.Tensor((512, 10), "float32"),
        b1: R.Tensor((1, 10), "float32"),
        x: R.Tensor((1, 784), "float32")
    ):
        with R.dataflow():
            lv0 = R.matmul(x, w0)
            lv1 = R.add(lv0, b0)
            lv2 = R.nn.relu(lv1)
            lv3 = R.matmul(lv2, w1)
            lv4 = R.add(lv3, b1)
            out = R.nn.relu(lv4)
            R.output(out)
        return out

In [11]:
MLP.show()

To print formatted TVM script, please install the formatter 'Black':
/usr/bin/python3 -m pip install "black==22.3.0" --upgrade --user


### 2. Load the MNIST dataset from keras.

In [7]:
from keras.datasets import mnist
from keras.utils import to_categorical

dtype = "float32"
target_type = "float32"
(X_train, Y_train), (_, _) = mnist.load_data()
X_train = X_train.reshape((60000, 1, 28 *28))
X_train = X_train.astype(dtype) / 255
Y_train = to_categorical(Y_train)
dataset = []

for i in range(60000):
    image = X_train[i]
    label = Y_train[i]
    label = label.reshape(10,)
    dataset.append([image, np.array([label]).astype(target_type)])

### 3. Model training.
#### Define setup_trainer with MSE loss function and SGD optimizer. The training output will be compared to a similar model trained with Keras that was trained using a categorical cross entropy loss function. This is preferable for a multi-class classification problem, but categorical cross-entropy loss is not yet supported in tvm. Train the model for 10 epochs. The workout lasted 20 minutes. The training time of the model with keras for 10 epochs was 20 seconds.

#### SGD = 0.01, MSELoss sum

In [None]:
pred_sinfo = relax.TensorStructInfo((1, 10), "float32")
setup_trainer = SetupTrainer(
    MSELoss(reduction="sum"),
    SGD(0.01),
    [pred_sinfo, pred_sinfo],
)
params_num = 4
trainer = Trainer(MLP, params_num, setup_trainer)
trainer.build("llvm", tvm.cpu(0))
trainer.xaiver_uniform_init_params()

last_loss = np.inf
for epoch in range(1):
    loss = 0
    batch_size = 60
    for i in range(batch_size):
        loss += trainer.update_params(dataset[i][0], dataset[i][1]).numpy()
    print("#epoch", epoch, "loss=", loss / batch_size)
trainer.predict(dataset[0][0])

### 4. Define the evaluation function. The Accuracy metric is used.

In [43]:
from sklearn.metrics import accuracy_score
from keras.datasets import mnist

def eval(_trainer):
    (_, _), (X_test, Y_test) = mnist.load_data()
    predicted=[]
    target=[]
    X_test = X_test.astype(dtype) / 255
    for i in range(10000):
        image = X_test[i]
        label = Y_test[i]
        image = image.reshape(1, 28*28)
        predict = _trainer.predict(image)
        predict_label = np.argmax(predict.numpy(), axis=1)[0]
        predicted.append(int(predict_label))
        target.append(label)
    print("Accuracy:", accuracy_score(target, predicted))

### Result.
#### The model trained with tvm shows an accuracy equal to a similar model trained with keras. (0.9822).

In [44]:
eval(trainer)

Accuracy: 0.9822
