# MLP with floating point weights and baremetal C implementation

In [1]:
import sys
import copy
import time
import serial
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import TensorDataset, DataLoader

import src.python.dataset.yalefaces as yalefaces
import src.python.model.util as util

np.random.seed(99)

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Load dataset
### Load and normalize the raw dataset

In [3]:
X, y = yalefaces.load("dataset/yalefaces", flatten=True)
X = X.astype("float32") / 255.0

### Split dataset into train and test

In [4]:
X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y)

### Compress dataset with PCA

In [5]:
num_train_faces, num_pixels = X_train_raw.shape
num_principal_components = int(num_train_faces)
pca = PCA(n_components=num_principal_components)

pca.fit(X)

X_train = pca.transform(X_train_raw)
X_test = pca.transform(X_test_raw)

In [6]:
print(X_train.shape, X_test.shape)

(132, 132) (33, 132)


### Convert datasets to pythorch tensors

In [7]:
train_dataset = TensorDataset(torch.Tensor(X_train), torch.LongTensor(y_train))
test_dataset = TensorDataset(torch.Tensor(X_test), torch.LongTensor(y_test))

## Train MLP

In [8]:
class MLP(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.layers = torch.nn.Sequential(
      torch.nn.Linear(132, 96, bias=True),
      torch.nn.ReLU(),
      torch.nn.Linear(96, 15, bias=True),
    )

  def forward(self, x):
    return self.layers(x)

In [9]:
model = MLP()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0001)

In [10]:
train_accs= []
train_losses = []

for epoch in range(200):
    train_data = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    error, num_samples = util.train(model, device, train_data, optimizer)
    loss = float(error)/float(num_samples)
    train_losses.append(loss)
    
    acc = util.test(model, device, train_data)
    train_accs.append(acc)

In [11]:
test_data = DataLoader(test_dataset, batch_size=len(test_dataset))
acc = util.test(model, device, test_data)
print(f"Test accuracy: {acc * 100:.2f}")

Test accuracy: 84.85


## Quantize model

In [12]:
# copy our original model
qmodel_float = copy.deepcopy(model.layers)
qmodel_float.eval()

Sequential(
  (0): Linear(in_features=132, out_features=96, bias=True)
  (1): ReLU()
  (2): Linear(in_features=96, out_features=15, bias=True)
)

In [13]:
# fuse layers (weights+activation)
torch.quantization.fuse_modules(qmodel_float, ['0', '1'], inplace=True)

# add quantization of input and output
qmodel_float = torch.nn.Sequential(
    torch.quantization.QuantStub(),
    *qmodel_float,
    torch.quantization.DeQuantStub()
)

In [14]:
# configure quantization
qmodel_float.qconfig = torch.quantization.default_qconfig
qmodel_float=qmodel_float.to('cpu')
qmodel_float.qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

In [15]:
# initialize quantization parameters
torch.quantization.prepare(qmodel_float, inplace=True)

