# Explore using `numba` to make the MUA calculation faster

In [1]:
import numba
from real_spike.utils import get_meta, get_sample_data, butter_filter
import numpy as np
import time

In [2]:
meta_path = "/home/clewis/repos/realSpike/data/120s_test/rb50_20250126_g0_t0.imec0.ap.meta"
bin_path = "/home/clewis/repos/realSpike/data/120s_test/rb50_20250126_g0_t0.imec0.ap.bin"

In [3]:
meta_data = get_meta(meta_path)
data = get_sample_data(bin_path, meta_data)

In [4]:
data.shape

(385, 3600001)

In [5]:
2 * 30 * 1_000

60000

In [6]:
vmax = float(meta_data["imAiRangeMax"])
# get Imax
imax = float(meta_data["imMaxInt"])
# get gain
gain = float(meta_data['imroTbl'].split(sep=')')[1].split(sep=' ')[3])

In [7]:
med_data = data[:150, :60_000]

In [8]:
conv_data = 1e6 * med_data / vmax / imax / gain
    
filt_data = butter_filter(conv_data, 1_000, 30_000)

In [9]:
median = np.median(filt_data, axis=1)
median.shape

(150,)

In [10]:
5 * 30 * 1_000

150000

In [11]:
t_data = data[:150, 60_001:(60_001 + 150)]
t_data.shape

(150, 150)

In [12]:
conv_data = 1e6 * t_data / vmax / imax / gain
    
filt_data = butter_filter(conv_data, 1_000, 30_000)

# Without `numba`

In [13]:
import scipy

In [32]:
def get_spike_events(data: np.ndarray, median, num_dev=4):
    # calculate mad
    mad = scipy.stats.median_abs_deviation(data, axis=1)

    # Calculate threshold
    thresh = (num_dev * mad) + median

    # Vectorized computation of absolute data
    abs_data = np.abs(data)

    # Find indices where threshold is crossed for each channel
    spike_indices = [np.where(abs_data[i] > thresh[i])[0] for i in range(data.shape[0])]

    spike_counts = [np.count_nonzero(arr) for arr in spike_indices]

    return spike_indices, spike_counts

In [33]:
times = list()
for i in range(200):
    t = time.perf_counter_ns()
    get_spike_events(filt_data, median)
    times.append((time.perf_counter_ns() - t) / 1e6)

In [34]:
sum(times) / len(times)

0.5368111600000001

# With `numba`

In [17]:
@numba.jit()
def row_median(arr):
    n, m = arr.shape
    out = np.empty(n)
    for i in range(n):
        row = np.sort(arr[i])   # Numba supports np.sort
        mid = m // 2
        if m % 2 == 0:
            out[i] = 0.5 * (row[mid - 1] + row[mid])
        else:
            out[i] = row[mid]
    return out

In [18]:
@numba.jit()
def get_spike_events2(data, median, num_dev=4):
    # calculate mad
    absolute_deviations = np.abs(data - median)
    mad = row_median(absolute_deviations)
    print(mad)

    # Calculate threshold
    thresh = (num_dev * mad) + median

    # Vectorized computation of absolute data
    abs_data = np.abs(data)

    # Find indices where threshold is crossed for each channel
    spike_indices = [np.where(abs_data[i] > thresh[i])[0] for i in range(data.shape[0])]

    spike_counts = [np.count_nonzero(arr) for arr in spike_indices]

    return spike_indices, spike_counts

In [27]:
times = list()
for i in range(1):
    t = time.perf_counter_ns()
    get_spike_events2(filt_data, median)
    times.append((time.perf_counter_ns() - t) / 1e6)

[12.07348797 13.37158355 17.35029044 14.26853317 15.7864196  17.45572941
 15.55268171 13.27807263 15.66793779 15.66593277 20.98800553 16.27118885
 16.95902169 16.84292603 19.04865573 15.62230482 16.0485322  16.18698044
 18.57357412 15.84873675 22.78988386 16.1502103  21.79177677 20.64439958
 21.5297344  20.2975379  24.73272513 26.5347807  22.0361881  19.13391767
 21.32572427 23.98180013 22.28049116 21.53624218 26.53002414 22.48515198
 26.27242955 26.24028438 22.03191748 23.99363165 22.28520479 22.64031145
 18.10347501 20.90003791 21.41539471 24.64435929 21.39287083 24.28414127
 22.55637856 32.74149008 24.99200695 28.05451625 29.63305253 30.18124962
 26.70565667 30.45903711 22.70865283 19.53382037 22.26012343 20.99610384
 21.31487351 21.82964739 27.34759756 26.63070825 28.11885595 23.76339832
 24.01870605 25.43222489 20.9907147  25.99043609 26.03847683 20.74892275
 27.99326657 24.25293112 22.58963509 28.67221092 25.8250474  24.25948231
 27.25174261 27.74591173 33.29227147 31.36039008 28

In [28]:
sum(times) / len(times)

2.3067

In [56]:
absolute_deviations = np.abs(filt_data - median)
absolute_deviations.shape
mad = np.median(absolute_deviations, axis=1)

In [57]:
mad.shape

(150,)

# Try using `jax`

In [35]:
import jax.numpy as jnp
import jax

In [36]:
def get_spike_count3(data, median, num_dev=4.0):
    # Median absolute deviation
    mad = jnp.median(jnp.abs(data - median))
    thresh = (num_dev * mad) + median
    
    # Count spikes
  #  spike_indices = [jnp.where(jnp.abs(data)[i] > thresh[i])[0] for i in range(data.shape[0])]

    spike_counts = jnp.sum(jnp.abs(data) > thresh)

    return spike_counts

  #  return spike_indices, spike_counts

# Example: apply row-wise
get_spike_count_jax = jax.vmap(get_spike_count3, in_axes=(0, 0))

In [51]:
times = list()
for i in range(200):
    t = time.perf_counter_ns()
    get_spike_count_jax(filt_data, median)
    times.append((time.perf_counter_ns() - t) / 1e6)

In [52]:
sum(times) / len(times)

1.7577583000000008