In [1]:
import h5py
import torch
import warnings
import torch.nn.functional as F

from os.path import join as pjoin
from models.quantized_deep4 import QuantDeep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

## Load dataset

In [2]:
datapath = "./dataset/KU_mi_smt.h5"
subj = 3


set_random_seeds(seed=20200205, cuda=False)

In [3]:
dfile = h5py.File(datapath, 'r')
dpath = '/s' + str(subj)
X = dfile[pjoin(dpath, 'X')][:]
Y = dfile[pjoin(dpath, 'Y')][:]
dfile.close()

In [4]:
X_train, Y_train = X[:200], Y[:200]
X_val, Y_val = X[200:300], Y[200:300]
X_test, Y_test = X[300:], Y[300:]
n_classes = 2
in_chans = X.shape[1]

In [5]:
print(X.shape, Y.shape)

(400, 62, 1000) (400,)


## Define Model

In [6]:
# final_conv_length = auto ensures we only get a single output in the time dimension
model = QuantDeep4Net(in_chans=in_chans, n_classes=n_classes,
                 input_time_length=X.shape[2],
                 final_conv_length=1, split_first_layer=False, act_bit_width=2, weight_bit_width=2)#.cuda()

In [7]:
# these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=1 * 0.01, weight_decay=0.5*0.001)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )

## Train Model

In [8]:
%%time
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    model.fit(X_train, Y_train, epochs=200, batch_size=16, scheduler='cosine', 
            validation_data=(X_val, Y_val), remember_best_column='valid_loss')

CPU times: user 17min 43s, sys: 1min 4s, total: 18min 48s
Wall time: 4min 42s


## Test Model

In [9]:
%%time
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    test_loss = model.evaluate(X_test, Y_test)

CPU times: user 822 ms, sys: 35.9 ms, total: 858 ms
Wall time: 213 ms


In [None]:
# Make note of this as it should be what you observe on the FPGA (Accuracy = 1 - Missclass)
print(test_loss)

{'loss': 0.6183669567108154, 'misclass': 0.31999999999999995, 'runtime': 0.00039458274841308594}


## Export Model to QONNX/FINN

In [11]:
model_dir = "./onnx"
model_file = model_dir + "/model_s3.onnx"

In [12]:
# CPU + eval + dummy input
model.network.cpu().eval()
input_t = torch.randn(1, 62, 1000, 1)

# Export model to QONNX
export_qonnx(model.network, export_path=model_file, input_t=input_t);

In [13]:
# Clean up the QONNX model
clean_file = model_dir + "/model_s3_clean.onnx"
qonnx_cleanup(model_file, out_file=clean_file)

# Convert QONNX â†’ FINN
finn_file = model_dir + "/model_s3_finn.onnx"
mw = ModelWrapper(clean_file)
mw = mw.transform(ConvertQONNXtoFINN())
mw.save(finn_file)