# Data loading & neuron analysis

This notebook covers the essential first steps of working with **DRIADA** --
loading calcium imaging data, inspecting core objects, reconstructing spikes,
and assessing recording quality.

**What you will learn:**

1. **Loading your data** -- wrap numpy arrays (from Suite2P, CaImAn, DeepLabCut, etc.) into a DRIADA `Experiment`.
2. **Single neuron analysis** -- create a `Neuron`, reconstruct spikes, optimize kinetics, compute quality metrics, and generate surrogates.
3. **Threshold vs wavelet reconstruction** -- compare two spike detection methods across four optimization modes.
4. **Method agreement** -- quantify event-region overlap between threshold and wavelet at varying tolerance.

In [None]:
!pip install -q driada
%matplotlib inline

import os
import time
import tempfile
import warnings

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

from driada.experiment import (
    load_exp_from_aligned_data,
    save_exp_to_pickle,
    load_exp_from_pickle,
    generate_synthetic_exp,
)
from driada.experiment.neuron import Neuron
from driada.experiment.synthetic import generate_pseudo_calcium_signal

## 1. Loading your data into DRIADA

You have numpy arrays from your recording pipeline (Suite2P, CaImAn,
DeepLabCut, etc.).  DRIADA wraps them into an **Experiment** object that
keeps neural activity and behavioral features aligned and annotated.

The only required key is `'calcium'` -- a `(n_neurons, n_frames)` array of
fluorescence traces.  Everything else you pass becomes a **dynamic feature**
(one value per timepoint).

In [None]:
# In practice: raw = np.load('your_recording.npz')
# Here we generate synthetic arrays that mimic a real recording.

np.random.seed(0)
n_neurons, n_frames = 50, 10000
fps = 30.0

calcium = np.random.randn(n_neurons, n_frames) * 0.1          # (50, 10000)
x_pos = np.cumsum(np.random.randn(n_frames) * 0.5)            # continuous
y_pos = np.cumsum(np.random.randn(n_frames) * 0.5)            # continuous
speed = np.abs(np.random.randn(n_frames)) * 5.0               # continuous
head_direction = np.random.uniform(0, 2 * np.pi, n_frames)    # circular (radians)
trial_type = np.random.choice([0, 1, 2], size=n_frames)       # discrete labels

print(f"calcium:        shape={calcium.shape}, dtype={calcium.dtype}")
print(f"x_pos:          shape={x_pos.shape}, dtype={x_pos.dtype}")
print(f"y_pos:          shape={y_pos.shape}, dtype={y_pos.dtype}")
print(f"speed:          shape={speed.shape}, dtype={speed.dtype}")
print(f"head_direction: shape={head_direction.shape}, dtype={head_direction.dtype}")
print(f"trial_type:     shape={trial_type.shape}, dtype={trial_type.dtype}")

### Feature types and aggregation

DRIADA auto-detects whether each feature is **continuous** or **discrete**.
You can override the detection with a `feature_types` dict.  Valid type
strings include: `continuous`, `circular`, `categorical`, `binary`, `count`.

`aggregate_features` groups related 1D features into a single
`MultiTimeSeries` (e.g. x_pos + y_pos -> position_2d).

`create_circular_2d=True` (the default) auto-creates a `(cos, sin)` encoding
for every circular feature.  This is important because MI estimators (GCMI,
KSG) assume the real line -- a raw angle wraps at 0 / 2*pi, breaking the
distance metric.  The `(cos, sin)` encoding maps the circle onto R^2 where
Euclidean distance is meaningful.

In [None]:
# Build the data dict
data = {
    # --- neural activity (required) --------------------------------
    "calcium": calcium,               # (50, 10000)
    # "spikes": my_spikes_array,      # optional, same shape as calcium
    # --- dynamic features: behavioral variables (one per timepoint) -
    "x_pos": x_pos,                   # continuous
    "y_pos": y_pos,                   # continuous
    "speed": speed,                   # continuous
    "head_direction": head_direction,  # circular (radians)
    "trial_type": trial_type,         # discrete labels
}

# Override auto-detected feature types (optional)
feature_types = {
    "head_direction": "circular",   # auto-detection may miss this
    "trial_type": "categorical",    # refine from generic discrete
}

# Aggregate multi-component features (optional)
aggregate_features = {
    ("x_pos", "y_pos"): "position_2d",
}

# Build the Experiment
exp = load_exp_from_aligned_data(
    data_source="MyLab",
    exp_params={"name": "demo_recording"},
    data=data,
    feature_types=feature_types,
    aggregate_features=aggregate_features,
    static_features={"fps": 30.0},
    # create_circular_2d=True is the default: for every circular
    # feature (here head_direction), DRIADA auto-creates a _2d
    # version as (cos, sin). This is important because MI estimators
    # (GCMI, KSG) work on the real line -- a raw angle wraps around
    # at 0/2pi, breaking the distance metric. The (cos, sin) encoding
    # maps the circle onto R^2 where Euclidean distance is meaningful.
    create_circular_2d=True,
    verbose=True,
)

### Inspecting the Experiment

