In [1]:
import os
import sys
import onnx
import json
import h5py
import torch
import logging
import warnings
import numpy as np
import torch.nn.functional as F

from os.path import join as pjoin
from braindecode.models.deep4 import Deep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

In [2]:
sys.path.append("../src")

from models.quantized_deep4 import QuantDeep4Net

In [3]:
datapath = "../dataset/KU_mi_smt.h5"
subj = 3


set_random_seeds(seed=20200205, cuda=False)

In [4]:
dfile = h5py.File(datapath, 'r')
dpath = '/s' + str(subj)
X = dfile[pjoin(dpath, 'X')][:]
Y = dfile[pjoin(dpath, 'Y')][:]
dfile.close()

In [5]:
X_train, Y_train = X[:200], Y[:200]
X_val, Y_val = X[200:300], Y[200:300]
X_test, Y_test = X[300:], Y[300:]
n_classes = 2
in_chans = X.shape[1]

In [6]:
print(X.shape, Y.shape)

(400, 62, 1000) (400,)


In [7]:
# final_conv_length = auto ensures we only get a single output in the time dimension
model = QuantDeep4Net(in_chans=in_chans, n_classes=n_classes,
                 input_time_length=X.shape[2],
                 final_conv_length=1, split_first_layer=False, act_bit_width=2, weight_bit_width=2)#.cuda()

In [8]:
# make the network so parameters exist
net = model.create_network()

# Count parameters
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_params:,}")

# If you also want frozen ones:
total_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters (incl. frozen): {total_params:,}")

Trainable parameters: 251,733
Total parameters (incl. frozen): 251,733


In [9]:
# these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=1 * 0.01, weight_decay=0.5*0.001)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )

In [10]:
%%time
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    model.fit(X_train, Y_train, epochs=200, batch_size=16, scheduler='cosine', 
            validation_data=(X_val, Y_val), remember_best_column='valid_loss')

CPU times: user 13min 41s, sys: 37.1 s, total: 14min 18s
Wall time: 3min 34s


In [15]:
%%time
test_loss = model.evaluate(X_test, Y_test)

CPU times: user 590 ms, sys: 31.9 ms, total: 622 ms
Wall time: 163 ms


  self.epochs_df = self.epochs_df.append(row_dict, ignore_index=True)


In [11]:
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    test_loss = model.evaluate(X_test, Y_test)
print(test_loss)

{'loss': 0.6158695220947266, 'misclass': 0.30000000000000004, 'runtime': 0.00039577484130859375}


In [12]:
model_dir = os.environ['ROOT_DIR'] + "/onnx"
model_file = model_dir + "/model_s3_new.onnx"

In [13]:
# CPU + eval + dummy input
model.network.cpu().eval()
input_t = torch.randn(1, 62, 1000, 1)

# Export → clean
export_qonnx(model.network, export_path=model_file, input_t=input_t)

ir_version: 7
producer_name: "pytorch"
producer_version: "1.13.1"
graph {
  node {
    input: "inp.1"
    input: "/input_quant/act_quant/export_handler/Constant_1_output_0"
    input: "/input_quant/act_quant/export_handler/Constant_2_output_0"
    input: "/input_quant/act_quant/export_handler/Constant_output_0"
    output: "/input_quant/act_quant/export_handler/Quant_output_0"
    name: "/input_quant/act_quant/export_handler/Quant"
    op_type: "Quant"
    attribute {
      name: "narrow"
      i: 0
      type: INT
    }
    attribute {
      name: "rounding_mode"
      s: "ROUND"
      type: STRING
    }
    attribute {
      name: "signed"
      i: 1
      type: INT
    }
    domain: "onnx.brevitas"
  }
  node {
    input: "/conv_time/weight_quant/export_handler/Constant_1_output_0"
    input: "/conv_time/weight_quant/export_handler/Constant_2_output_0"
    input: "/input_quant/act_quant/export_handler/Constant_2_output_0"
    input: "/conv_time/weight_quant/export_handler/Constant_o

In [14]:
clean_file = model_dir + "/model_s3_clean.onnx"
qonnx_cleanup(model_file, out_file=clean_file)

# Convert QONNX → FINN (no explicit dtype needed)
finn_file = model_dir + "/model_s3_finn.onnx"
mw = ModelWrapper(clean_file)
mw = mw.transform(ConvertQONNXtoFINN())
mw.save(finn_file)