# Neuron analysis

[**DRIADA**](https://driada.readthedocs.io) (Dimensionality Reduction for
Integrated Activity Data) is a Python framework for neural data analysis.
It bridges two perspectives that are usually treated separately: what
*individual* neurons encode, and how the *population as a whole* represents
information.  The typical analysis workflow looks like this:

| Step | Notebook | What it does |
|---|---|---|
| **Overview** | [00 -- DRIADA overview](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/00_driada_overview.ipynb) | Core data structures, quick tour of INTENSE, DR, networks |
| **Neuron analysis** | **01 -- this notebook** | Spike reconstruction, kinetics optimization, quality metrics, surrogates |
| **Single-neuron selectivity** | [02 -- INTENSE](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/02_selectivity_detection_intense.ipynb) | Detect which neurons encode which behavioral variables |
| **Population geometry** | [03 -- Dimensionality reduction](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/03_population_geometry_dr.ipynb) | Extract low-dimensional manifolds from population activity |
| **Network analysis** | [04 -- Networks](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/04_network_analysis.ipynb) | Build and analyze cell-cell interaction graphs |
| **Putting it together** | [05 -- Advanced](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/05_advanced_capabilities.ipynb) | Combine INTENSE + DR, leave-one-out importance, RSA, RNN analysis |

This notebook focuses on individual neuron analysis: spike reconstruction,
kinetics optimization, and quality assessment.

**What you will learn:**