Note the auto-generated features in the list below:
- **position_2d** -- from `aggregate_features` (x_pos + y_pos)
- **head_direction_2d** -- from `create_circular_2d` (cos + sin encoding)

In [None]:
print(f"Neurons:     {exp.n_cells}")
print(f"Timepoints:  {exp.n_frames}")
print(f"FPS:         {exp.static_features.get('fps', 'unknown')}")
print(f"Calcium:     {exp.calcium.data.shape}")

# Note the auto-generated features in the list below:
#   - position_2d:        from aggregate_features (x_pos + y_pos)
#   - head_direction_2d:  from create_circular_2d (cos + sin encoding)
print("\nDynamic features (time-varying behavioral variables):")
for name, ts in sorted(exp.dynamic_features.items()):
    ti = getattr(ts, "type_info", None)
    if ti and hasattr(ti, "primary_type"):
        dtype_str = f"{ti.primary_type}/{ti.subtype}"
        if ti.is_circular:
            dtype_str += " (circular)"
    else:
        dtype_str = "discrete" if ts.discrete else "continuous"
    shape = ts.data.shape
    print(f"  {name:25s}  shape={str(shape):15s}  type={dtype_str}")

### TimeSeries and MultiTimeSeries

Each dynamic feature is stored as one of two classes:

| Class | Description |
|---|---|
| **TimeSeries** | A single 1D variable (e.g. `speed`) |
| **MultiTimeSeries** | Multiple aligned 1D variables stacked into a 2D array (e.g. `position_2d = [x, y]`) |

Key attributes on both:
- `.data` -- raw numpy array (1D or 2D)
- `.discrete` -- True if discrete, False if continuous
- `.type_info` -- rich type metadata (subtype, circularity)
- `.copula_normal_data` -- GCMI-ready transform (continuous only)
- `.int_data` -- integer-coded values (discrete only)

MultiTimeSeries additionally has `.ts_list` (list of component TimeSeries)
and `.n_dim` (number of components).

In [None]:
# Features are accessible as attributes: exp.speed, exp.position_2d, etc.
# This is equivalent to exp.dynamic_features["speed"].
speed_ts = exp.speed
print(f"speed.data.shape:   {speed_ts.data.shape}")
print(f"speed.discrete:     {speed_ts.discrete}")
print(f"speed.type_info:    {speed_ts.type_info.primary_type}"
      f"/{speed_ts.type_info.subtype}")
print(f"speed has copula:   {speed_ts.copula_normal_data is not None}")

# Access a 2D feature (MultiTimeSeries)
pos_mts = exp.position_2d
print(f"\nposition_2d.data.shape: {pos_mts.data.shape}")
print(f"position_2d.n_dim:      {pos_mts.n_dim}  (x and y)")
# Individual components are full TimeSeries objects:
print(f"position_2d.ts_list[0]: {pos_mts.ts_list[0].name}"
      f"  shape={pos_mts.ts_list[0].data.shape}")

# Discrete feature
trial_ts = exp.trial_type
print(f"\ntrial_type.discrete:  {trial_ts.discrete}")
print(f"trial_type.int_data:  {trial_ts.int_data[:8]}...")
print(f"trial_type has copula: {trial_ts.copula_normal_data is not None}")

### Batch spike reconstruction

`reconstruct_all_neurons()` applies the same reconstruction method across
the whole population.  After reconstruction, per-neuron quality metrics
(wavelet SNR, R-squared, event counts) are available.

In [None]:
exp.reconstruct_all_neurons(method='threshold', n_iter=3, show_progress=True)
print(f"[OK] Reconstructed spikes for {exp.n_cells} neurons")

# Collect per-neuron quality metrics
snr_list = []
r2_list = []
event_counts = []

for n in exp.neurons:
    snr_list.append(n.get_wavelet_snr())
    r2_list.append(n.get_reconstruction_r2())
    event_counts.append(n.get_event_count())

snr_arr = np.array(snr_list)
r2_arr = np.array(r2_list)
evt_arr = np.array(event_counts)

print(f"\nPopulation quality summary ({exp.n_cells} neurons):")
print(f"  Wavelet SNR:  {np.mean(snr_arr):.2f} +/- {np.std(snr_arr):.2f}"
      f"  (range {np.min(snr_arr):.2f} - {np.max(snr_arr):.2f})")
print(f"  Recon R2:     {np.mean(r2_arr):.4f} +/- {np.std(r2_arr):.4f}"
      f"  (range {np.min(r2_arr):.4f} - {np.max(r2_arr):.4f})")
print(f"  Event count:  {np.mean(evt_arr):.1f} +/- {np.std(evt_arr):.1f}"
      f"  (range {np.min(evt_arr)} - {np.max(evt_arr)})")

### Neural data access

Neural activity is stored in two complementary ways:

| View | Description |
|---|---|
| `exp.calcium` | `MultiTimeSeries` (n_neurons, n_frames) -- convenient for population-level analysis (DR, RSA, decoding) |
| `exp.neurons` | List of `Neuron` objects -- for single-cell analysis (reconstruction, kinetics, quality) |

