In [11]:
# from experanto.interpolators import SpikesInterpolator as Interpolator
from experanto.interpolators import Interpolator
from experanto.intervals import TimeInterval

import numpy as np
import yaml
import tempfile
import shutil
from pathlib import Path
from numba import njit, prange

In [12]:
import numpy as np
from numba import njit, prange
from pathlib import Path

# --- THE ENGINE ---
# This function is compiled to machine code the first time it runs.
# 'parallel=True' allows it to use all CPU cores.
@njit(parallel=True, fastmath=True)
def fast_count_spikes(all_spikes, indices, window_starts, window_ends, out_counts):
    """
    all_spikes: 1D array (1.1GB)
    indices: 1D array (38k) - start/end of each neuron in all_spikes
    window_starts: 1D array (BatchSize) - start times for the query
    window_ends: 1D array (BatchSize)
    out_counts: 2D array (BatchSize, N_Neurons) - Result placeholder
    """
    n_batch = len(window_starts)
    n_neurons = len(indices) - 1
    
    # We parallelize the OUTER loop (the batch). 
    # Or we can parallelize the NEURON loop. 
    # Since N_Neurons (38k) > Batch (e.g. 128), parallelizing neurons is better.
    
    for i in prange(n_neurons):
        # 1. Get the slice for this neuron
        # (This is zero-copy in Numba)
        idx_start = indices[i]
        idx_end = indices[i+1]
        neuron_spikes = all_spikes[idx_start:idx_end]
        
        # 2. Check all time windows for this neuron
        # Since spikes are sorted, we use binary search
        for b in range(n_batch):
            t0 = window_starts[b]
            t1 = window_ends[b]
            
            # Binary Search
            # np.searchsorted is supported natively in Numba
            # It finds where t0 and t1 would fit in the sorted array
            c_start = np.searchsorted(neuron_spikes, t0)
            c_end = np.searchsorted(neuron_spikes, t1)
            
            out_counts[b, i] = c_end - c_start

# --- THE CLASS ---
class SpikesInterpolator(Interpolator):
    def __init__(
            self, 
            root_folder: str,
            cache_data: bool = False,
            interpolation_window: float = 0.3,
            interpolation_align: str = "center",
            load_to_ram: bool = False,
            ):
        super().__init__(root_folder)

        meta = self.load_meta()

        self.start_time = meta.get("start_time", 0)
        self.end_time = meta.get("end_time", np.inf)
        self.valid_interval = TimeInterval(self.start_time, self.end_time)

        self.cache_trials = cache_data
        self.interpolation_window = interpolation_window
        self.interpolation_align = interpolation_align

        # Use self.root_folder, defined in the base class
        self.dat_path = self.root_folder / "spikes.npy"
        
        # Ensure indices are typed correctly for Numba
        self.indices = np.array(meta["spike_indices"]).astype(np.int64)
        self.n_signals = len(self.indices) - 1

        if load_to_ram:
            print("Loading spikes to RAM...")
            self.spikes = np.fromfile(self.dat_path, dtype='float64')
        else:
            self.spikes = np.memmap(self.dat_path, dtype='float64', mode='r')

    def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        # 1. Filter for valid times
        valid = self.valid_times(times)
        valid_times = times[valid]
        
        # Handle edge case where no times are valid
        if len(valid_times) == 0:
            return np.empty((0, self.n_signals)), valid

        valid_times += 1e-4

        # 2. Prepare boundaries
        if self.interpolation_align == "center":
            starts = valid_times - self.interpolation_window / 2
            ends   = valid_times + self.interpolation_window / 2
        elif self.interpolation_align == "left":
            starts = valid_times
            ends   = valid_times + self.interpolation_window
        elif self.interpolation_align == "right":
            starts = valid_times - self.interpolation_window
            ends   = valid_times
        else:
            raise ValueError(f"Unknown alignment mode: {self.interpolation_align}")

        # 3. Prepare Output
        # SIZE FIX: Only allocate for the VALID batch size
        batch_size = len(valid_times)
        counts = np.zeros((batch_size, self.n_signals), dtype=np.float64)
        
        # 4. Call Numba Engine
        fast_count_spikes(self.spikes, self.indices, starts, ends, counts)
        
        # SIGNATURE FIX: Return both data and the mask
        return counts, valid

In [13]:
print("Loaded SpikesInterpolator")

Loaded SpikesInterpolator