1. **Single neuron analysis** -- create a [`Neuron`](https://driada.readthedocs.io/en/latest/api/experiment/core.html#driada.experiment.neuron.Neuron), reconstruct spikes, optimize kinetics, compute quality metrics, and generate surrogates.
2. **Threshold vs wavelet reconstruction** -- compare two spike detection methods across four optimization modes.
3. **Threshold vs wavelet reconstruction** -- compare two spike detection methods, visualize agreement, and quantify event-region overlap.

In [None]:
# TODO: revert to '!pip install -q driada' after v1.0.0 PyPI release
!pip install -q git+https://github.com/iabs-neuro/driada.git@main
%matplotlib inline

import warnings

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

from driada.experiment import generate_synthetic_exp
from driada.experiment.neuron import Neuron
from driada.experiment.synthetic import generate_pseudo_calcium_signal

## 1. Setup

This notebook covers **individual neuron analysis**: spike reconstruction,
kinetics optimization, and quality assessment.  For an introduction to
`Experiment` objects, feature types, and `TimeSeries`, see
[Notebook 00 -- DRIADA overview](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/00_driada_overview.ipynb).

We work directly with
[`Neuron`](https://driada.readthedocs.io/en/latest/api/experiment/core.html#driada.experiment.neuron.Neuron)
objects and synthetic calcium traces throughout.

## 2. Single neuron analysis

Deep dive into individual neuron quality: generate a synthetic calcium
signal with [`generate_pseudo_calcium_signal`](https://driada.readthedocs.io/en/latest/api/experiment/synthetic.html#driada.experiment.synthetic.core.generate_pseudo_calcium_signal),
create a [`Neuron`](https://driada.readthedocs.io/en/latest/api/experiment/core.html#driada.experiment.neuron.Neuron) object, reconstruct spikes, optimize kinetics,
compute quality metrics, and generate surrogates for null-hypothesis testing.

In [None]:
np.random.seed(42)

print("1. Generating synthetic calcium signal...")

signal = generate_pseudo_calcium_signal(
    duration=200.0,              # Signal duration in seconds
    sampling_rate=20.0,          # Sampling rate (Hz)
    event_rate=0.15,             # Average event rate (Hz)
    amplitude_range=(1.0, 2.0),  # Event amplitude range (dF/F0)
    decay_time=1.5,              # Calcium decay time constant (seconds)
    rise_time=0.15,              # Calcium rise time constant (seconds)
    noise_std=0.05,              # Additive Gaussian noise level
    kernel='double_exponential'  # Realistic calcium kernel
)

print(f"   [OK] Generated signal: {len(signal)} frames ({len(signal)/20:.1f} seconds)")

print("\n2. Creating Neuron object...")

neuron = Neuron(
    cell_id='example_neuron',
    ca=signal,              # Calcium signal
    sp=None,                # No ground-truth spikes (will be reconstructed)
    fps=20.0                # Sampling rate
)

print(f"   [OK] Neuron created: {neuron.cell_id}")
print(f"   [OK] Signal length: {neuron.n_frames} frames")
print(f"   [OK] Sampling rate: {neuron.fps} Hz")

### Spike reconstruction

The **wavelet** method detects calcium transient events via CWT (continuous
wavelet transform) ridge analysis.  CWT ridge detection identifies
scale-persistent features in the wavelet scalogram -- ridges that persist
across multiple scales correspond to true transient events rather than noise.

In [None]:
print("3. Reconstructing spikes using wavelet method...")

spikes = neuron.reconstruct_spikes(
    method='wavelet',
    create_event_regions=True  # Create event regions for quality metrics
)

n_events = int(np.sum(neuron.asp.data > 0))
print(f"   [OK] Detected {n_events} calcium events")
print(f"   [OK] Spike train stored in neuron.sp")
print(f"   [OK] Amplitude spikes stored in neuron.asp")

### Wavelet scalogram

The CWT scalogram shows how the signal's energy is distributed across
scales (frequencies) and time.  Calcium transient events appear as
bright cones extending from low to high scales.  Grey shading marks
the detected event regions.

In [None]:
from driada.experiment.wavelet_event_detection import (
    get_adaptive_wavelet_scales,
)
from scipy.ndimage import gaussian_filter1d

# Compute CWT scalogram using ssqueezepy (installed with driada)
from ssqueezepy import cwt, Wavelet as SqzWavelet

fps_neuron = neuron.fps
scales = get_adaptive_wavelet_scales(fps_neuron)
sig_smooth = gaussian_filter1d(neuron.ca.data, sigma=int(0.4 * fps_neuron))
wavelet_obj = SqzWavelet(('gmw', {'beta': 2, 'gamma': 3}))
Wx, _ = cwt(sig_smooth, wavelet_obj, scales=scales)

time_sec = np.arange(len(sig_smooth)) / fps_neuron

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 7), sharex=True,
                                gridspec_kw={'height_ratios': [1, 1]})

# Top: scalogram
ax1.imshow(np.abs(Wx), aspect='auto', cmap='turbo',
           extent=[0, time_sec[-1], scales[-1], scales[0]])
ax1.set_ylabel('Scale')
ax1.set_title('CWT scalogram (Generalized Morse Wavelet)')

# Bottom: signal + event regions
ax2.plot(time_sec, neuron.ca.data, 'b', linewidth=0.8)
if neuron.wvt_ridges:
    for ridge in neuron.wvt_ridges:
        t0 = ridge.start / fps_neuron
        t1 = ridge.end / fps_neuron
        ax2.axvspan(t0, t1, alpha=0.3, color='grey')
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('dF/F0')
ax2.set_title(f'Calcium signal with {n_events} detected events')

plt.tight_layout()
plt.show()

### Kinetics optimization

Fit rise and decay time constants to detected events using the **direct
measurement** method.  The `direct` method measures t_rise from the
derivative of the onset-to-peak waveform and t_off by fitting an
exponential to the peak-to-baseline decay, avoiding iterative optimization.

In [None]:
print("4. Optimizing calcium kinetics...")

kinetics = neuron.get_kinetics(
    method='direct',           # Direct measurement from detected events
    use_cached=False          # Force recomputation
)

print(f"   [OK] Optimized rise time (t_rise): {kinetics['t_rise']:.3f} seconds")
print(f"   [OK] Optimized decay time (t_off): {kinetics['t_off']:.3f} seconds")
print(f"   [OK] Events used: {kinetics['n_events_detected']}")

### Quality metrics

- **Wavelet SNR** -- ratio of event amplitude to baseline noise
- **R-squared** -- reconstruction quality (1.0 = perfect, >0.7 = good)
- **Event-only R-squared** -- quality restricted to event regions
- **NRMSE** -- normalized root mean squared error (lower is better)
- **NMAE** -- normalized mean absolute error (lower is better)

In [None]:
print("5. Computing wavelet SNR...")

wavelet_snr = neuron.get_wavelet_snr()

print(f"   [OK] Wavelet SNR: {wavelet_snr:.2f}")
print(f"       (Ratio of event amplitude to baseline noise)")

print("\n6. Computing reconstruction quality metrics...")

r2 = neuron.get_reconstruction_r2()
print(f"   [OK] Reconstruction R2:  {r2:.4f}")

r2_events = neuron.get_reconstruction_r2(event_only=True)
print(f"   [OK] Event-only R2:      {r2_events:.4f}")

nrmse = neuron.get_nrmse()
print(f"   [OK] Normalized RMSE:    {nrmse:.4f}")

nmae = neuron.get_nmae()
print(f"   [OK] Normalized MAE:     {nmae:.4f}")

### Event and noise metrics

Event-specific metrics assess quality only within detected transient
regions, which matters more than full-signal metrics when events are
sparse.  Noise estimates help set detection thresholds.

In [None]:
# Event-level metrics
event_count = neuron.get_event_count()
event_rmse = neuron.get_event_rmse()
event_mae = neuron.get_event_mae()
event_snr = neuron.get_event_snr()

print(f"Event count:      {event_count}")
print(f"Event RMSE:       {event_rmse:.4f}")
print(f"Event MAE:        {event_mae:.4f}")
print(f"Event SNR:        {event_snr:.2f}")

# Noise characterization
mad = neuron.get_mad()
baseline_std = neuron.get_baseline_noise_std()

print(f"\nMAD (robust noise):    {mad:.4f}")
print(f"Baseline noise std:    {baseline_std:.4f}")

### Accessing the reconstruction

The `reconstructed` property returns the cached model fit (calcium
kernel convolved with detected spikes).  Use `get_reconstructed()` to
recompute with custom kinetics.

In [None]:
# Cached reconstruction (uses optimized kinetics)
recon = neuron.reconstructed  # TimeSeries object
print(f"Reconstructed shape: {recon.data.shape}")
print(f"Reconstructed range: [{recon.data.min():.3f}, {recon.data.max():.3f}]")

# Recompute with custom time constants (in frames)
recon_custom = neuron.get_reconstructed(t_rise_frames=3, t_off_frames=30)
print(f"Custom recon range:  [{recon_custom.data.min():.3f}, {recon_custom.data.max():.3f}]")

### Surrogate generation

Four surrogate methods for null-hypothesis testing:

| Method | Type | Preserves |
|---|---|---|
| **roll_based** | Calcium | Autocorrelation structure, amplitude distribution |
| **waveform_based** | Calcium | Individual waveform shapes, event count |
| **chunks_based** | Calcium | Local structure within chunks |
| **isi_based** | Spikes | Inter-spike interval distribution |

In [None]:
print("7. Surrogate generation methods...")

shuffled_roll = neuron.get_shuffled_calcium(method='roll_based', seed=42)
print(f"\n   [Roll-based] Circular shift surrogate:")
print(f"       Mean: {np.mean(shuffled_roll):.4f}  (original: {np.mean(neuron.ca.data):.4f})")
print(f"       Std:  {np.std(shuffled_roll):.4f}  (original: {np.std(neuron.ca.data):.4f})")

shuffled_wf = neuron.get_shuffled_calcium(method='waveform_based', seed=42)
print(f"\n   [Waveform-based] Spike-shuffle + reconstruct surrogate:")
print(f"       Mean: {np.mean(shuffled_wf):.4f}  (original: {np.mean(neuron.ca.data):.4f})")
print(f"       Std:  {np.std(shuffled_wf):.4f}  (original: {np.std(neuron.ca.data):.4f})")

shuffled_chunks = neuron.get_shuffled_calcium(method='chunks_based', seed=42)
print(f"\n   [Chunks-based] Chunk reordering surrogate:")
print(f"       Mean: {np.mean(shuffled_chunks):.4f}  (original: {np.mean(neuron.ca.data):.4f})")
print(f"       Std:  {np.std(shuffled_chunks):.4f}  (original: {np.std(neuron.ca.data):.4f})")

shuffled_sp = neuron.get_shuffled_spikes(method='isi_based', seed=42)
original_spike_count = int(np.sum(neuron.sp.data > 0))
shuffled_spike_count = int(np.sum(shuffled_sp > 0))
print(f"\n   [ISI-based] Spike train surrogate:")
print(f"       Spike count: {shuffled_spike_count}  (original: {original_spike_count})")

In [None]:
# Visualize surrogates: original trace + 4 surrogate methods
fig, axes = plt.subplots(5, 1, figsize=(14, 10), sharex=True)

time_surr = np.arange(len(neuron.ca.data)) / neuron.fps
xlim = (0, min(30, time_surr[-1]))  # Show first 30 seconds

surrogates = [
    ('Original', neuron.ca.data, 'black'),
    ('Roll-based', shuffled_roll, '#1f77b4'),
    ('Waveform-based', shuffled_wf, '#ff7f0e'),
    ('Chunks-based', shuffled_chunks, '#2ca02c'),
    ('ISI-based', neuron.get_reconstructed(
        t_rise_frames=int(0.15 * neuron.fps),
        t_off_frames=int(1.5 * neuron.fps),
        spikes=shuffled_sp,
    ).data, '#d62728'),
]

for ax, (label, trace, color) in zip(axes, surrogates):
    ax.plot(time_surr, trace, color=color, linewidth=0.8)
    ax.set_ylabel(label, fontsize=9)
    ax.set_xlim(xlim)
    ax.grid(True, alpha=0.3)

axes[-1].set_xlabel('Time (s)')
axes[0].set_title('Surrogate methods comparison (first 30 s)')
plt.tight_layout()
plt.show()

## 3. Threshold vs wavelet reconstruction

Above we analyzed one neuron with the wavelet method. DRIADA also supports
threshold-based detection. How do the two methods compare on the same signal?

DRIADA supports two spike detection methods:

| Method | Description |
|---|---|
| **Threshold** | MAD-based signal crossing -- fast, good for high SNR |
| **Wavelet** | CWT ridge detection -- more sensitive, better for low SNR or overlapping events |

Both support iterative detection (detect-subtract-detect) to recover
weaker events hidden under larger transients.

In [None]:
# Generate a synthetic signal with non-default kinetics
fps_cmp = 30.0
signal_cmp = generate_pseudo_calcium_signal(
    duration=120.0, sampling_rate=fps_cmp, event_rate=0.3,  # higher rate for denser events
    amplitude_range=(0.3, 1.2), decay_time=0.8, rise_time=0.10,
    noise_std=0.04, kernel='double_exponential',
)
time_cmp = np.arange(len(signal_cmp)) / fps_cmp

# Wavelet: iterative detection with kinetics optimization
n_wvt = Neuron(cell_id='wavelet', ca=signal_cmp.copy(), sp=None, fps=fps_cmp)
with warnings.catch_warnings():
    warnings.simplefilter('ignore', UserWarning)
    n_wvt.reconstruct_spikes(method='wavelet', iterative=True, n_iter=3,
                              create_event_regions=True)
    n_wvt.optimize_kinetics(method='direct', fps=fps_cmp,
                             update_reconstruction=True, detection_method='wavelet')

# Threshold: iterative detection with kinetics optimization
n_thr = Neuron(cell_id='threshold', ca=signal_cmp.copy(), sp=None, fps=fps_cmp)
with warnings.catch_warnings():
    warnings.simplefilter('ignore', UserWarning)
    n_thr.reconstruct_spikes(method='threshold', iterative=True, n_iter=3,
                              n_mad=4.0, create_event_regions=True,  # n_mad: noise multiplier for threshold
                              adaptive_thresholds=True)
    n_thr.optimize_kinetics(method='direct', fps=fps_cmp,
                             update_reconstruction=True, detection_method='threshold',
                             n_mad=4.0, iterative=True, n_iter=3,
                             adaptive_thresholds=True)

wvt_events = len(n_wvt.wvt_ridges) if n_wvt.wvt_ridges else 0
thr_events = len(n_thr.threshold_events) if n_thr.threshold_events else 0

print(f"Wavelet:   {wvt_events} events, R2={n_wvt.get_reconstruction_r2():.4f}")
print(f"Threshold: {thr_events} events, R2={n_thr.get_reconstruction_r2():.4f}")

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(14, 7), sharex=True)

ax = axes[0]
ax.plot(time_cmp, signal_cmp, 'k', linewidth=0.8)
ax.set_ylabel('dF/F0')
ax.set_title('Original calcium signal')
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.plot(time_cmp, signal_cmp, 'k', linewidth=0.5, alpha=0.4)
ax.plot(time_cmp, n_wvt._reconstructed.data, 'b', linewidth=1.2,
        label=f'Wavelet (R2={n_wvt.get_reconstruction_r2():.3f})')
if n_wvt.wvt_ridges:
    for r in n_wvt.wvt_ridges:
        ax.axvspan(r.start / fps_cmp, r.end / fps_cmp, alpha=0.15, color='blue')
ax.set_ylabel('dF/F0')
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)