In [None]:
# Population-level: full calcium matrix as MultiTimeSeries
print(f"exp.calcium:        {type(exp.calcium).__name__}"
      f"  shape={exp.calcium.data.shape}")
has_spikes = exp.spikes is not None and exp.spikes.data.any()
print(f"exp.spikes:         {'available' if has_spikes else 'not provided'}")

# Single-neuron level: list of Neuron objects
neuron = exp.neurons[0]
print(f"\nexp.neurons:        {len(exp.neurons)} Neuron objects")
print(f"neuron.cell_id:     {neuron.cell_id}")
print(f"neuron.ca:          {type(neuron.ca).__name__}"
      f"  shape={neuron.ca.data.shape}")
print(f"neuron.sp:          "
      f"{'shape=' + str(neuron.sp.data.shape) if neuron.sp else 'None (no spikes provided)'}")
print(f"neuron.fps:         {neuron.fps}")
# See Section 2 for spike reconstruction, event detection,
# kinetics optimization, and other Neuron methods.

### Save and reload

The entire Experiment (neural data + features + metadata) can be serialized
with pickle for fast roundtrip storage.

In [None]:
pkl_path = os.path.join(tempfile.gettempdir(), "demo_experiment.pkl")
save_exp_to_pickle(exp, pkl_path, verbose=False)
file_size_mb = os.path.getsize(pkl_path) / 1024 / 1024
print(f"Saved:  {pkl_path} ({file_size_mb:.1f} MB)")

exp_loaded = load_exp_from_pickle(pkl_path, verbose=False)
print(f"Loaded: {exp_loaded.n_cells} neurons, {exp_loaded.n_frames} frames")

# Verify roundtrip
assert exp_loaded.n_cells == exp.n_cells
assert exp_loaded.n_frames == exp.n_frames
assert np.allclose(exp_loaded.calcium.data, exp.calcium.data)
print("Roundtrip verified -- data matches.")

# Clean up
os.remove(pkl_path)
print(f"Cleaned up {pkl_path}")

## 2. Single neuron analysis

Deep dive into individual neuron quality: generate a synthetic calcium
signal, create a `Neuron` object, reconstruct spikes, optimize kinetics,
compute quality metrics, and generate surrogates for null-hypothesis testing.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# =============================================================================
# Step 1: Generate Synthetic Calcium Signal
# =============================================================================
print("1. Generating synthetic calcium signal...")

signal = generate_pseudo_calcium_signal(
    duration=200.0,              # Signal duration in seconds
    sampling_rate=20.0,          # Sampling rate (Hz)
    event_rate=0.15,             # Average event rate (Hz)
    amplitude_range=(1.0, 2.0),  # Event amplitude range (dF/F0)
    decay_time=1.5,              # Calcium decay time constant (seconds)
    rise_time=0.15,              # Calcium rise time constant (seconds)
    noise_std=0.05,              # Additive Gaussian noise level
    kernel='double_exponential'  # Realistic calcium kernel
)

print(f"   [OK] Generated signal: {len(signal)} frames ({len(signal)/20:.1f} seconds)")

# =============================================================================
# Step 2: Create Neuron Object
# =============================================================================
print("\n2. Creating Neuron object...")

neuron = Neuron(
    cell_id='example_neuron',
    ca=signal,              # Calcium signal
    sp=None,                # No ground-truth spikes (will be reconstructed)
    fps=20.0                # Sampling rate
)

print(f"   [OK] Neuron created: {neuron.cell_id}")
print(f"   [OK] Signal length: {neuron.n_frames} frames")
print(f"   [OK] Sampling rate: {neuron.fps} Hz")

### Spike reconstruction

The **wavelet** method detects calcium transient events via CWT (continuous
wavelet transform) ridge analysis.

In [None]:
# =============================================================================
# Step 3: Reconstruct Spikes with Wavelet Method
# =============================================================================
print("3. Reconstructing spikes using wavelet method...")

spikes = neuron.reconstruct_spikes(
    method='wavelet',
    create_event_regions=True  # Create event regions for quality metrics
)

n_events = int(np.sum(neuron.asp.data > 0))
print(f"   [OK] Detected {n_events} calcium events")
print(f"   [OK] Spike train stored in neuron.sp")
print(f"   [OK] Amplitude spikes stored in neuron.asp")

### Kinetics optimization

Fit rise and decay time constants to detected events using the **direct
measurement** method.

In [None]:
# =============================================================================
# Step 4: Optimize Calcium Kinetics
# =============================================================================
print("4. Optimizing calcium kinetics...")

kinetics = neuron.get_kinetics(
    method='direct',           # Direct measurement from detected events
    use_cached=False          # Force recomputation
)

print(f"   [OK] Optimized rise time (t_rise): {kinetics['t_rise']:.3f} seconds")
print(f"   [OK] Optimized decay time (t_off): {kinetics['t_off']:.3f} seconds")
print(f"   [OK] Events used: {kinetics['n_events_detected']}")

### Quality metrics

