# Torch vs. ONNX

Verify that the spike detection model output of braindance vs. torch vs. onnx in python is concordant


In [4]:
import json
import numpy as np
import torch
import matplotlib.pyplot as plt

## Load Trained Model


In [5]:
# Instantiate a trained model
from braindance.core.spikedetector.model import ModelSpikeSorter

# detection_model = ModelSpikeSorter.load("checkpoints/spikedetector/mea")

with open("checkpoints/spikedetector/mea/init_dict.json", "r") as f:
    init_dict = json.load(f)
pytorch_model = ModelSpikeSorter(**init_dict)
state_dict = torch.load(
    "checkpoints/spikedetector/mea/state_dict.pt", map_location="cpu"
)
pytorch_model.load_state_dict(state_dict)

<All keys matched successfully>

## Run via Braindance
Populates data/inter if debug=False

```
Saving traces:
100%|██████████| 1/1 [00:04<00:00,  4.46s/it]
Running detection model:
Compiling detection model for 942 elecs ...
Cannot compile detection model with torch_tensorrt because cannot load torch_tensorrt. Skipping NVIDIA compilation
Allocating disk space to save model traces and outputs ...
Inference scaling: 0.3761194029850746
Running model ...
100%|██████████| 832/832 [09:36<00:00,  1.44it/s]
Detecting sequences
100%|██████████| 942/942 [00:04<00:00, 211.68it/s]
Detected 10 preliminary propagation sequences
Extracting sequences' detections, intervals, and amplitudes

100%|██████████| 10/10 [00:02<00:00,  4.03it/s]
8 clusters remain after filtering
Reassigning spikes to preliminary propagation sequences
Initializing ...
Sorting recording
100%|██████████| 1000/1000 [00:00<00:00, 3377.24it/s]
Extracting sequences' detections, intervals, and amplitudes

100%|██████████| 7/7 [00:02<00:00,  2.80it/s]
7 clusters remain after filtering
Merging preliminary propagation sequences - first round

100%|██████████| 7/7 [00:02<00:00,  3.14it/s]
7 sequences after first merging
Merging preliminary propagation sequences - second round ...

RT-Sort detected 7 sequences
```

In [6]:
from braindance.core.spikesorter.rt_sort import detect_sequences

# Detect sequences in the first 5 minutes of a recording
rt_sort = detect_sequences(
    "data/MEA_rec_patch_ground_truth_cell7.raw.h5",
    "data/inter",
    pytorch_model,
    recording_window_ms=(0, 5 * 1000),
    device="cpu",
    verbose=True,
    debug=True,
    # num_processes=1,  # Uncomment for debugging
)

Saving traces:


100%|██████████| 1/1 [00:04<00:00,  4.46s/it]


Running detection model:
Compiling detection model for 942 elecs ...
Cannot compile detection model with torch_tensorrt because cannot load torch_tensorrt. Skipping NVIDIA compilation
Allocating disk space to save model traces and outputs ...
Inference scaling: 0.3761194029850746
Running model ...


100%|██████████| 832/832 [09:57<00:00,  1.39it/s]


Detecting sequences


100%|██████████| 942/942 [00:04<00:00, 213.90it/s]

Detected 10 preliminary propagation sequences
Extracting sequences' detections, intervals, and amplitudes



100%|██████████| 10/10 [00:02<00:00,  3.99it/s]


8 clusters remain after filtering
Reassigning spikes to preliminary propagation sequences
Initializing ...
Sorting recording


100%|██████████| 1000/1000 [00:00<00:00, 3306.13it/s]

Extracting sequences' detections, intervals, and amplitudes



100%|██████████| 7/7 [00:02<00:00,  2.80it/s]

7 clusters remain after filtering
Merging preliminary propagation sequences - first round



100%|██████████| 7/7 [00:02<00:00,  2.94it/s]

7 sequences after first merging
Merging preliminary propagation sequences - second round ...

RT-Sort detected 7 sequences





In [32]:
# Load outputs for comparison below
scaled_traces = np.load("data/inter/scaled_traces.npy")
# scaled_traces = np.load("data/inter/scaled_traces.npy")[
#     0:8, 0 : 10 * pytorch_model.sample_size
# ]
braindance_model_outputs = np.load("data/inter/model_outputs.npy")

## Run via PyTorch
Run using PyTorch with a simplified version of the code in braindance rtsort run_detection_model

In [42]:
import torch
import numpy as np
from tqdm import tqdm


