# Torch vs. ONNX

Verify that the spike detection output of torch vs. onnx in python is concordant


In [1]:
import json
import numpy as np
import torch
import matplotlib.pyplot as plt

## Load Trained Model


In [3]:
# Instantiate a trained model
from braindance.core.spikedetector.model import ModelSpikeSorter

# detection_model = ModelSpikeSorter.load("checkpoints/spikedetector/mea")

with open("checkpoints/spikedetector/mea/init_dict.json", "r") as f:
    init_dict = json.load(f)
pytorch_model = ModelSpikeSorter(**init_dict)
state_dict = torch.load(
    "checkpoints/spikedetector/mea/state_dict.pt", map_location="cpu"
)
pytorch_model.load_state_dict(state_dict)

<All keys matched successfully>

## Run via Braindance


In [5]:
from braindance.core.spikesorter.rt_sort import detect_sequences

# Detect sequences in the first 5 minutes of a recording
rt_sort = detect_sequences(
    "data/MEA_rec_patch_ground_truth_cell7.raw.h5",
    "data/inter",
    pytorch_model,
    recording_window_ms=(0, 5 * 1000),
    device="cpu",
    verbose=True,
    debug=True,
    # num_processes=1,  # Uncomment for debugging
)

Skipping saving scaled traces because file scaled_traces.npy already exists and debug=True
Skipping running detection model because file model_outputs.npy already exists and debug=True
Skipping detecting preliminary propagation sequences because file all_clusters.pickle already exists and debug=True
Skipping reassigning spikes because file all_clusters_reassigned.pickle already exists and debug=True
Merging preliminary propagation sequences - first round


100%|██████████| 7/7 [00:02<00:00,  3.19it/s]


7 sequences after first merging
Merging preliminary propagation sequences - second round ...

RT-Sort detected 7 sequences


Output with debug=False
```
Saving traces:
100%|██████████| 1/1 [00:04<00:00,  4.46s/it]
Running detection model:
Compiling detection model for 942 elecs ...
Cannot compile detection model with torch_tensorrt because cannot load torch_tensorrt. Skipping NVIDIA compilation
Allocating disk space to save model traces and outputs ...
Inference scaling: 0.3761194029850746
Running model ...
100%|██████████| 832/832 [09:36<00:00,  1.44it/s]
Detecting sequences
100%|██████████| 942/942 [00:04<00:00, 211.68it/s]
Detected 10 preliminary propagation sequences
Extracting sequences' detections, intervals, and amplitudes

100%|██████████| 10/10 [00:02<00:00,  4.03it/s]
8 clusters remain after filtering
Reassigning spikes to preliminary propagation sequences
Initializing ...
Sorting recording
100%|██████████| 1000/1000 [00:00<00:00, 3377.24it/s]
Extracting sequences' detections, intervals, and amplitudes

100%|██████████| 7/7 [00:02<00:00,  2.80it/s]
7 clusters remain after filtering
Merging preliminary propagation sequences - first round

100%|██████████| 7/7 [00:02<00:00,  3.14it/s]
7 sequences after first merging
Merging preliminary propagation sequences - second round ...

RT-Sort detected 7 sequences
```

## Run via PyTorch

In [9]:
# Calculate inference scaling factor based on the median IQR of the first 1000 frames
import scipy.stats

scaled_traces = np.load("data/inter/scaled_traces.npy")

# Defaults from rt_sort.py
inference_scaling_numerator = 12.6
pre_median_frames = 1000

window = scaled_traces[:, :pre_median_frames]
iqrs = scipy.stats.iqr(window, axis=1)
median = np.median(iqrs)
inference_scaling = inference_scaling_numerator / median
print("inference_scaling:", inference_scaling)

# From output of detect_sequences call above should be 0.3761194029850746
assert np.isclose(inference_scaling, 0.3761194029850746, rtol=1e-5)

inference_scaling: 0.3761194029850746


In [None]:
# Calculate a single input frame to the model
traces_torch = torch.tensor(
    scaled_traces[:, 0 : pytorch_model.sample_size],
    dtype=torch.float16,
)
traces_torch -= torch.median(traces_torch, dim=1, keepdim=True).values

## Compare PyTorch to Braindance

In [43]:
# Limit comparison to the first num_output_locs locations as there is some overlap
model_traces = np.load("data/inter/model_traces.npy")
print("Model inputs are concordant")
np.isclose(
    traces_torch.numpy()[:, 0 : pytorch_model.num_output_locs],
    model_traces[:, 0 : pytorch_model.num_output_locs],
    rtol=1e-5,
).all()

Model inputs are concordant


np.True_