- **Wavelet SNR** -- ratio of event amplitude to baseline noise
- **R-squared** -- reconstruction quality (1.0 = perfect, >0.7 = good)
- **Event-only R-squared** -- quality restricted to event regions
- **NRMSE** -- normalized root mean squared error (lower is better)
- **NMAE** -- normalized mean absolute error (lower is better)

In [None]:
# =============================================================================
# Step 5: Calculate Wavelet SNR
# =============================================================================
print("5. Computing wavelet SNR...")

wavelet_snr = neuron.get_wavelet_snr()

print(f"   [OK] Wavelet SNR: {wavelet_snr:.2f}")
print(f"       (Ratio of event amplitude to baseline noise)")

# =============================================================================
# Step 6: Calculate Reconstruction Quality Metrics
# =============================================================================
print("\n6. Computing reconstruction quality metrics...")

# R2 (coefficient of determination)
r2 = neuron.get_reconstruction_r2()
print(f"   [OK] Reconstruction R2: {r2:.4f}")
print(f"       (1.0 = perfect, >0.7 = good quality)")

# Event-only R2 (focuses on event regions)
r2_events = neuron.get_reconstruction_r2(event_only=True)
print(f"   [OK] Event-only R2: {r2_events:.4f}")
print(f"       (Quality in event regions only)")

# Normalized RMSE
nrmse = neuron.get_nrmse()
print(f"   [OK] Normalized RMSE: {nrmse:.4f}")
print(f"       (Lower is better)")

# Normalized MAE
nmae = neuron.get_nmae()
print(f"   [OK] Normalized MAE: {nmae:.4f}")
print(f"       (Lower is better)")

### Surrogate generation

Four surrogate methods for null-hypothesis testing:

| Method | Type | Preserves |
|---|---|---|
| **roll_based** | Calcium | Autocorrelation structure, amplitude distribution |
| **waveform_based** | Calcium | Individual waveform shapes, event count |
| **chunks_based** | Calcium | Local structure within chunks |
| **isi_based** | Spikes | Inter-spike interval distribution |

In [None]:
# =============================================================================
# Step 7: Surrogate Generation Methods
# =============================================================================
print("7. Surrogate generation methods...")
print("   Three calcium surrogate types and one spike surrogate type.")

# --- Calcium surrogates ---

# 7a. Roll-based: circular shift preserving all autocorrelations
shuffled_roll = neuron.get_shuffled_calcium(method='roll_based', seed=42)
print(f"\n   [Roll-based] Circular shift surrogate:")
print(f"       Mean: {np.mean(shuffled_roll):.4f}  (original: {np.mean(neuron.ca.data):.4f})")
print(f"       Std:  {np.std(shuffled_roll):.4f}  (original: {np.std(neuron.ca.data):.4f})")
print(f"       Preserves: autocorrelation structure, amplitude distribution")

# 7b. Waveform-based: shuffle detected spike times, reconstruct calcium
shuffled_wf = neuron.get_shuffled_calcium(method='waveform_based', seed=42)
print(f"\n   [Waveform-based] Spike-shuffle + reconstruct surrogate:")
print(f"       Mean: {np.mean(shuffled_wf):.4f}  (original: {np.mean(neuron.ca.data):.4f})")
print(f"       Std:  {np.std(shuffled_wf):.4f}  (original: {np.std(neuron.ca.data):.4f})")
print(f"       Preserves: individual waveform shapes, event count")

# 7c. Chunks-based: divide signal into chunks and reorder
shuffled_chunks = neuron.get_shuffled_calcium(method='chunks_based', seed=42)
print(f"\n   [Chunks-based] Chunk reordering surrogate:")
print(f"       Mean: {np.mean(shuffled_chunks):.4f}  (original: {np.mean(neuron.ca.data):.4f})")
print(f"       Std:  {np.std(shuffled_chunks):.4f}  (original: {np.std(neuron.ca.data):.4f})")
print(f"       Preserves: local structure within chunks")

# --- Spike surrogates ---

# 7d. ISI-based: shuffle inter-spike intervals, preserving ISI distribution
shuffled_sp = neuron.get_shuffled_spikes(method='isi_based', seed=42)
original_spike_count = int(np.sum(neuron.sp.data > 0))
shuffled_spike_count = int(np.sum(shuffled_sp > 0))
print(f"\n   [ISI-based] Spike train surrogate:")
print(f"       Spike count: {shuffled_spike_count}  (original: {original_spike_count})")
print(f"       Preserves: inter-spike interval distribution")

## 3. Threshold vs wavelet reconstruction

Two detection methods, four optimization modes:

| Mode | Description |
|---|---|
| Default kinetics | Single pass, preset rise/decay times |
| Optimized kinetics | Single pass + fit kinetics to your signal |
| Iterative n=2 + optimized | Detect -> subtract -> detect + optimize |
| Iterative n=3 + optimized | More passes for weaker events |

**Threshold** is faster; **wavelet** is more sensitive, especially for low
SNR or overlapping events.

