# Testbench for `mnist_app()`

In [None]:
#
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Author: Mark Rollins

In [None]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow.keras import mixed_precision
import numpy as np
import os.path

## Layer Definition

In [None]:
inputs = keras.Input(shape=(28,28,1),name="input")
x1 = layers.Conv2D(filters=16,kernel_size=3,activation="relu",name="conv2d_w1")(inputs)
x2 = layers.MaxPooling2D(pool_size=2,name="max_pooling2d_w2")(x1)
x3 = layers.Conv2D(filters=64,kernel_size=3,activation="relu",name="conv2d_w3")(x2)
x4 = layers.MaxPooling2D(pool_size=2,name="max_pooling2d_w4")(x3)
x5 = layers.Conv2D(filters=128,kernel_size=3,activation="relu",name="conv2d_w5")(x4)
x6 = layers.Flatten()(x5)
outputs = layers.Dense(10,activation="softmax")(x6)
model = keras.Model(inputs=inputs,outputs=outputs,name="ConvNet")
w1_taps = ((np.loadtxt('../conv2d_w1/taps_trained.txt')).astype("bfloat16")).astype("float32")
w1_taps = np.reshape(w1_taps,(3,3,1,16))
w1_bias = ((np.loadtxt('../conv2d_w1/bias_trained.txt')).astype("bfloat16")).astype("float32")
w3_taps = ((np.loadtxt('../conv2d_w3/taps_trained.txt')).astype("bfloat16")).astype("float32")
w3_taps = np.reshape(w3_taps,(3,3,16,64))
w3_bias = ((np.loadtxt('../conv2d_w3/bias_trained.txt')).astype("bfloat16")).astype("float32")
w5_taps = ((np.loadtxt('../conv2d_w5/taps_trained.txt')).astype("bfloat16")).astype("float32")
w5_taps = np.reshape(w5_taps,(3,3,64,128))
w5_bias = ((np.loadtxt('../conv2d_w5/bias_trained.txt')).astype("bfloat16")).astype("float32")
w7_taps = ((np.loadtxt('../dense_w7/taps_trained.txt')).astype("bfloat16")).astype("float32")
w7_taps = np.reshape(w7_taps,(1152,10))
w7_bias = ((np.loadtxt('../dense_w7/bias_trained.txt')).astype("bfloat16")).astype("float32")
model.set_weights((w1_taps,w1_bias,w3_taps,w3_bias,w5_taps,w5_bias,w7_taps,w7_bias))

In [None]:
model.summary()

## Load NMIST Images

In [None]:
# Set batch size and # of batches
BS=1
NB_tst = 12  # Must be multiple of 4
# Load MNIST database:
(trn_images,trn_labels), (tst_images,tst_labels) = mnist.load_data()
tst_images = tst_images.reshape((10000,28,28,1))
tst_images = tst_images[:NB_tst*BS,:,:,:]
# Extract usable data:
tst_labels = tst_labels[:NB_tst*BS]
tst_images = tst_images.astype("float32") / 255
tst_images = (tst_images.astype("bfloat16")).astype("float32")
with open("num_iter.h","w") as fid:
    fid.write('//\n')    
    fid.write('// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.\n')    
    fid.write('// SPDX-License-Identifier: MIT\n')    
    fid.write('//\n')    
    fid.write('//  Author: Mark Rollins\n')    
    fid.write(f'#define NUM_ITER {NB_tst/4:5.0f}\n')

## Compute Golden Outputs

In [None]:
predict = model.predict(tst_images,batch_size=1)
predict = (predict.astype("bfloat16")).astype("float32")
print(predict.shape)

## Store Layer Inputs

In [None]:
# Save input images for the network:
print(tst_images.shape)
np.savetxt('data/ifm_i.txt',np.reshape(tst_images,(-1,4)),fmt='%f %f %f %f')
ifm_i = np.reshape(tst_images,(-1,1))

## Store Layer Outputs

In [None]:
print(predict.shape)
np.savetxt('data/ofm_o.txt',np.reshape(predict,(-1,4)),fmt='%f %f %f %f')
ofm_o = np.reshape(predict,(-1,1))

## Store Weights & Bias

In [None]:
taps = np.concatenate((np.reshape(w1_taps,(-1,1)),np.reshape(w1_bias,(-1,1))),axis=0)
np.savetxt('data/wts1_i.txt',np.reshape(taps,(-1,4)),fmt='%f %f %f %f')
wts1_i = np.reshape(taps,(-1,1))

In [None]:
taps = np.concatenate((np.reshape(w3_taps[:,:,0:8, 0: 4],(-1,1)),np.reshape(w3_taps[:,:,8:16, 0: 4],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8, 4: 8],(-1,1)),np.reshape(w3_taps[:,:,8:16, 4: 8],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8, 8:12],(-1,1)),np.reshape(w3_taps[:,:,8:16, 8:12],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,12:16],(-1,1)),np.reshape(w3_taps[:,:,8:16,12:16],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,16:20],(-1,1)),np.reshape(w3_taps[:,:,8:16,16:20],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,20:24],(-1,1)),np.reshape(w3_taps[:,:,8:16,20:24],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,24:28],(-1,1)),np.reshape(w3_taps[:,:,8:16,24:28],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,28:32],(-1,1)),np.reshape(w3_taps[:,:,8:16,28:32],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,32:36],(-1,1)),np.reshape(w3_taps[:,:,8:16,32:36],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,36:40],(-1,1)),np.reshape(w3_taps[:,:,8:16,36:40],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,40:44],(-1,1)),np.reshape(w3_taps[:,:,8:16,40:44],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,44:48],(-1,1)),np.reshape(w3_taps[:,:,8:16,44:48],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,48:52],(-1,1)),np.reshape(w3_taps[:,:,8:16,48:52],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,52:56],(-1,1)),np.reshape(w3_taps[:,:,8:16,52:56],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,56:60],(-1,1)),np.reshape(w3_taps[:,:,8:16,56:60],(-1,1)),
                       np.reshape(w3_taps[:,:,0:8,60:64],(-1,1)),np.reshape(w3_taps[:,:,8:16,60:64],(-1,1))),axis=0)
