# Explore using `numba` to make the MUA calculation faster

In [1]:
from numba import njit
from real_spike.utils import get_meta, get_sample_data, butter_filter
import numpy as np
import time

In [2]:
meta_path = "/home/clewis/repos/realSpike/data/120s_test/rb50_20250126_g0_t0.imec0.ap.meta"
bin_path = "/home/clewis/repos/realSpike/data/120s_test/rb50_20250126_g0_t0.imec0.ap.bin"

In [3]:
meta_data = get_meta(meta_path)
data = get_sample_data(bin_path, meta_data)

In [4]:
data.shape

(385, 3600001)

In [5]:
2 * 30 * 1_000

60000

In [6]:
vmax = float(meta_data["imAiRangeMax"])
# get Imax
imax = float(meta_data["imMaxInt"])
# get gain
gain = float(meta_data['imroTbl'].split(sep=')')[1].split(sep=' ')[3])

In [7]:
med_data = data[:150, :60_000]

In [8]:
conv_data = 1e6 * med_data / vmax / imax / gain
    
filt_data = butter_filter(conv_data, 1_000, 30_000)

In [9]:
median = np.median(filt_data, axis=1)
median.shape

(150,)

In [10]:
5 * 30 * 1_000

150000

In [11]:
t_data = data[:150, 60_001:(60_001 + 150_000)]
t_data.shape

(150, 150000)

In [12]:
conv_data = 1e6 * t_data / vmax / imax / gain
    
filt_data = butter_filter(conv_data, 1_000, 30_000)

# Without `numba`

In [13]:
import scipy

In [14]:
def get_spike_events(data: np.ndarray, median, num_dev=4):
    # calculate mad
    mad = scipy.stats.median_abs_deviation(data, axis=1)

    # Calculate threshold
    thresh = (num_dev * mad) + median

    # Vectorized computation of absolute data
    abs_data = np.abs(data)

    # Find indices where threshold is crossed for each channel
    spike_indices = [np.where(abs_data[i] > thresh[i])[0] for i in range(data.shape[0])]

    spike_counts = [np.count_nonzero(arr) for arr in spike_indices]

    return spike_indices, spike_counts

In [15]:
t = time.perf_counter_ns()
get_spike_events(filt_data, median)
print((time.perf_counter_ns() - t) / 1e6)

423.446167


# With `numba`

In [16]:
@njit
def get_spike_events2(mad, median, num_dev=4):
    # calculate mad
    

    # Calculate threshold
    thresh = (num_dev * mad) + median

    # Vectorized computation of absolute data
    abs_data = np.abs(data)

    # Find indices where threshold is crossed for each channel
    spike_indices = [np.where(abs_data[i] > thresh[i])[0] for i in range(data.shape[0])]

    spike_counts = [np.count_nonzero(arr) for arr in spike_indices]

    return spike_indices, spike_counts

In [None]:
t = time.perf_counter_ns()
mad = scipy.stats.median_abs_deviation(data, axis=1)
get_spike_events2(mad, median)
print((time.perf_counter_ns() - t) / 1e6)