In [None]:
def create_synthetic_neuron(duration=60.0, fps=30.0, event_rate=0.3, seed=42):
    """Generate synthetic calcium signal with known ground truth.

    Uses kinetics different from defaults to demonstrate optimization benefit.
    Default kinetics: t_rise=0.25s, t_off=2.0s
    True kinetics: t_rise=0.10s, t_off=0.8s (faster indicator)
    """
    np.random.seed(seed)

    # Kinetics faster than defaults - optimization should help
    t_rise_true = 0.10   # Faster than default 0.25s
    t_off_true = 0.8     # Faster than default 2.0s

    signal = generate_pseudo_calcium_signal(
        duration=duration,
        sampling_rate=fps,
        event_rate=event_rate,
        amplitude_range=(0.3, 1.2),
        decay_time=t_off_true,
        rise_time=t_rise_true,
        noise_std=0.04,               # Moderate noise
        kernel='double_exponential'
    )

    return signal, {'t_rise_true': t_rise_true, 't_off_true': t_off_true, 'event_rate': event_rate}


def reconstruct_with_mode(neuron, fps, method, mode_name, iterative=False, n_iter=1,
                          optimize=False, adaptive_thresholds=False):
    """Reconstruct with specified method and mode.

    Returns dict with reconstruction results.
    """
    start = time.time()

    # Suppress the default kinetics warning
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)

        if method == 'threshold':
            neuron.reconstruct_spikes(
                method='threshold',
                n_mad=4.0,                    # Balanced threshold for noisy data
                min_duration_frames=2,        # Allow shorter events
                create_event_regions=True,
                iterative=iterative,
                n_iter=n_iter,
                adaptive_thresholds=adaptive_thresholds
            )
        else:  # wavelet
            neuron.reconstruct_spikes(
                method='wavelet',
                create_event_regions=True,
                iterative=iterative,
                n_iter=n_iter,
                adaptive_thresholds=adaptive_thresholds
            )

    # Optionally optimize kinetics
    optimized = False
    if optimize:
        result = neuron.optimize_kinetics(
            method='direct',
            fps=fps,
            update_reconstruction=True,
            detection_method=method,
            # Pass through to reconstruct_spikes
            n_mad=4.0,
            iterative=iterative,
            n_iter=n_iter,
            adaptive_thresholds=adaptive_thresholds
        )
        optimized = result.get('optimized', False)

    time_total = time.time() - start

    # Get kinetics (optimization may fall back to defaults if events
    # are poorly characterized from single-pass with wrong kinetics)
    t_rise = neuron.t_rise if neuron.t_rise else neuron.default_t_rise
    t_off = neuron.t_off if neuron.t_off else neuron.default_t_off

    # Count events
    if method == 'threshold':
        n_events = len(neuron.threshold_events) if neuron.threshold_events else 0
    else:
        n_events = len(neuron.wvt_ridges) if neuron.wvt_ridges else 0

    # Quality metrics from Neuron API
    r2 = neuron.get_reconstruction_r2()
    corr = np.corrcoef(neuron.ca.data, neuron._reconstructed.data)[0, 1]

    return {
        'mode': mode_name,
        'reconstruction': neuron._reconstructed.data,
        't_rise': t_rise / fps,
        't_off': t_off / fps,
        'n_events': n_events,
        'optimized': optimized,
        'r2': r2,
        'correlation': corr,
        'time': time_total,
    }


def reconstruct_all_modes(signal, fps, method):
    """Run all 4 reconstruction modes for a given method.

    Modes:
    1. Default kinetics (single pass)
    2. Optimized kinetics (single pass)
    3. Iterative n_iter=2 + optimized kinetics
    4. Iterative n_iter=3 + optimized kinetics
    """
    results = []

    # Mode 1: Default kinetics (single pass, no optimization)
    print(f"   Mode 1: Default kinetics...")
    neuron = Neuron(cell_id=f'{method}_default', ca=signal.copy(), sp=None, fps=fps)
    results.append(reconstruct_with_mode(
        neuron, fps, method,
        mode_name='Default kinetics',
        iterative=False, n_iter=1, optimize=False
    ))

    # Mode 2: Optimized kinetics (single pass + optimization)
    print(f"   Mode 2: Optimized kinetics...")
    neuron = Neuron(cell_id=f'{method}_optimized', ca=signal.copy(), sp=None, fps=fps)
    results.append(reconstruct_with_mode(
        neuron, fps, method,
        mode_name='Optimized kinetics',
        iterative=False, n_iter=1, optimize=True
    ))

    # Mode 3: Iterative n_iter=2 + optimized kinetics
    print(f"   Mode 3: Iterative (n_iter=2) + optimized...")
    neuron = Neuron(cell_id=f'{method}_iter2', ca=signal.copy(), sp=None, fps=fps)
    results.append(reconstruct_with_mode(
        neuron, fps, method,
        mode_name='Iterative n=2 + opt',
        iterative=True, n_iter=2, optimize=True, adaptive_thresholds=True
    ))

    # Mode 4: Iterative n_iter=3 + optimized kinetics
    print(f"   Mode 4: Iterative (n_iter=3) + optimized...")
    neuron = Neuron(cell_id=f'{method}_iter3', ca=signal.copy(), sp=None, fps=fps)
    results.append(reconstruct_with_mode(
        neuron, fps, method,
        mode_name='Iterative n=3 + opt',
        iterative=True, n_iter=3, optimize=True, adaptive_thresholds=True
    ))

    return results

