In [12]:
import jax.numpy as jnp
import jax
from tqdm import tqdm
from functools import partial
import h5py
import numpy as np


@partial(jax.jit, static_argnames=['dtype'])
def jax_count(spike_times, bin_edges, dtype=jnp.uint32):
    num_bins = bin_edges.size - 1
    idx = jnp.searchsorted(bin_edges, spike_times, side="right") - 1
    counts = jnp.bincount(idx, length=num_bins)

    return jax.lax.convert_element_type(counts, dtype)

def nwb_spike_count(file, interval, window, batchsize=None, dtype=jnp.uint32, save=False):
    T = interval[1] - interval[0]
    N = int(T // window)
    if save:
        readmode = "a"
    else:
        readmode = "r"
    with h5py.File(file, readmode) as data:
        spike_times = data['units']['spike_times'][:] # * [:] Loads the entire file at once
        units = len(data['units']['id'])
        indices = data['units']['spike_times_index']

        bin_edges = jnp.linspace(interval[0]-window, interval[1], N + 2) # Add an extra bin to count less-than spikes

        bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

        def indices2idxs(i):
            minidx = indices[i]
            maxidx = indices[i+1] if i < units - 1 else len(spike_times)
            return range(minidx, maxidx)

        if batchsize is not None:
            nbatches = np.ceil(units / batchsize)
        else:
            nbatches = 1

        batchis = np.arange(units)
        # numspikes = [indices[i+1] - indices[i] for i in range(units)-1]
        # numspikes.append(len(spike_times) - indices[-1])
        # batchis = np.argsort(numspikes)
        # batchis = np.argsort([len(spike_times[indices2idxs(i)]) for i in batchis])
        batchis = np.array_split(batchis, nbatches)

        @jax.jit
        def _count(spike_times):
            return jax_count(spike_times, bin_edges, dtype)

        def count(batchi):
            # * Get spike times for a batch
            batch = [spike_times[indices2idxs(i)] for i in batchi]

            # * Pad the batch with NaN's for JaX
            maxlen = max([len(b) for b in batch])
            batch = [jnp.pad(b, (0, maxlen - len(b)), constant_values=np.nan) for b in batch]
            batch = jnp.stack(batch)

            counts = jax.vmap(_count)(batch)
            return counts

        if save:
            if 'counts' in data:
                del data['counts']
            if 'bin_centers' in data:
                del data['bin_centers']
            data.create_dataset("counts", data=np.zeros((units, N), dtype=dtype))
            for batchi in tqdm(batchis):
                counts_batch = count(batchi)
                data["counts"][batchi, :] = counts_batch[:, 1:]
            data.create_dataset("bin_centers", data=np.array(bin_centers[1:]))
            return file
        else:
            counts = [count(batchi) for batchi in tqdm(batchis)]
            X = jnp.concatenate(counts, axis=0)
            return X[:, 1:], bin_centers[1:]


In [None]:
spike_counts, ts = nwb_spike_count("test.nwb", (0, 9600), 0.005, batchsize=500, dtype=jnp.uint8, save=False)

In [None]:
%load_ext line_profiler
%lprun -f nwb_spike_count nwb_spike_count("test.nwb", (0, 9600), 0.005, batchsize=100, dtype=jnp.uint8, save=False)

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


100%|██████████| 1/1 [00:05<00:00,  5.51s/it]


Timer unit: 1e-09 s

Total time: 13.4483 s
File: /var/tmp/pbs.20495.headnode/ipykernel_2792436/2454981060.py
Function: nwb_spike_count at line 17

Line #      Hits         Time  Per Hit   % Time  Line Contents
    17                                           def nwb_spike_count(file, interval, window, batchsize=None, dtype=jnp.uint32, save=False):
    18         1        880.0    880.0      0.0      T = interval[1] - interval[0]
    19         1       2248.0   2248.0      0.0      N = int(T // window)
    20         1        370.0    370.0      0.0      if save:
    21                                                   readmode = "a"
    22                                               else:
    23         1        163.0    163.0      0.0          readmode = "r"
    24         2  115819365.0    6e+07      0.9      with h5py.File(file, readmode) as data:
    25         1 7812308779.0    8e+09     58.1          spike_times = data['units']['spike_times'][:] # * [:] Loads the entire file at