ax = axes[2]
ax.plot(time_cmp, signal_cmp, 'k', linewidth=0.5, alpha=0.4)
ax.plot(time_cmp, n_thr._reconstructed.data, 'r', linewidth=1.2,
        label=f'Threshold (R2={n_thr.get_reconstruction_r2():.3f})')
if n_thr.threshold_events:
    for ev in n_thr.threshold_events:
        ax.axvspan(ev.start / fps_cmp, ev.end / fps_cmp, alpha=0.15, color='red')
ax.set_ylabel('dF/F0')
ax.set_xlabel('Time (s)')
ax.legend(loc='upper right', fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
print("Generating synthetic experiment...")
exp4 = generate_synthetic_exp(
    n_dfeats=2, n_cfeats=1, nneurons=5, duration=120, fps=20, seed=42
)

fps4 = exp4.fps
n_neurons4 = exp4.calcium.scdata.shape[0]
time4 = np.arange(exp4.calcium.scdata.shape[1]) / fps4

wavelet_events = []
threshold_events = []

for n4 in exp4.neurons:
    n4.reconstruct_spikes(method="wavelet", iterative=True, n_iter=3, fps=fps4)
    wavelet_events.append(list(n4.wvt_ridges))
    n4.reconstruct_spikes(method="threshold", iterative=True, n_iter=3,
                           n_mad=4.0, adaptive_thresholds=True, fps=fps4)  # lower n_mad catches more events
    threshold_events.append(list(n4.threshold_events))

for i in range(n_neurons4):
    print(f"  Neuron {i}: wavelet={len(wavelet_events[i])}, "
          f"threshold={len(threshold_events[i])}")

## Further reading

Standalone examples (run directly, no external data needed):
- [neuron_basic_usage](https://github.com/iabs-neuro/driada/tree/main/examples/neuron_basic_usage) -- Core Neuron class and quality metrics
- [spike_reconstruction](https://github.com/iabs-neuro/driada/tree/main/examples/spike_reconstruction) -- Wavelet vs threshold comparison
- [threshold_vs_wavelet_optimization](https://github.com/iabs-neuro/driada/tree/main/examples/spike_reconstruction) -- Optimization modes and benchmarks

[All examples](https://github.com/iabs-neuro/driada/tree/main/examples)