In [None]:
# Create synthetic neuron with non-default kinetics (t_rise=0.10s, t_off=0.8s)
print("1. Generating synthetic calcium signal...")
fps = 30.0
signal, ground_truth = create_synthetic_neuron(duration=300.0, fps=fps, seed=42)  # 5 minutes
print(f"   Signal: {len(signal)} frames ({len(signal)/fps:.1f} seconds)")
print(f"   Ground truth: t_rise={ground_truth['t_rise_true']:.3f}s, "
      f"t_off={ground_truth['t_off_true']:.3f}s")

In [None]:
# Threshold reconstruction - all 4 modes
print("2. Threshold-based reconstruction (FAST)...")
threshold_results = reconstruct_all_modes(signal, fps, method='threshold')

In [None]:
# Wavelet reconstruction - all 4 modes
print("3. Wavelet-based reconstruction (SENSITIVE)...")
wavelet_results = reconstruct_all_modes(signal, fps, method='wavelet')

In [None]:
# Print summary table
time_axis = np.arange(len(signal)) / fps
print("=" * 120)
print("RECONSTRUCTION QUALITY SUMMARY")
print("=" * 120)
print(f"{'Method':<12} {'Mode':<22} {'Events':<8} {'t_rise(s)':<11} {'t_off(s)':<11} "
      f"{'R^2':<8} {'Corr':<8} {'Opt':<6} {'Time(s)':<10}")
print("-" * 120)

for res in threshold_results:
    opt = "yes" if res['optimized'] else "-"
    print(f"{'Threshold':<12} {res['mode']:<22} {res['n_events']:<8} "
          f"{res['t_rise']:<11.3f} {res['t_off']:<11.3f} "
          f"{res['r2']:<8.4f} {res['correlation']:<8.4f} "
          f"{opt:<6} {res['time']:<10.4f}")

print("-" * 120)

for res in wavelet_results:
    opt = "yes" if res['optimized'] else "-"
    print(f"{'Wavelet':<12} {res['mode']:<22} {res['n_events']:<8} "
          f"{res['t_rise']:<11.3f} {res['t_off']:<11.3f} "
          f"{res['r2']:<8.4f} {res['correlation']:<8.4f} "
          f"{opt:<6} {res['time']:<10.4f}")

# Calculate speedup
threshold_total_time = sum(r['time'] for r in threshold_results)
wavelet_total_time = sum(r['time'] for r in wavelet_results)
speedup = wavelet_total_time / threshold_total_time

print("\n" + "=" * 80)
print(f"PERFORMANCE: Threshold is {speedup:.1f}x faster than Wavelet")
print(f"  Threshold total: {threshold_total_time:.3f}s")
print(f"  Wavelet total:   {wavelet_total_time:.3f}s")
print("=" * 80)

In [None]:
# Visualization: reconstruction traces (2 columns, 5 rows)
fig = plt.figure(figsize=(16, 12))
gs = fig.add_gridspec(5, 2, hspace=0.4, wspace=0.3)

# Column 1: Threshold method
# Original signal
ax0 = fig.add_subplot(gs[0, 0])
ax0.plot(time_axis, signal, 'k-', linewidth=1, label='Calcium signal')
ax0.set_ylabel('dF/F0')
ax0.set_title('THRESHOLD METHOD (fast)', fontweight='bold', fontsize=12)
ax0.grid(True, alpha=0.3)
ax0.legend(loc='upper right')

# Reconstruction modes for threshold
for i, res in enumerate(threshold_results):
    ax = fig.add_subplot(gs[i+1, 0])
    ax.plot(time_axis, signal, 'k-', linewidth=0.8, alpha=0.5, label='Original')
    ax.plot(time_axis, res['reconstruction'], 'b-', linewidth=1.2,
            label=f"Reconstruction (R^2={res['r2']:.3f})")

    ax.set_title(f"{res['mode']} | t_rise={res['t_rise']:.3f}s, t_off={res['t_off']:.3f}s",
                fontsize=10)
    ax.set_ylabel('dF/F0')
    if i == 3:
        ax.set_xlabel('Time (s)')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right', fontsize=8)

# Column 2: Wavelet method
# Original signal
ax0 = fig.add_subplot(gs[0, 1])
ax0.plot(time_axis, signal, 'k-', linewidth=1, label='Calcium signal')
ax0.set_ylabel('dF/F0')
ax0.set_title('WAVELET METHOD (sensitive)', fontweight='bold', fontsize=12)
ax0.grid(True, alpha=0.3)
ax0.legend(loc='upper right')

