In [1]:
print("Working...")

Working...


In [2]:
# from experanto.interpolators import SpikesInterpolator as Interpolator
from experanto.interpolators import Interpolator
from experanto.intervals import TimeInterval

import numpy as np
import yaml
import tempfile
import shutil
from pathlib import Path
from numba import njit, prange
import time

In [None]:
import numpy as np
from numba import njit, prange
from pathlib import Path

# --- THE ENGINE ---
# 'parallel=True' allows it to use all CPU cores.
@njit(parallel=True, fastmath=True)
def fast_count_spikes(all_spikes, indices, window_starts, window_ends, out_counts):
    """
    all_spikes: 1D array (1.1GB)
    indices: 1D array (38k) - start/end of each neuron in all_spikes
    window_starts: 1D array (BatchSize) - start times for the query
    window_ends: 1D array (BatchSize)
    out_counts: 2D array (BatchSize, N_Neurons) - Result placeholder
    """
    n_batch = len(window_starts)
    n_neurons = len(indices) - 1
    
    # We parallelize the OUTER loop (the batch). 
    # Or we can parallelize the NEURON loop. 
    # Since N_Neurons (38k) > Batch (e.g. 128), parallelizing neurons is better.
    
    for i in prange(n_neurons):
        # 1. Get the slice for this neuron
        # (This is zero-copy in Numba)
        idx_start = indices[i]
        idx_end = indices[i+1]
        neuron_spikes = all_spikes[idx_start:idx_end]
        
        # 2. Check all time windows for this neuron
        # Since spikes are sorted, we use binary search
        for b in range(n_batch):
            t0 = window_starts[b]
            t1 = window_ends[b]
            
            # Binary Search
            # np.searchsorted is supported natively in Numba
            # It finds where t0 and t1 would fit in the sorted array
            c_start = np.searchsorted(neuron_spikes, t0)
            c_end = np.searchsorted(neuron_spikes, t1)
            
            out_counts[b, i] = c_end - c_start

# --- THE CLASS ---
class SpikesInterpolator(Interpolator):
    def __init__(
            self, 
            root_folder: str,
            cache_data: bool = False,
            interpolation_window: float = 0.3,
            interpolation_align: str = "center",
            load_to_ram: bool = False,
            ):
        super().__init__(root_folder)

        meta = self.load_meta()

        self.start_time = meta.get("start_time", 0)
        self.end_time = meta.get("end_time", np.inf)
        self.valid_interval = TimeInterval(self.start_time, self.end_time)

        self.cache_trials = cache_data
        self.interpolation_window = interpolation_window
        self.interpolation_align = interpolation_align

        # Use self.root_folder, defined in the base class
        self.dat_path = self.root_folder / "spikes.npy"
        
        # Ensure indices are typed correctly for Numba
        self.indices = np.array(meta["spike_indices"]).astype(np.int64)
        self.n_signals = len(self.indices) - 1

        if load_to_ram:
            print("Loading spikes to RAM...")
            self.spikes = np.fromfile(self.dat_path, dtype='float64')
        else:
            self.spikes = np.memmap(self.dat_path, dtype='float64', mode='r')

    def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
        # 1. Filter for valid times
        valid = self.valid_times(times)
        valid_times = times[valid]
        
        # Handle edge case where no times are valid
        if len(valid_times) == 0:
            return np.empty((0, self.n_signals)), valid

        valid_times += 1e-4

        # 2. Prepare boundaries
        if self.interpolation_align == "center":
            starts = valid_times - self.interpolation_window / 2
            ends   = valid_times + self.interpolation_window / 2
        elif self.interpolation_align == "left":
            starts = valid_times
            ends   = valid_times + self.interpolation_window
        elif self.interpolation_align == "right":
            starts = valid_times - self.interpolation_window
            ends   = valid_times
        else:
            raise ValueError(f"Unknown alignment mode: {self.interpolation_align}")

        # 3. Prepare Output
        # SIZE FIX: Only allocate for the VALID batch size
        batch_size = len(valid_times)
        counts = np.zeros((batch_size, self.n_signals), dtype=np.float64)
        
        # 4. Call Numba Engine
        fast_count_spikes(self.spikes, self.indices, starts, ends, counts)
        
        # SIGNATURE FIX: Return both data and the mask
        return counts, valid