def run_detection_model(
    scaled_traces,
    model,
    inference_scaling_numerator=12.6,
    pre_median_frames=1000,
    device="cpu",
):
    """
    Simplified function to run a PyTorch detection model on scaled traces using windowed computation.

    Parameters:
        scaled_traces (np.ndarray): Input data array of shape (num_channels, recording_duration).
        model: Pre-instantiated PyTorch model with attributes sample_size, num_output_locs, input_scale.
        inference_scaling_numerator (float): Numerator for scaling factor calculation.
        pre_median_frames (int): Number of frames for initial median calculation.
        device (str): Device to run the model on ("cuda" or "cpu").

    Returns:
        torch.Tensor: Model outputs of shape (num_channels, processed_duration).
    """
    # Convert input to torch tensor
    scaled_traces = torch.tensor(scaled_traces, device=device, dtype=torch.float16)

    # Get model parameters
    sample_size = model.sample_size
    num_output_locs = model.num_output_locs
    input_scale = model.input_scale
    num_chans, rec_duration = scaled_traces.shape

    # Calculate inference scaling based on initial window
    window = (
        scaled_traces[:, :pre_median_frames].to(torch.float32).cpu()
    )  # Cast to float32 and move to CPU
    if window.dtype != torch.float32:
        raise ValueError(
            f"Window tensor dtype is {window.dtype}, expected torch.float32"
        )
    iqrs = torch.quantile(window, 0.75, dim=1) - torch.quantile(window, 0.25, dim=1)
    median_iqr = torch.median(iqrs)
    inference_scaling = (
        inference_scaling_numerator / median_iqr if median_iqr != 0 else 1
    )

    # Define windows for processing
    all_start_frames = list(range(0, rec_duration - sample_size + 1, num_output_locs))[
        0:2
    ]
    output_duration = rec_duration - sample_size + 1
    outputs_all = torch.zeros(
        (num_chans, output_duration), device=device, dtype=torch.float16
    )

    # Process each window
    with torch.no_grad():
        for start_frame in tqdm(all_start_frames):
            # Extract window
            traces_torch = scaled_traces[:, start_frame : start_frame + sample_size]

            # Subtract median for baseline correction
            traces_torch = (
                traces_torch - torch.median(traces_torch, dim=1, keepdim=True).values
            )

            # Run model on window and store output
            outputs = model.model(
                traces_torch[:, None, :] * input_scale * inference_scaling
            )
            outputs_all[:, start_frame : start_frame + num_output_locs] = outputs

    return outputs_all.cpu()


pytorch_model_outputs = run_detection_model(
    scaled_traces=scaled_traces, model=pytorch_model, device="cpu"
)

100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


## Compare Braindance to PyTorch Model Outputs

In [51]:
np.isclose(
    braindance_model_outputs[0:2, 0 : 2 * pytorch_model.num_output_locs],
    pytorch_model_outputs.detach().numpy()[0:2, 0 : 2 * pytorch_model.num_output_locs],
    rtol=1e-4,
).all()

np.True_

## Run via PyTorch

In [None]:
# Calculate inference scaling factor based on the median IQR of the first 1000 frames
import scipy.stats

scaled_traces = np.load("data/inter/scaled_traces.npy")

# Defaults from rt_sort.py
inference_scaling_numerator = 12.6
pre_median_frames = 1000

window = scaled_traces[:, :pre_median_frames]
iqrs = scipy.stats.iqr(window, axis=1)
median = np.median(iqrs)
inference_scaling = inference_scaling_numerator / median
print("inference_scaling:", inference_scaling)

# From output of detect_sequences call above should be 0.3761194029850746
assert np.isclose(inference_scaling, 0.3761194029850746, rtol=1e-5)

In [None]:
# Calculate a single input frame to the model
traces_torch = torch.tensor(
    scaled_traces[:, 0 : pytorch_model.sample_size],
    dtype=torch.float16,
)
traces_torch -= torch.median(traces_torch, dim=1, keepdim=True).values

## Compare PyTorch to Braindance

In [None]:
# Limit comparison to the first num_output_locs locations as there is some overlap
model_traces = np.load("data/inter/model_traces.npy")
print("Model inputs are concordant")
np.isclose(
    traces_torch.numpy()[:, 0 : pytorch_model.num_output_locs],
    model_traces[:, 0 : pytorch_model.num_output_locs],
    rtol=1e-5,
).all()

In [None]:
# outputs = pytorch_model.model(model_input[0:1, :, 0:100]).cpu()
pytorch_outputs = pytorch_model.model(
    traces_torch[:, None, :] * pytorch_model.input_scale * inference_scaling
)
model_outputs = np.load("data/inter/model_outputs.npy")
print("Model outputs are concordant")
np.isclose(
    pytorch_outputs.detach().numpy()[:, 0 : pytorch_model.num_output_locs],
    model_outputs[:, 0 : pytorch_model.num_output_locs],
    rtol=1e-5,
).all()

