In [None]:
import onnx
import torch
import numpy as np
import torch.nn.functional as F

from quantized_deep4 import QuantDeep4Net
from braindecode.torch_ext.optimizers import AdamW
from braindecode.torch_ext.util import set_random_seeds
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

In [None]:
# --- Fake, learnable within-subject data (N,C,T,1) = (400,62,1000,1) ---
rng = np.random.default_rng(42)
N, C, T = 400, 62, 1000

# Base noise
X = rng.normal(0, 1, size=(N, C, T, 1)).astype(np.float32)

# Balanced labels
Y = np.concatenate([np.zeros(N//2, dtype=np.int64), np.ones(N - N//2, dtype=np.int64)])
rng.shuffle(Y)

# Make class 1 slightly structured so the model can learn:
# add a faint sinusoid to first 5 channels for class-1 trials
t = np.linspace(0, 1, T, dtype=np.float32)
sinusoid = (0.25 * np.sin(2 * np.pi * 8 * t))  # 8 Hz bump
mask = (Y == 1)
X[mask, 0:5, :, 0] += sinusoid  # broadcast over time, first 5 chans

# Optional: light per-channel standardization (keeps signal learnable)
mean = X.mean(axis=(0,2,3), keepdims=True)
std  = X.std(axis=(0,2,3), keepdims=True) + 1e-6
X = ((X - mean) / std).astype(np.float32)

# Hand back like your original get_data() call would
# X, Y now match what your code expects:
#   X.shape == (N, 62, 1000, 1)
#   Y.shape == (N,)
print(X.shape)
print(Y.shape)

In [None]:
X_train, Y_train = X[50:250], Y[50:250]
X_val, Y_val = X[250:300], Y[250:300]
X_test, Y_test = X[300:], Y[300:]

n_classes = 2
in_chans = X.shape[1]

# final_conv_length = auto ensures we only get a single output in the time dimension
model = QuantDeep4Net(in_chans=in_chans, n_classes=n_classes,
                 input_time_length=X.shape[2],
                 final_conv_length=1, split_first_layer=False, quant_bit_width=2)#.cuda()

In [None]:
# make the network so parameters exist
net = model.create_network()

# Count parameters
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_params:,}")

# If you also want frozen ones:
total_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters (incl. frozen): {total_params:,}")

In [None]:
# these are good values for the deep model
optimizer = AdamW(model.parameters(), lr=1 * 0.01, weight_decay=0.5*0.001)
model.compile(loss=F.nll_loss, optimizer=optimizer, iterator_seed=1, )

In [None]:
model.fit(X_train, Y_train, epochs=10, batch_size=16, scheduler='cosine', 
        validation_data=(X_val, Y_val))#, remember_best_column='valid_loss')

In [None]:
test_loss = model.evaluate(X_test, Y_test)
print(test_loss)

In [None]:
ready_model_filename = "model.onnx"

# CPU + eval + dummy input
model.network.cpu().eval()
input_t = torch.randn(1, 62, 1000, 1)

# Export → clean
export_qonnx(model.network, export_path=ready_model_filename, input_t=input_t)
qonnx_cleanup(ready_model_filename, out_file=ready_model_filename)

# Convert QONNX → FINN (no explicit dtype needed)
mw = ModelWrapper(ready_model_filename)
mw = mw.transform(ConvertQONNXtoFINN())
mw.save(ready_model_filename)