Sequential(
  (0): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (1): LinearReLU(
    (0): Linear(in_features=132, out_features=96, bias=True)
    (1): ReLU()
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (2): Identity()
  (3): Linear(
    in_features=96, out_features=15, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (4): DeQuantStub()
)

In [16]:
# learn quantization parameters from test samples
with torch.inference_mode():
    for batch_idx, (x, y) in enumerate(test_data):
        x,y = x.to('cpu'), y.to('cpu')
        qmodel_float(x)

# quantize weights
qmodel = torch.quantization.convert(qmodel_float, inplace=False)

In [17]:
print("Weight size before quantization:", qmodel_float[1][0].weight.element_size(), "byte(s)")
print("Weight size after quantization:", qmodel[1].weight().element_size(), "byte(s)")

Weight size before quantization: 4 byte(s)
Weight size after quantization: 1 byte(s)


In [18]:
# measure accuracy of the quantized model
test_data = DataLoader(test_dataset, batch_size=len(test_dataset))
acc = util.test(qmodel, 'cpu', test_data)
print(f"Test accuracy: {acc * 100:.2f}")

Test accuracy: 84.85


## Export model weights as C code

In [19]:
model_params = model.state_dict()
qmodel_params = qmodel.state_dict()

In [20]:
layer_indexes = [1, 3]

src_filename = 'src/embedded/2-mlp-baremetal-int8/esp32s3/main/mlp_weights.c'
hdr_filename = 'src/embedded/2-mlp-baremetal-int8/esp32s3/include/mlp_weights.h'

model.cpu()

with open(src_filename, 'w') as source, open(hdr_filename, 'w') as header:

    # header preamble
    header.write('#ifndef MLP_WEIGHTS\n#define MLP_WEIGHTS\n\n')
    header.write('#include <stdint.h>\n\n')

    # source includes
    source.write('#include "mlp_weights.h"\n\n')

    # input: quantization params
    x_scale, x_zero = util.get_input_qparams(qmodel_params)
    header.write(f"extern const float input_zero;\n")
    header.write(f"extern const float input_scale;\n\n")
    source.write(f"const float input_zero = {x_zero};\n")
    source.write(f"const float input_scale = {x_scale};\n\n")

    for layer in layer_indexes:
        weights = util.get_weights(model_params, layer-1).flatten()
        weights_scale, layer_scale = util.get_scale(qmodel_params, layer)
        weights_zero, layer_zero = util.get_zero(qmodel_params, layer)

        # quantize weights and bias
        qweights = (np.around(weights / weights_scale) + weights_zero).astype(int)
        qweights -= weights_zero

        # layer and weights: quantization params
        header.write(f"extern const int8_t layer_{layer}_weights_zero;\n")
        header.write(f"extern const float layer_{layer}_weights_scale;\n\n")
        header.write(f"extern const int8_t layer_{layer}_zero;\n")
        header.write(f"extern const float layer_{layer}_scale;\n\n")

        source.write(f"const int8_t layer_{layer}_weights_zero = {weights_zero};\n")
        source.write(f"const float layer_{layer}_weights_scale = {weights_scale};\n\n")
        source.write(f"const int8_t layer_{layer}_zero = {layer_zero};\n")
        source.write(f"const float layer_{layer}_scale = {layer_scale};\n\n")

        # Weights
        header.write(f"extern const int8_t layer_{layer}_weights[{len(weights)}];\n")
        source.write(f"const int8_t layer_{layer}_weights[{len(weights)}] = {{")
        for i in range(len(weights)-1):
            source.write(f"{qweights[i]}, ")
        source.write(f"{qweights[len(qweights)-1]}}};\n\n")

        header.write(f"extern const int16_t layer_{layer}_weights_s16[{len(weights)}];\n")
        source.write(f"const int16_t layer_{layer}_weights_s16[{len(weights)}] = {{")
        for i in range(len(weights)-1):
            source.write(f"{qweights[i]}, ")
        source.write(f"{qweights[len(qweights)-1]}}};\n\n")

    header.write('\n#endif // end of MLP_PARAMS\n')

## Talk to esp32

In [132]:
CPU_FREQ_KHZ = 240000
num_tests = len(X_test)

with serial.Serial("/dev/ttyUSB0", baudrate=115200, timeout=None) as esp32, tqdm(total=num_tests, file=sys.stdout) as pbar:
    esp32.read_until(b'Ready\n')
    
    num_correct = 0
    all_elapsed = []

    for i in range(num_tests):
        face = X_test[i]
        expected = y_test[i]

        # msg = esp32.read_until(b'Waiting for input\n')
        expected_msg = b'Waiting for input\n'
        msg = esp32.read_until(expected_msg)
        
        # msg = esp32.read(len(expected_msg))
        assert msg == expected_msg, msg
        # assert msg.__contains__(expected_msg), msg

        # send command=1 (new inference)
        esp32.write(b'\x01')

        # send input (132 bytes)
        esp32.write(face.tobytes())

        # read output
        subject = esp32.read(4)
        subject = int.from_bytes(subject, byteorder="little")

        # read inference duration
        elapsed = esp32.read(4)
        elapsed = int.from_bytes(elapsed, byteorder="little")
        all_elapsed.append(elapsed)

        # count correct inferences
        if expected == subject:
            num_correct += 1


        # print(expected, subject)
        # print status
        acc = num_correct/(i+1)
        pbar.set_description(f"Accuracy = {acc*100:.2f}%, Average Inference Duration = {elapsed/CPU_FREQ_KHZ:.3f}ms")
        pbar.update(1)


  0%|          | 0/33 [00:00<?, ?it/s]

Accuracy = 84.85%, Average Inference Duration = 0.383ms: 100%|██████████| 33/33 [00:01<00:00, 18.92it/s]


In [122]:
# CPU_FREQ_KHZ = 240000
# num_tests = len(X_test)

# with serial.Serial("/dev/ttyUSB0", baudrate=115200, timeout=None) as esp32:
#     esp32.read_until(b'Ready\n')
    
#     num_correct = 0
#     all_elapsed = []

#     for i in range(num_tests):
#         face = X_test[i]
#         expected = y_test[i]

#         # msg = esp32.read_until(b'Waiting for input\n')
#         expected_msg = b'Waiting for input\n'
#         msg = esp32.read_until(expected_msg)
        
#         assert msg == expected_msg, msg
        
#         # send command=1 (new inference)
#         esp32.write(b'\x01')

#         # send input (132 bytes)
#         esp32.write(face.tobytes())

#         ### MLP

#         # read quantized input 132*int8_t
#         mlp_quant_input = esp32.read(132)

#         # read MVM output as int32
#         mlp_mvm1 = esp32.read(96*4)

#         # read dequantized MVM output as float
#         mlp_mvm1_float = esp32.read(96*4)


#         # read quantized relu as 96*int8
#         mlp_quant_relu = esp32.read(96)

#         # read MVM 2 output as int32
#         mlp_mvm2 = esp32.read(15*4)

#         # read dequantized MVM 2 output as float
#         mlp_mvm2_float = esp32.read(15*4)



#         #### MLP OPTIMIZED

#         # read quantized input s16
#         mlp_s16_quant_input = esp32.read(132*2)

#         # read MVM int16 output
#         mlp_s16_mvm1 = esp32.read(96*2)
        
#         # read MVM 1 output as float
#         mlp_s16_mvm1_float = esp32.read(96*4)

        
#         # read quantized input s16
#         mlp_s16_quant_relu = esp32.read(96*2)

#         # read MVM 2 int16 output
#         mlp_s16_mvm2 = esp32.read(15*2)
        
#         # read MVM 2 output as float
#         mlp_s16_mvm2_float = esp32.read(15*4)

#         break

### Quantized input

In [124]:
# fail = False
# for i in range(0, 132):
#     a = int.from_bytes(mlp_s16_quant_input[(i*2):(i*2)+2], byteorder="little", signed=True)
#     b = int.from_bytes(mlp_quant_input[i:(i+1)], byteorder="little", signed=True)
    
#     if a != b:
#         fail = True
#         print(f"At index {i}, values {a}/{b} - FAIL")
# if fail:
#     print("Test FAILED")
# else:
#     print("Test PASSED")

Test PASSED


### First MVM

In [127]:
# fail = False
# for i in range(0, 96):
#     a = int.from_bytes(mlp_s16_mvm1[(i*2):(i*2)+2], byteorder="little", signed=True)
#     b = int.from_bytes(mlp_mvm1[(i*4):(i*4)+4], byteorder="little", signed=True)
#     # print(a, end=", ")

#     if a != b:
#         fail = True
#         print(f"At index {i}, values {a} {b} - FAIL")
# if fail:
#     print("Test FAILED")
# else:
#     print("Test PASSED")

    

Test PASSED


### Dequantized MVM 1

In [128]:
# fail = False
# for i in range(0, 96):
#     a = int.from_bytes(mlp_s16_mvm1_float[(i*4):(i*4)+4], byteorder="little", signed=True)
#     b = int.from_bytes(mlp_mvm1_float[(i*4):(i*4)+4], byteorder="little", signed=True)
#     # print(a, end=", ")

#     if a != b:
#         fail = True
#         print(f"At index {i}, values {a} {b} - FAIL")
# if fail:
#     print("Test FAILED")
# else:
#     print("Test PASSED")

Test PASSED


### Quantized ReLU

In [129]:
# fail = False
# for i in range(0, 96):
#     a = int.from_bytes(mlp_s16_quant_relu[(i*2):(i*2)+2], byteorder="little", signed=True)
#     b = int.from_bytes(mlp_quant_relu[i:(i+1)], byteorder="little", signed=True)
#     # print(a, end=", ")

#     if a != b:
#         fail = True
#         print(f"At index {i}, values {a} {b} - FAIL")
# if fail:
#     print("Test FAILED")
# else:
#     print("Test PASSED")

Test PASSED


### Second MVM

In [131]:
# fail = False
# for i in range(0, 15):
#     a = int.from_bytes(mlp_s16_mvm2[(i*2):(i*2)+2], byteorder="little", signed=True)
#     b = int.from_bytes(mlp_mvm2[(i*4):(i*4)+4], byteorder="little", signed=True)
#     # print(a, end=", ")

#     if a != b:
#         fail = True
#         print(f"At index {i}, values {a} {b} -> FAIL")
# if fail:
#     print("Test FAILED")
# else:
#     print("Test PASSED")

At index 1, values -14629 -8369 -> FAIL
At index 2, values 7051 -3127 -> FAIL
At index 3, values 13715 -6174 -> FAIL
At index 4, values 29992 -1609 -> FAIL
At index 5, values -30492 2510 -> FAIL
At index 6, values 4900 -2503 -> FAIL
At index 7, values -8267 1169 -> FAIL
At index 8, values -32145 2438 -> FAIL
At index 9, values 7275 -8095 -> FAIL
At index 10, values -18620 -2687 -> FAIL
At index 11, values 12956 -4674 -> FAIL
At index 12, values 14421 -2791 -> FAIL
At index 13, values 19953 10009 -> FAIL
At index 14, values -7391 -2698 -> FAIL
Test FAILED