taps = np.concatenate((np.reshape(taps,(-1,1)),np.reshape(w3_bias,(-1,1))),axis=0)
np.savetxt('data/wts3_i.txt',np.reshape(taps,(-1,4)),fmt='%f %f %f %f')
wts3_i = np.reshape(taps,(-1,1))

In [None]:
taps = np.reshape(w5_taps,(3,3,8,8,32,4))
taps = np.transpose(taps,(4,2,0,1,3,5))
taps0 = np.concatenate((np.reshape(taps[ 0: 8,:,:,:,:,:],(-1,1)),np.reshape(w5_bias[ 0: 32],(-1,1))),axis=0);
taps1 = np.concatenate((np.reshape(taps[ 8:16,:,:,:,:,:],(-1,1)),np.reshape(w5_bias[32: 64],(-1,1))),axis=0);
taps2 = np.concatenate((np.reshape(taps[16:24,:,:,:,:,:],(-1,1)),np.reshape(w5_bias[64: 96],(-1,1))),axis=0);
taps3 = np.concatenate((np.reshape(taps[24:32,:,:,:,:,:],(-1,1)),np.reshape(w5_bias[96:128],(-1,1))),axis=0);
np.savetxt('data/wts5_0_i.txt',np.reshape(taps0,(-1,4)),fmt='%f %f %f %f')
np.savetxt('data/wts5_1_i.txt',np.reshape(taps1,(-1,4)),fmt='%f %f %f %f')
np.savetxt('data/wts5_2_i.txt',np.reshape(taps2,(-1,4)),fmt='%f %f %f %f')
np.savetxt('data/wts5_3_i.txt',np.reshape(taps3,(-1,4)),fmt='%f %f %f %f')
wts5_0_i = np.reshape(taps0,(-1,1))
wts5_1_i = np.reshape(taps1,(-1,1))
wts5_2_i = np.reshape(taps2,(-1,1))
wts5_3_i = np.reshape(taps3,(-1,1))

In [None]:
# We need to zero-pad weights by 6 for DM alignment & 64-bit PLIO
taps = np.reshape(w7_taps,(1152,10))
taps = np.transpose(taps,(1,0))
taps = np.concatenate((np.reshape(taps,(-1,1)),np.reshape(w7_bias,(-1,1))),axis=0)
taps = np.concatenate((taps,np.zeros((6,1))),axis=0)
np.savetxt('data/wts7_i.txt',np.reshape(taps,(-1,4)),fmt='%f %f %f %f')
wts7_i = np.reshape(taps,(-1,1))

## Run Vitis Functional Simulation

In [None]:
import os
import vfs
mnist_graph = vfs.aieGraph(
    input_file='mnist_app.cpp',
    part="xcve2802-vsvh1760-2MP-e-S",
    include_paths=['./','../wts_init','../conv2d_w1','../conv2d_w3','../conv2d_w5',
                  '../max_pooling2d_w2','../max_pooling2d_w4','../dense_w7'])
act_o = mnist_graph.run(vfs.array(ifm_i[:,0],vfs.bfloat16),
                        vfs.array(wts1_i[:,0],vfs.bfloat16),
                        vfs.array(wts3_i[:,0],vfs.bfloat16),
                        vfs.array(wts5_0_i[:,0],vfs.bfloat16),
                        vfs.array(wts5_1_i[:,0],vfs.bfloat16),
                        vfs.array(wts5_2_i[:,0],vfs.bfloat16),
                        vfs.array(wts5_3_i[:,0],vfs.bfloat16),
                        vfs.array(wts7_i[:,0],vfs.bfloat16))

In [None]:
act_o = np.array(act_o)
act_o = np.transpose(np.reshape(act_o,(-1,16)),(0,1))
ofm_o = np.transpose(np.reshape(ofm_o,(-1,10)),(0,1))
act_o = act_o[:,:10]   # Remove zero pad
tmp = np.reshape(ofm_o,(-1,1))
error = np.reshape(act_o-ofm_o,(-1,1))
lvl = np.max(np.abs(tmp))
lvl_min = -(0.5**8)*lvl*np.ones((len(tmp),1))
lvl_max = +(0.5**8)*lvl*np.ones((len(tmp),1))
tt = np.arange(0,len(tmp))
fig,ax = plt.subplots(nrows=1,ncols=2)
ax[0].plot(tt,error,color="b")
ax[0].plot(tt,lvl_min,color="r")
ax[0].plot(tt,lvl_max,color="r")
ax[0].set_title("Error: Keras vs. AIE")
ax[1].plot(tt,np.reshape(ofm_o,(-1,1)))
ax[1].plot(tt,np.reshape(act_o,(-1,1)),linestyle="dashed")
ax[1].set_title("ConvNet Outputs")
ax[1].legend(labels=("Keras","AIE"))
plt.show()