In [46]:
# outputs = pytorch_model.model(model_input[0:1, :, 0:100]).cpu()
pytorch_outputs = pytorch_model.model(
    traces_torch[:, None, :] * pytorch_model.input_scale * inference_scaling
)
model_outputs = np.load("data/inter/model_outputs.npy")
print("Model outputs are concordant")
np.isclose(
    pytorch_outputs.detach().numpy()[:, 0 : pytorch_model.num_output_locs],
    model_outputs[:, 0 : pytorch_model.num_output_locs],
    rtol=1e-5,
).all()

Model outputs are concordant


np.True_

# Export to .onnx


In [100]:
import onnx

# num_channels = init_dict["num_channels_in"]
# sample_size = init_dict["sample_size"]

# model_input = traces_torch[:, None, :] * pytorch_model.input_scale
# channel_num = 0
# start_frame = 0
# test_input = model_input[
#     channel_num : channel_num + 1,
#     :,
#     start_frame : start_frame + pytorch_model.sample_size,
# ]

# # Move the model to GPU if needed
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)
pytorch_model.model.eval()  # Set model to evaluation mode
# Convert all parameters to float32
model = pytorch_model.model.float()  # This casts all parameters to torch.float32

torch.onnx.export(
    model,
    torch.zeros(1, 1, pytorch_model.sample_size, dtype=torch.float32),
    str("models/detect-mea.onnx"),
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 2: "sequence_length"},
        "output": {0: "batch_size", 2: "sequence_length"},
    },
    opset_version=12,
    verbose=False,
)

In [103]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("models/detect-mea.onnx")

input_meta = ort_session.get_inputs()[0]
input_name = input_meta.name
onnx_input_shape = input_meta.shape

test_input = traces_torch[:, None, :] * pytorch_model.input_scale * inference_scaling

onnx_outputs = ort_session.run(
    ["output"],
    {"input": test_input.to(torch.float).numpy()},
)

print("Pytorch and ONNX outputs are concordant")
np.isclose(
    pytorch_outputs.detach().numpy()[:, 0 : pytorch_model.num_output_locs],
    onnx_outputs[0][:, 0 : pytorch_model.num_output_locs],
    rtol=1e-2,
).all()

Pytorch and ONNX outputs are concordant


np.True_

## Find Spikes

Plot the traces and fine spikes via threshold


In [None]:
import matplotlib.pyplot as plt

# Plot the waveform as a time series
plt.figure(figsize=(12, 6))
waveform_data = model_input.squeeze().numpy()  # Remove batch and channel dimensions
time_points = np.arange(len(waveform_data))
plt.plot(time_points, waveform_data)
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.title("Waveform Time Series")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Find spikes (values below -30) in model_input
waveform_data = model_input.squeeze().numpy()

# Handle both single and multi-channel cases
if len(waveform_data.shape) == 2:
    num_channels, sequence_length = waveform_data.shape

    spike_locations = []
    for channel in range(num_channels):
        # Find sample indices where values go below -30
        spike_samples = np.where(waveform_data[channel] < -30)[0]

        if len(spike_samples) > 0:
            print(
                f"Channel {channel}: {len(spike_samples)} spikes at samples {spike_samples}"
            )
            # Store channel and sample pairs
            for sample in spike_samples:
                spike_locations.append(
                    (channel, sample, waveform_data[channel, sample])
                )
        # else:
        #     print(f"Channel {channel}: No spikes detected")

    # Summary of all spikes
    print(f"\nTotal spikes found: {len(spike_locations)}")
    if spike_locations:
        print("(Channel, Sample, Value):")
        for channel, sample, value in spike_locations:
            print(f"  ({channel}, {sample}, {value:.2f})")

else:
    # Single channel case
    spike_samples = np.where(waveform_data < -30)[0]
    print(f"Single channel: {len(spike_samples)} spikes at samples {spike_samples}")
    if len(spike_samples) > 0:
        for sample in spike_samples:
            print(f"  Sample {sample}: {waveform_data[sample]:.2f}")

In [None]:
channel_num = 281
start_frame = 2725
test_input = model_input[
    channel_num : channel_num + 1,
    :,
    start_frame : start_frame + pytorch_model.sample_size,
]

# Plot the waveform as a time series
plt.figure(figsize=(12, 6))
waveform_data = test_input.squeeze().numpy()  # Remove batch and channel dimensions
time_points = np.arange(len(waveform_data))
plt.plot(time_points, waveform_data)
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.title("Waveform Time Series")
plt.grid(True, alpha=0.3)
plt.show()

## PyTorch Run


In [None]:
pytorch_outputs = pytorch_model.model(test_input).cpu()

np.save("data/pytorch/model_outputs.npy", pytorch_outputs.detach().numpy())
pytorch_outputs

# Run via ONNX Python Runtime


In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("models/detect-mea.onnx")

input_meta = ort_session.get_inputs()[0]
input_name = input_meta.name
onnx_input_shape = input_meta.shape

In [None]:
ort_outs = ort_session.run(["output"], {"input": test_input.to(torch.float).numpy()})
ort_outs