# Export to .onnx


In [None]:
import onnx

# num_channels = init_dict["num_channels_in"]
# sample_size = init_dict["sample_size"]

# model_input = traces_torch[:, None, :] * pytorch_model.input_scale
# channel_num = 0
# start_frame = 0
# test_input = model_input[
#     channel_num : channel_num + 1,
#     :,
#     start_frame : start_frame + pytorch_model.sample_size,
# ]

# # Move the model to GPU if needed
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model.to(device)
pytorch_model.model.eval()  # Set model to evaluation mode
# Convert all parameters to float32
model = pytorch_model.model.float()  # This casts all parameters to torch.float32

torch.onnx.export(
    model,
    torch.zeros(1, 1, pytorch_model.sample_size, dtype=torch.float32),
    str("models/detect-mea.onnx"),
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 2: "sequence_length"},
        "output": {0: "batch_size", 2: "sequence_length"},
    },
    opset_version=12,
    verbose=False,
)

In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("models/detect-mea.onnx")

input_meta = ort_session.get_inputs()[0]
input_name = input_meta.name
onnx_input_shape = input_meta.shape

test_input = traces_torch[:, None, :] * pytorch_model.input_scale * inference_scaling

onnx_outputs = ort_session.run(
    ["output"],
    {"input": test_input.to(torch.float).numpy()},
)

print("Pytorch and ONNX outputs are concordant")
np.isclose(
    pytorch_outputs.detach().numpy()[:, 0 : pytorch_model.num_output_locs],
    onnx_outputs[0][:, 0 : pytorch_model.num_output_locs],
    rtol=1e-2,
).all()

## Find Spikes

Plot the traces and fine spikes via threshold


In [None]:
import matplotlib.pyplot as plt

# Plot the waveform as a time series
plt.figure(figsize=(12, 6))
waveform_data = model_input.squeeze().numpy()  # Remove batch and channel dimensions
time_points = np.arange(len(waveform_data))
plt.plot(time_points, waveform_data)
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.title("Waveform Time Series")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Find spikes (values below -30) in model_input
waveform_data = model_input.squeeze().numpy()

# Handle both single and multi-channel cases
if len(waveform_data.shape) == 2:
    num_channels, sequence_length = waveform_data.shape

    spike_locations = []
    for channel in range(num_channels):
        # Find sample indices where values go below -30
        spike_samples = np.where(waveform_data[channel] < -30)[0]

        if len(spike_samples) > 0:
            print(
                f"Channel {channel}: {len(spike_samples)} spikes at samples {spike_samples}"
            )
            # Store channel and sample pairs
            for sample in spike_samples:
                spike_locations.append(
                    (channel, sample, waveform_data[channel, sample])
                )
        # else:
        #     print(f"Channel {channel}: No spikes detected")

    # Summary of all spikes
    print(f"\nTotal spikes found: {len(spike_locations)}")
    if spike_locations:
        print("(Channel, Sample, Value):")
        for channel, sample, value in spike_locations:
            print(f"  ({channel}, {sample}, {value:.2f})")

else:
    # Single channel case
    spike_samples = np.where(waveform_data < -30)[0]
    print(f"Single channel: {len(spike_samples)} spikes at samples {spike_samples}")
    if len(spike_samples) > 0:
        for sample in spike_samples:
            print(f"  Sample {sample}: {waveform_data[sample]:.2f}")

In [None]:
channel_num = 281
start_frame = 2725
test_input = model_input[
    channel_num : channel_num + 1,
    :,
    start_frame : start_frame + pytorch_model.sample_size,
]

# Plot the waveform as a time series
plt.figure(figsize=(12, 6))
waveform_data = test_input.squeeze().numpy()  # Remove batch and channel dimensions
time_points = np.arange(len(waveform_data))
plt.plot(time_points, waveform_data)
plt.xlabel("Sample")
plt.ylabel("Amplitude")
plt.title("Waveform Time Series")
plt.grid(True, alpha=0.3)
plt.show()

## PyTorch Run


In [None]:
pytorch_outputs = pytorch_model.model(test_input).cpu()

np.save("data/pytorch/model_outputs.npy", pytorch_outputs.detach().numpy())
pytorch_outputs

# Run via ONNX Python Runtime


In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("models/detect-mea.onnx")

input_meta = ort_session.get_inputs()[0]
input_name = input_meta.name
onnx_input_shape = input_meta.shape

In [None]:
ort_outs = ort_session.run(["output"], {"input": test_input.to(torch.float).numpy()})
ort_outs