# Reconstruction modes for wavelet
for i, res in enumerate(wavelet_results):
    ax = fig.add_subplot(gs[i+1, 1])
    ax.plot(time_axis, signal, 'k-', linewidth=0.8, alpha=0.5, label='Original')
    ax.plot(time_axis, res['reconstruction'], 'r-', linewidth=1.2,
            label=f"Reconstruction (R^2={res['r2']:.3f})")

    ax.set_title(f"{res['mode']} | t_rise={res['t_rise']:.3f}s, t_off={res['t_off']:.3f}s",
                fontsize=10)
    ax.set_ylabel('dF/F0')
    if i == 3:
        ax.set_xlabel('Time (s)')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right', fontsize=8)

plt.suptitle(
    f'Spike reconstruction comparison | Ground truth: t_rise={ground_truth["t_rise_true"]:.3f}s, '
    f't_off={ground_truth["t_off_true"]:.3f}s | Speedup: {speedup:.1f}x',
    fontsize=14, fontweight='bold', y=0.995
)
plt.show()

In [None]:
# Convergence metrics (2x2 subplots)
fig2, axes = plt.subplots(2, 2, figsize=(14, 9))

modes = range(len(threshold_results))
mode_labels = ['Default', 'Optimized', 'Iter n=2', 'Iter n=3']

# R^2 comparison
ax = axes[0, 0]
ax.plot(modes, [r['r2'] for r in threshold_results], 'bo-', label='Threshold', linewidth=2, markersize=8)
ax.plot(modes, [r['r2'] for r in wavelet_results], 'ro-', label='Wavelet', linewidth=2, markersize=8)
ax.set_ylabel('R^2')
ax.set_title('Reconstruction quality (R^2)', fontweight='bold')
ax.set_xticks(list(modes))
ax.set_xticklabels(mode_labels, rotation=15)
ax.grid(True, alpha=0.3)
ax.legend()

# Correlation comparison
ax = axes[0, 1]
ax.plot(modes, [r['correlation'] for r in threshold_results], 'bo-', label='Threshold', linewidth=2, markersize=8)
ax.plot(modes, [r['correlation'] for r in wavelet_results], 'ro-', label='Wavelet', linewidth=2, markersize=8)
ax.set_ylabel('Correlation')
ax.set_title('Correlation coefficient', fontweight='bold')
ax.set_xticks(list(modes))
ax.set_xticklabels(mode_labels, rotation=15)
ax.grid(True, alpha=0.3)
ax.legend()

# Event count comparison
ax = axes[1, 0]
ax.plot(modes, [r['n_events'] for r in threshold_results], 'bo-', label='Threshold', linewidth=2, markersize=8)
ax.plot(modes, [r['n_events'] for r in wavelet_results], 'ro-', label='Wavelet', linewidth=2, markersize=8)
ax.set_ylabel('Events detected')
ax.set_title('Number of events detected', fontweight='bold')
ax.set_xticks(list(modes))
ax.set_xticklabels(mode_labels, rotation=15)
ax.grid(True, alpha=0.3)
ax.legend()

# Kinetics comparison (t_rise and t_off with ground truth)
ax = axes[1, 1]
ax.plot(modes, [r['t_rise'] for r in threshold_results], 'b^-', label='Thr t_rise', linewidth=2, markersize=8)
ax.plot(modes, [r['t_rise'] for r in wavelet_results], 'r^-', label='Wvt t_rise', linewidth=2, markersize=8)
ax.plot(modes, [r['t_off'] for r in threshold_results], 'bs-', label='Thr t_off', linewidth=2, markersize=8)
ax.plot(modes, [r['t_off'] for r in wavelet_results], 'rs-', label='Wvt t_off', linewidth=2, markersize=8)
ax.axhline(ground_truth['t_rise_true'], color='g', linestyle='--', linewidth=1.5, alpha=0.7)
ax.axhline(ground_truth['t_off_true'], color='g', linestyle='--', linewidth=1.5, alpha=0.7, label='Ground truth')
ax.set_ylabel('Time (s)')
ax.set_title('Kinetics estimation', fontweight='bold')
ax.set_xticks(list(modes))
ax.set_xticklabels(mode_labels, rotation=15)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)