In [4]:
print("Loaded SpikesInterpolator")

Loaded SpikesInterpolator


In [5]:
# ==========================================
# 2. DATA GENERATION & TESTING
# ==========================================

def create_dummy_data(folder_path, n_neurons=50, duration=100.0, rate=20):
    """Creates synthetic spike data and metadata."""
    folder = Path(folder_path)
    folder.mkdir(parents=True, exist_ok=True)
    
    all_spikes = []
    indices = [0]
    
    for _ in range(n_neurons):
        n_spikes = int(duration * rate)
        # Random spikes sorted
        spikes = np.sort(np.random.uniform(0, duration, n_spikes))
        all_spikes.append(spikes)
        indices.append(indices[-1] + len(spikes))
        
    flat_spikes = np.concatenate(all_spikes)
    flat_spikes.tofile(folder / "spikes.npy")
    
    meta = {
        "modality": "spikes",
        "n_signals": n_neurons,
        "spike_indices": indices,
        "start_time": 0.0,
        "end_time": duration,
        "sampling_rate": 1000.0 
    }
    with open(folder / "meta.yml", "w") as f:
        yaml.dump(meta, f)
        
    return all_spikes, flat_spikes, indices

In [26]:
# Setup temporary directory
temp_dir = tempfile.mkdtemp()

try:
    print(f"Test Environment created at: {temp_dir}")
    
    # 1. Create Data
    gt_spikes_list, gt_flat, gt_indices = create_dummy_data(temp_dir)
    
    # 2. Instantiate Interpolator
    # We use a window of 1.0s to make manual verification easy
    interpolator = SpikesInterpolator(temp_dir, interpolation_window=1.0, interpolation_align="center")
    
    # 3. Define Query Times
    # Time 5.0 -> Window [4.5, 5.5]
    query_times = np.array([1.0, 5.0, 9.0]) 
    query_times = np.sort(np.random.uniform(1.0, 1000.0, size=100))  # 100 random query times
    print(f"Query Times: {query_times}")
    
    # 4. Run Interpolation
    print("Running interpolation...")
    counts, valid = interpolator.interpolate(query_times)
    
    print(f"Output Shape: {counts.shape}")
    print(f"Output Counts:\n{counts}")

    # 5. Verify Correctness (Slow Check)
    print("\nVerifying accuracy against ground truth...")
    window = 1.0
    errors = 0
    total_checks = 0
    
    for t_idx, t in enumerate(query_times):
        t_start = t - window/2
        t_end = t + window/2
        
        for n_idx in range(len(gt_spikes_list)):
            total_checks += 1
            # Manual count using the original list
            neuron_spikes = gt_spikes_list[n_idx]
            manual_count = np.sum((neuron_spikes >= t_start) & (neuron_spikes < t_end))
            
            numba_count = counts[t_idx, n_idx]
            
            if manual_count != numba_count:
                print(f"Mismatch at time {t}, neuron {n_idx}: Expected {manual_count}, got {numba_count}")
                errors += 1
                
    if errors == 0:
        print("SUCCESS: All counts match exactly.")
    else:
        print(f"FAILED: {errors}/{total_checks} mismatches found. {errors/total_checks*100:.2f}%")

finally:
    # Cleanup
    shutil.rmtree(temp_dir)