In [21]:
# ==========================================
# 2. DATA GENERATION & TESTING
# ==========================================

def create_dummy_data(folder_path):
    """Creates synthetic spike data and metadata."""
    folder = Path(folder_path)
    folder.mkdir(parents=True, exist_ok=True)
    
    n_neurons = 500
    duration = 1000.0
    avg_rate = 200 # Hz
    
    all_spikes = []
    indices = [0]
    
    # Generate sorted random spikes for each neuron
    for _ in range(n_neurons):
        n_spikes = int(duration * avg_rate)
        spikes = np.sort(np.random.uniform(0, duration, n_spikes))
        all_spikes.append(spikes)
        indices.append(indices[-1] + len(spikes))
        
    flat_spikes = np.concatenate(all_spikes)
    
    # Save spikes.npy
    flat_spikes.tofile(folder / "spikes.npy")
    
    # Save meta.yml
    meta = {
        "modality": "spikes",
        "n_signals": n_neurons,
        "spike_indices": indices,
        "start_time": 0.0,
        "end_time": duration,
        "sampling_rate": 1000.0 # dummy value
    }
    
    with open(folder / "meta.yml", "w") as f:
        yaml.dump(meta, f)
        
    return all_spikes, flat_spikes, indices

In [23]:
# Setup temporary directory
temp_dir = tempfile.mkdtemp()

try:
    print(f"Test Environment created at: {temp_dir}")
    
    # 1. Create Data
    gt_spikes_list, gt_flat, gt_indices = create_dummy_data(temp_dir)
    
    # 2. Instantiate Interpolator
    # We use a window of 1.0s to make manual verification easy
    interpolator = SpikesInterpolator(temp_dir, interpolation_window=1.0, interpolation_align="left")
    
    # 3. Define Query Times
    # Time 5.0 -> Window [4.5, 5.5]
    query_times = np.array([1.0, 5.0, 9.0]) 
    
    # 4. Run Interpolation
    print("Running interpolation...")
    counts, valid = interpolator.interpolate(query_times)
    
    print(f"Output Shape: {counts.shape}")
    print(f"Output Counts:\n{counts}")

    # 5. Verify Correctness (Slow Check)
    print("\nVerifying accuracy against ground truth...")
    window = 1.0
    errors = 0
    
    for t_idx, t in enumerate(query_times):
        t_start = t - window/2
        t_end = t + window/2
        
        for n_idx in range(len(gt_spikes_list)):
            # Manual count using the original list
            neuron_spikes = gt_spikes_list[n_idx]
            manual_count = np.sum((neuron_spikes >= t_start) & (neuron_spikes < t_end))
            
            numba_count = counts[t_idx, n_idx]
            
            if manual_count != numba_count:
                print(f"Mismatch at time {t}, neuron {n_idx}: Expected {manual_count}, got {numba_count}")
                errors += 1
                
    if errors == 0:
        print("SUCCESS: All counts match exactly.")
    else:
        print(f"FAILED: {errors} mismatches found.")

finally:
    # Cleanup
    shutil.rmtree(temp_dir)

Test Environment created at: /mnt/lustre-grete/tmp/u18196/tmpuhc58r7j
Running interpolation...
Output Shape: (3, 500)
Output Counts:
[[200. 230. 202. ... 194. 187. 207.]
 [181. 204. 228. ... 210. 196. 186.]
 [215. 190. 179. ... 205. 200. 188.]]

Verifying accuracy against ground truth...
Mismatch at time 1.0, neuron 0: Expected 210, got 200.0
Mismatch at time 1.0, neuron 1: Expected 227, got 230.0
Mismatch at time 1.0, neuron 2: Expected 195, got 202.0
Mismatch at time 1.0, neuron 3: Expected 215, got 219.0
Mismatch at time 1.0, neuron 4: Expected 196, got 191.0
Mismatch at time 1.0, neuron 5: Expected 228, got 235.0
Mismatch at time 1.0, neuron 6: Expected 185, got 199.0
Mismatch at time 1.0, neuron 7: Expected 211, got 214.0
Mismatch at time 1.0, neuron 8: Expected 225, got 212.0
Mismatch at time 1.0, neuron 9: Expected 175, got 181.0
Mismatch at time 1.0, neuron 10: Expected 209, got 187.0
Mismatch at time 1.0, neuron 11: Expected 221, got 204.0
Mismatch at time 1.0, neuron 12: Expe