plt.suptitle(f'Reconstruction metrics by mode | Speedup: {speedup:.1f}x',
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

## 4. Method agreement analysis

Given the same data, how well do threshold and wavelet agree?  Both methods
detect calcium transient regions (event start to end) but use different
signal processing: wavelet uses CWT ridge detection while threshold uses
MAD-based signal crossing.  Event-region overlap with varying tolerance
reveals timing differences between detection mechanisms.

In [None]:
# Generate 5-neuron experiment, run both methods (iterative n=3)
print("Generating synthetic calcium imaging data...")
exp4 = generate_synthetic_exp(
    n_dfeats=2, n_cfeats=1, nneurons=5, duration=120, fps=20, seed=42  # 2 minutes
)

calcium4 = exp4.calcium
fps4 = exp4.fps
n_neurons4 = calcium4.scdata.shape[0]
time4 = np.arange(calcium4.scdata.shape[1]) / fps4

# Both methods use Neuron-level iterative reconstruction (n_iter=3)
# to catch overlapping events via residual analysis.
wavelet_events = []
threshold_events = []

for neuron in exp4.neurons:
    # Wavelet: CWT ridge detection on residuals
    print(f"  Neuron {neuron.cell_id}: wavelet...", end="")
    neuron.reconstruct_spikes(
        method="wavelet", iterative=True, n_iter=3, fps=fps4
    )
    wavelet_events.append(list(neuron.wvt_ridges))

    # Threshold: MAD-based event detection on residuals
    print(" threshold...", end="")
    neuron.reconstruct_spikes(
        method="threshold", iterative=True, n_iter=3, n_mad=4.0,
        adaptive_thresholds=True, fps=fps4,
    )
    threshold_events.append(list(neuron.threshold_events))
    print(
        f" done ({len(wavelet_events[-1])} / {len(threshold_events[-1])} events)"
    )

In [None]:
# Event counts per neuron
print("=" * 50)
print("EVENT COUNTS (iterative, n_iter=3)")
print("=" * 50)
print(f"{'Neuron':<10} {'Wavelet':<14} {'Threshold':<14}")
print("-" * 50)
for i in range(n_neurons4):
    n_w = len(wavelet_events[i])
    n_t = len(threshold_events[i])
    print(f"{i:<10} {n_w:<14} {n_t:<14}")
total_w = sum(len(wavelet_events[i]) for i in range(n_neurons4))
total_t = sum(len(threshold_events[i]) for i in range(n_neurons4))
print("-" * 50)
print(f"{'Total':<10} {total_w:<14} {total_t:<14}")

In [None]:
# Event region visualization for one neuron
neuron_idx = 2

fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

# Plot calcium signal (scaled data)
ax = axes[0]
ax.plot(time4, calcium4.scdata[neuron_idx, :], "k-", linewidth=1)
ax.set_ylabel("Calcium\n(scaled)")
ax.set_title(f"Neuron {neuron_idx}: spike reconstruction comparison")
ax.grid(True, alpha=0.3)

# Plot wavelet-detected event regions
ax = axes[1]
for ev in wavelet_events[neuron_idx]:
    ax.axvspan(ev.start / fps4, ev.end / fps4, alpha=0.5, color="blue")
ax.set_ylabel("Wavelet\nEvents")
ax.set_ylim(-0.1, 1.1)
ax.grid(True, alpha=0.3)
ax.legend(
    handles=[Patch(facecolor="blue", alpha=0.5, label="Event region")],
    loc="upper right",
)

# Plot threshold-detected event regions
ax = axes[2]
for ev in threshold_events[neuron_idx]:
    ax.axvspan(ev.start / fps4, ev.end / fps4, alpha=0.5, color="red")
ax.set_ylabel("Threshold\nEvents")
ax.set_ylim(-0.1, 1.1)
ax.set_xlabel("Time (s)")
ax.grid(True, alpha=0.3)
ax.legend(
    handles=[Patch(facecolor="red", alpha=0.5, label="Event region")],
    loc="upper right",
)

plt.tight_layout()
plt.show()

In [None]:
# Agreement vs tolerance curve
# Both methods detect event regions via iterative residual analysis.
# Sweeping tolerance reveals how well the detected regions align.
tolerance_sec = np.arange(0, 1.05, 0.05)
tolerance_frames_arr = (tolerance_sec * fps4).astype(int)

# Extract start/end arrays
w_starts = [[int(e.start) for e in wavelet_events[i]] for i in range(n_neurons4)]
w_ends = [[int(e.end) for e in wavelet_events[i]] for i in range(n_neurons4)]
t_starts = [[int(e.start) for e in threshold_events[i]] for i in range(n_neurons4)]
t_ends = [[int(e.end) for e in threshold_events[i]] for i in range(n_neurons4)]

agreements = []
for tol in tolerance_frames_arr:
    matched = 0
    for i in range(n_neurons4):
        for ws, we in zip(w_starts[i], w_ends[i]):
            # Check if any threshold event overlaps this wavelet event
            for ts, te in zip(t_starts[i], t_ends[i]):
                if ts <= (we + tol) and te >= (ws - tol):
                    matched += 1
                    break
    agreements.append(matched / total_w if total_w > 0 else 0)

print("=" * 50)
print("AGREEMENT VS TOLERANCE")
print("=" * 50)
print(f"{'Tolerance (s)':<16} {'Matched':<12} {'Agreement':<12}")
print("-" * 50)
for tol_s, agr in zip(tolerance_sec, agreements):
    if tol_s % 0.25 < 0.01 or abs(tol_s % 0.25 - 0.25) < 0.01:
        print(f"{tol_s:<16.2f} {int(agr * total_w):<12} {agr:<12.1%}")

# Plot tolerance curve
fig2, ax2 = plt.subplots(figsize=(8, 5))
ax2.plot(
    tolerance_sec, [a * 100 for a in agreements], "ko-", linewidth=2, markersize=4
)
ax2.set_xlabel("Tolerance (s)")
ax2.set_ylabel("Agreement (%)")
ax2.set_title("Event-level agreement: wavelet vs threshold")
ax2.set_ylim(0, 105)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()