Test Environment created at: /mnt/lustre-grete/tmp/u18196/tmphihyfba2
Query Times: [  3.50023038  36.49095793  38.16130216  46.2304197   53.75492245
  54.92828752  76.8985505   77.79036471 145.01584162 146.48955694
 168.20422344 171.82960657 177.97242342 178.60321715 186.13890857
 195.19165015 209.96918032 210.11792344 236.64218954 245.96599061
 259.70832701 263.24620955 265.49658938 270.28946855 278.10987713
 285.6027504  285.6467078  287.02912433 298.63076316 306.84322179
 323.27555511 324.3676068  329.54467709 332.16290704 335.2584511
 352.26588251 359.24612799 380.05049274 390.92145254 393.77305273
 398.78024962 408.16446144 408.18004048 420.25153471 424.32825405
 435.15417952 438.044676   445.87932196 448.70185216 459.84935268
 461.67695086 463.04016061 473.30555321 490.71645631 523.43517979
 524.19936891 533.56551954 545.84738164 556.60822923 571.0013136
 580.67333565 586.23418273 589.5733758  593.32394789 596.04500191
 602.72315072 625.43961154 632.97624584 640.95432006 652.7765

In [33]:
# ==========================================
# 3. TEST RUNNER (ALIGNMENT + SPEED)
# ==========================================

temp_dir = tempfile.mkdtemp()

try:
    print(f"Test Environment: {temp_dir}")
    gt_spikes_list, gt_flat, gt_indices = create_dummy_data(temp_dir, n_neurons=5000, duration=1000.0)
    
    # Define query times (randomly sampled)
    n_queries = 1000
    query_times = np.sort(np.random.uniform(1.0, 999.0, size=n_queries))
    window_size = 0.5
    
    # alignments_to_test = ["center"]
    alignments_to_test = ["center", "left", "right"]
    
    print(f"\n{'='*60}")
    print(f"STARTING TESTS: {n_queries} queries, Window={window_size}s")
    print(f"{'='*60}")

    for align in alignments_to_test:
        print(f"\n>>> Testing Alignment: {align.upper()}")
        
        # 1. Instantiate
        interpolator = SpikesInterpolator(
            temp_dir, 
            interpolation_window=window_size, 
            interpolation_align=align
        )
        
        # 2. Run & Time
        # Warmup (optional, to compile JIT)
        _ = interpolator.interpolate(query_times[:10])
        
        start_t = time.perf_counter()
        counts, valid = interpolator.interpolate(query_times)
        end_t = time.perf_counter()
        
        duration_sec = end_t - start_t
        speed_qps = n_queries / duration_sec
        
        print(f"Time: {duration_sec*1000:.2f} ms")
        print(f"Speed: {speed_qps:.0f} queries/sec")
        
        # 3. Verify Correctness
        print("Verifying accuracy...")
        errors = 0
        big_errors = 0
        total_checks = 0
        
        for t_idx, t in enumerate(query_times):
            # Adjust ground truth logic based on alignment
            if align == "center":
                t_start = t - window_size/2
                t_end   = t + window_size/2
            elif align == "left":
                t_start = t
                t_end   = t + window_size
            elif align == "right":
                t_start = t - window_size
                t_end   = t
                
            for n_idx in range(len(gt_spikes_list)):
                total_checks += 1
                neuron_spikes = gt_spikes_list[n_idx]
                # Ground truth count
                manual_count = np.sum((neuron_spikes >= t_start) & (neuron_spikes < t_end))
                
                numba_count = counts[t_idx, n_idx]
                
                if manual_count != numba_count:
                    # Print only the first error to avoid spamming
                    if errors%1000 == 0:
                        print(f"Mismatch at time {t:.2f}, neuron {n_idx}: Expected {manual_count}, got {numba_count}")
                    errors += 1
                    if abs(manual_count - numba_count) > 1:
                        big_errors += 1

        if errors == 0:
            print("SUCCESS: All counts match.")
        else:
            print(f"FAILED: {errors}/{total_checks} mismatches found. {errors/total_checks*100:.2f}%")
            print(f"Large Errors (>1 count difference): {big_errors}")

finally:
    shutil.rmtree(temp_dir)
    print(f"\n{'='*60}")
    print("Cleanup complete.")

Test Environment: /mnt/lustre-grete/tmp/u18196/tmpe7hu7mym

STARTING TESTS: 1000 queries, Window=0.5s

>>> Testing Alignment: CENTER
Time: 316.91 ms
Speed: 3155 queries/sec
Verifying accuracy...
Mismatch at time 3.06, neuron 214: Expected 23, got 24.0
Mismatch at time 45.66, neuron 1243: Expected 13, got 12.0
Mismatch at time 94.01, neuron 172: Expected 8, got 9.0
Mismatch at time 139.38, neuron 3738: Expected 11, got 12.0
Mismatch at time 196.35, neuron 4316: Expected 12, got 11.0
Mismatch at time 231.95, neuron 2538: Expected 9, got 10.0
Mismatch at time 291.18, neuron 2515: Expected 14, got 13.0
Mismatch at time 343.45, neuron 3928: Expected 15, got 14.0
Mismatch at time 407.50, neuron 3860: Expected 11, got 10.0
Mismatch at time 444.32, neuron 2907: Expected 10, got 11.0
Mismatch at time 494.04, neuron 2428: Expected 7, got 8.0
Mismatch at time 540.11, neuron 3895: Expected 7, got 8.0
Mismatch at time 585.65, neuron 1185: Expected 6, got 7.0
Mismatch at time 628.26, neuron 900: Exp

In [6]:
# ==========================================
# 3. TEST RUNNER (ALIGNMENT + SPEED)
# ==========================================

import sys


temp_dir = tempfile.mkdtemp()
temp_dir = "/mnt/vast-nhr/projects/nix00014/goirik/data/dummy_data"

try:
    duration=1000.0
    print(f"Test Environment: {temp_dir}")
    start = time.perf_counter()
    gt_spikes_list, gt_flat, gt_indices = create_dummy_data(temp_dir, n_neurons=5000, duration=duration)
    end = time.perf_counter()
    print(f"Data creation time: {end - start:.2f} seconds")

    print(f"Size of gt_spikes_list in GB: {sys.getsizeof(gt_spikes_list)*1e-9}")
    print(f"Size of gt_flat in GB: {sys.getsizeof(gt_flat)*1e-9}")
    print(f"Size of gt_indices in GB: {sys.getsizeof(gt_indices)*1e-9}")
    
    # Define query times (randomly sampled)
    n_queries = 1000
    query_times = np.sort(np.random.uniform(1.0, duration, size=n_queries))
    window_size = 0.5
    
    alignments_to_test = ["center"]
    # alignments_to_test = ["center", "left", "right"]
    
    print(f"\n{'='*60}")
    print(f"STARTING TESTS: {n_queries} queries, Window={window_size}s")
    print(f"{'='*60}")

    for align in alignments_to_test:
        print(f"\n>>> Testing Alignment: {align.upper()}")
        
        # 1. Instantiate
        interpolator = SpikesInterpolator(
            temp_dir, 
            interpolation_window=window_size, 
            interpolation_align=align
        )
        
        # 2. Run & Time
        # Warmup (optional, to compile JIT)
        _ = interpolator.interpolate(query_times[:10])
        
        start_t = time.perf_counter()
        counts, valid = interpolator.interpolate(query_times)
        end_t = time.perf_counter()
        
        duration_sec = end_t - start_t
        speed_qps = n_queries / duration_sec
        
        print(f"Time: {duration_sec*1000:.2f} ms")
        print(f"Speed: {speed_qps:.0f} queries/sec")

        easdgs
        
        # 3. Verify Correctness
        print("Verifying accuracy...")
        errors = 0
        big_errors = 0
        total_checks = 0
        
        for t_idx, t in enumerate(query_times):
            # Adjust ground truth logic based on alignment
            if align == "center":
                t_start = t - window_size/2
                t_end   = t + window_size/2
            elif align == "left":
                t_start = t
                t_end   = t + window_size
            elif align == "right":
                t_start = t - window_size
                t_end   = t
                
            for n_idx in range(len(gt_spikes_list)):
                total_checks += 1
                neuron_spikes = gt_spikes_list[n_idx]
                # Ground truth count
                manual_count = np.sum((neuron_spikes >= t_start) & (neuron_spikes < t_end))
                
                numba_count = counts[t_idx, n_idx]
                
                if manual_count != numba_count:
                    # Print only the first error to avoid spamming
                    if errors%1000 == 0:
                        print(f"Mismatch at time {t:.2f}, neuron {n_idx}: Expected {manual_count}, got {numba_count}")
                    errors += 1
                    if abs(manual_count - numba_count) > 1:
                        big_errors += 1

        if errors == 0:
            print("SUCCESS: All counts match.")
        else:
            print(f"FAILED: {errors}/{total_checks} mismatches found. {errors/total_checks*100:.2f}%")
            print(f"Large Errors (>1 count difference): {big_errors}")

finally:
    shutil.rmtree(temp_dir)
    print(f"\n{'='*60}")
    print("Cleanup complete.")

Test Environment: /mnt/vast-nhr/projects/nix00014/goirik/data/dummy_data
Data creation time: 3.11 seconds
Size of gt_spikes_list in GB: 4.1880000000000006e-05
Size of gt_flat in GB: 0.8000001120000001
Size of gt_indices in GB: 4.1880000000000006e-05

STARTING TESTS: 1000 queries, Window=0.5s

>>> Testing Alignment: CENTER
Time: 220.79 ms
Speed: 4529 queries/sec


OSError: [Errno 39] Directory not empty: '/mnt/vast-nhr/projects/nix00014/goirik/data/dummy_data'

In [None]:
# create_dummy_data(temp_dir, n_neurons=38150, duration=1683227.0)


In [13]:
type(gt_spikes_list), len(gt_spikes_list), type(gt_flat), gt_flat.shape, type(gt_indices), len(gt_indices)

(list, 5000, numpy.ndarray, (100000000,), list, 5001)

In [None]:
# ==========================================
# 3. TEST RUNNER (ALIGNMENT + SPEED)
# ==========================================
# temp_dir = "/mnt/vast-nhr/projects/nix00014/goirik/data/dummy_data"
temp_dir = "/mnt/vast-nhr/projects/nix00014/goirik/mozaik-models/experanto/data"
duration=1000.0

# Define query times (randomly sampled)
n_queries = 1000
query_times = np.sort(np.random.uniform(1.0, duration, size=n_queries))
window_size = 0.5

# alignments_to_test = ["center"]
alignments_to_test = ["center", "left", "right"]

print(f"\n{'='*60}")
print(f"STARTING TESTS: {n_queries} queries, Window={window_size}s")
print(f"{'='*60}")

for align in alignments_to_test:
    print(f"\n>>> Testing Alignment: {align.upper()}")
    
    # 1. Instantiate
    interpolator = SpikesInterpolator(
        temp_dir, 
        interpolation_window=window_size, 
        interpolation_align=align
    )
    
    # 2. Run & Time
    # Warmup (optional, to compile JIT)
    _ = interpolator.interpolate(query_times[:10])
    
    start_t = time.perf_counter()
    counts, valid = interpolator.interpolate(query_times)
    end_t = time.perf_counter()
    
    duration_sec = end_t - start_t
    speed_qps = n_queries / duration_sec
    
    print(f"Time: {duration_sec*1000:.2f} ms")
    print(f"Speed: {speed_qps:.0f} queries/sec")



STARTING TESTS: 1000 queries, Window=0.5s

>>> Testing Alignment: CENTER
Time: 389.29 ms
Speed: 2569 queries/sec
