# Detecting neuron selectivity with INTENSE

Which neurons encode which behavioral variables -- and how do you tell
real selectivity from chance?  [**DRIADA**](https://driada.readthedocs.io)
answers this with **INTENSE**, an information-theoretic significance
testing pipeline.  This notebook walks through the method from first
principles to a full production run.

| Step | Notebook | What it does |
|---|---|---|
| Load & inspect | [01 -- Data loading](01_data_loading_and_neurons.ipynb) | Wrap your recording into an `Experiment`, reconstruct spikes, assess quality |
| **Single-neuron selectivity** | **02 -- this notebook** | Detect which neurons encode which behavioral variables |
| Population geometry | [03 -- Dimensionality reduction](03_population_geometry_dr.ipynb) | Extract low-dimensional manifolds from population activity |
| Functional networks | [04 -- Networks](04_functional_networks.ipynb) | Build and analyze cell-cell interaction graphs |
| Putting it together | [05 -- Advanced](05_advanced_capabilities.ipynb) | Combine INTENSE + DR, leave-one-out importance, RSA, RNN analysis |

**What you will learn:**

1. **Information theory fundamentals** -- mutual information estimation
   (GCMI vs KSG), similarity metrics, time-delayed MI, conditional MI,
   and interaction information.
2. **Basic INTENSE workflow** -- generate a synthetic population, run
   two-stage significance testing, and extract results.
3. **Complete pipeline with ground truth validation** -- all feature types,
   Holm correction, disentanglement, delay optimization, and validation
   against known selectivity.
4. **Feature-feature relations** -- discover which behavioral variables
   are themselves correlated before analyzing neurons.
5. **Mixed selectivity & disentanglement** -- neurons responding to
   multiple correlated features: true mixed selectivity vs redundant
   detections.

In [None]:
!pip install -q driada
%matplotlib inline

import os
import time
import tempfile

import numpy as np
import matplotlib.pyplot as plt

import driada
from driada.information import (
    TimeSeries,
    get_mi,
    get_sim,
    get_tdmi,
    conditional_mi,
    interaction_information,
)
from driada.information.info_base import MultiTimeSeries
from driada.experiment.synthetic import generate_tuned_selectivity_exp
from driada.intense import compute_feat_feat_significance
from driada.intense.io import save_results, load_results

## 1. Information theory fundamentals

Before running INTENSE, understand the building blocks: mutual information
estimation, similarity metrics, temporal lags, and conditional dependencies.

### Creating TimeSeries

Wrap numpy arrays as [`TimeSeries`](https://driada.readthedocs.io/en/latest/api/information/core.html) with type hints (`linear`, `categorical`,
`circular`). The type determines which MI estimator and preprocessing
DRIADA uses internally. Continuous-continuous pairs use GCMI (Gaussian copula,
via `mi_gg`) or KSG; continuous-discrete pairs use `mi_model_gd`;
discrete-discrete pairs use exact MI from the joint distribution.

In [None]:
rng = np.random.default_rng(42)
n = 5000

# ------------------------------------------------------------------
# 1. Create TimeSeries from numpy arrays
# ------------------------------------------------------------------
print("[1] Creating TimeSeries objects")
print("-" * 40)

continuous = rng.normal(size=n)
ts_cont = TimeSeries(continuous, ts_type="linear", name="continuous")
print(f"  Continuous: type={ts_cont.type_info}, len={len(ts_cont.data)}")

discrete = rng.choice([0, 1, 2], size=n, p=[0.5, 0.3, 0.2])
ts_disc = TimeSeries(discrete, ts_type="categorical", name="discrete")
print(f"  Discrete:   type={ts_disc.type_info}, len={len(ts_disc.data)}")

circular = rng.uniform(0, 2 * np.pi, size=n)
ts_circ = TimeSeries(circular, ts_type="circular", name="circular")
print(f"  Circular:   type={ts_circ.type_info}, len={len(ts_circ.data)}")

### Pairwise MI: GCMI vs KSG

[`get_mi`](https://driada.readthedocs.io/en/latest/api/information/core.html)`()` estimates mutual information between two `TimeSeries`.
**GCMI** (Gaussian Copula MI) is fast but only captures monotonic
dependency. **KSG** (Kraskov-Stoegbauer-Grassberger) captures arbitrary
dependency but is slower.

In [None]:
# ------------------------------------------------------------------
# 2. get_mi -- pairwise MI
# ------------------------------------------------------------------
print("[2] Pairwise mutual information (get_mi)")
print("-" * 40)

x = rng.normal(size=n)
noise = rng.normal(size=n)
y_corr = x + 0.5 * noise          # correlated with x
y_indep = rng.normal(size=n)       # independent of x

ts_x = TimeSeries(x)
ts_y_corr = TimeSeries(y_corr)
ts_y_indep = TimeSeries(y_indep)

mi_corr = get_mi(ts_x, ts_y_corr)
mi_indep = get_mi(ts_x, ts_y_indep)
print(f"  MI(X, Y_correlated)  = {mi_corr:.4f} bits")
print(f"  MI(X, Y_independent) = {mi_indep:.4f} bits")
print(f"  Correlated MI >> independent MI: {mi_corr > 5 * mi_indep}")

# Compare estimators on monotonic vs non-monotonic relationships.
# GCMI reduces to -0.5*log(1-rho^2) where rho is Spearman rank correlation,
# so it only captures monotonic dependency. KSG captures any dependency.
mi_gcmi = get_mi(ts_x, ts_y_corr, estimator="gcmi")
mi_ksg = get_mi(ts_x, ts_y_corr, estimator="ksg")
print(f"\n  Monotonic relationship (y = x + noise):")
print(f"    GCMI: {mi_gcmi:.4f} bits")
print(f"    KSG:  {mi_ksg:.4f} bits")
print(f"    (agree because relationship is monotonic)")

# Non-monotonic: y = x^2. Spearman rho ~ 0 due to exact symmetry, so GCMI ~ 0.
x_sym = rng.uniform(-3, 3, size=n)
y_quad = x_sym ** 2 + 0.3 * rng.normal(size=n)
ts_x_sym = TimeSeries(x_sym)
ts_y_quad = TimeSeries(y_quad)
mi_gcmi_q = get_mi(ts_x_sym, ts_y_quad, estimator="gcmi")
mi_ksg_q = get_mi(ts_x_sym, ts_y_quad, estimator="ksg")
print(f"\n  Non-monotonic relationship (y = x^2 + noise):")
print(f"    GCMI: {mi_gcmi_q:.4f} bits  (blind to symmetric dependency)")
print(f"    KSG:  {mi_ksg_q:.4f} bits  (captures it)")
print(f"    KSG >> GCMI: {mi_ksg_q > 3 * mi_gcmi_q}")

### Similarity metrics

[`get_sim`](https://driada.readthedocs.io/en/latest/api/information/core.html)`()` wraps MI, Pearson r, and Spearman rho in a unified interface.
Available metrics include `mi`, `pearsonr`, `spearmanr`, `kendalltau`,
`fast_pearsonr`, `av` (activity ratio for binary-gated signals), and any
scipy.stats correlation function by name.

In [None]:
# ------------------------------------------------------------------
# 3. get_sim -- compare metrics on the same data
# ------------------------------------------------------------------
print("[3] Similarity metrics comparison (get_sim)")
print("-" * 40)

metrics = ["mi", "pearsonr", "spearmanr"]
for metric in metrics:
    val = get_sim(ts_x, ts_y_corr, metric=metric)
    print(f"  {metric:12s}(X, Y_corr) = {val:.4f}")

### Time-delayed MI

[`get_tdmi`](https://driada.readthedocs.io/en/latest/api/information/core.html)`()` sweeps temporal lags to find the shift that maximizes MI.
This is useful for detecting delayed neural responses to behavior.

In [None]:
# ------------------------------------------------------------------
# 4. get_tdmi -- time-delayed MI
# ------------------------------------------------------------------
print("[4] Time-delayed MI (get_tdmi)")
print("-" * 40)

# Create a signal with known lag=15 autocorrelation
base = rng.normal(size=n)
lag = 15
lagged = np.zeros(n)
lagged[lag:] = base[:-lag]
signal = base + 0.3 * rng.normal(size=n) + 0.8 * lagged

max_shift = 50
tdmi_values = np.array(get_tdmi(signal, max_shift=max_shift))
best_lag = np.argmax(tdmi_values) + 1  # get_tdmi starts at min_shift=1
print(f"  True lag: {lag}")
print(f"  TDMI peak lag: {best_lag}")
print(f"  TDMI at peak: {tdmi_values[best_lag - 1]:.4f} bits")
print(f"  Lag correctly detected: {abs(best_lag - lag) <= 2}")

### Conditional MI and interaction information

**Conditional MI** ([`conditional_mi`](https://driada.readthedocs.io/en/latest/api/information/core.html)) `I(X;Y|Z)` removes shared variance with Z.
**Interaction information** ([`interaction_information`](https://driada.readthedocs.io/en/latest/api/information/core.html)) distinguishes synergy (>0) from
redundancy (<0).

In [None]:
# ------------------------------------------------------------------
# 5. conditional_mi -- I(X;Y|Z)
# ------------------------------------------------------------------
print("[5] Conditional MI: I(X;Y|Z)")
print("-" * 40)

z = rng.normal(size=n)
x_from_z = z + 0.3 * rng.normal(size=n)
y_from_z = z + 0.3 * rng.normal(size=n)

ts_xz = TimeSeries(x_from_z)
ts_yz = TimeSeries(y_from_z)
ts_z = TimeSeries(z)

mi_xy = get_mi(ts_xz, ts_yz)
cmi_xy_z = conditional_mi(ts_xz, ts_yz, ts_z)

print(f"  I(X;Y)   = {mi_xy:.4f} bits  (shared via Z)")
print(f"  I(X;Y|Z) = {cmi_xy_z:.4f} bits  (residual after conditioning)")
print(f"  Conditioning reduces MI: {cmi_xy_z < mi_xy * 0.5}")

In [None]:
# ------------------------------------------------------------------
# 6. interaction_information -- synergy vs redundancy
# ------------------------------------------------------------------
print("[6] Interaction information: synergy vs redundancy")
print("-" * 40)

# Redundancy: Y and Z provide overlapping info about X
x_r = rng.normal(size=n)
y_r = TimeSeries(x_r + 0.2 * rng.normal(size=n))
z_r = TimeSeries(x_r + 0.2 * rng.normal(size=n))
ts_xr = TimeSeries(x_r)

ii_redund = interaction_information(ts_xr, y_r, z_r)
print(f"  Redundancy example: II = {ii_redund:.4f} (expected < 0)")

# Synergy: XOR-like relationship
a = rng.choice([0, 1], size=n).astype(float)
b = rng.choice([0, 1], size=n).astype(float)
xor_signal = (a + b + 0.1 * rng.normal(size=n))

ts_xor = TimeSeries(xor_signal)
ts_a = TimeSeries(a, ts_type="binary")
ts_b = TimeSeries(b, ts_type="binary")

ii_synergy = interaction_information(ts_xor, ts_a, ts_b)
print(f"  Synergy example:    II = {ii_synergy:.4f} (expected > 0)")
print(f"  Redundancy is negative: {ii_redund < 0}")

## 2. Basic INTENSE workflow

The minimal pipeline: [`generate_tuned_selectivity_exp`](https://driada.readthedocs.io/en/latest/api/experiment/synthetic.html) creates a synthetic
population with known selectivity, [`compute_cell_feat_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html) runs
two-stage significance testing, and results are extracted from the experiment.

In [None]:
# Step 1: Generate synthetic experiment with meaningful features
print("1. Generating synthetic experiment...")
print("   - 10 neurons with realistic tuning")
print("   - Head direction cells (circular tuning)")
print("   - Speed cells (linear tuning)")
print("   - Event cells (discrete responses)")
print("   - 10 minutes recording")

# Define simple population with meaningful selectivity
population = [
    {"name": "hd_cells", "count": 2, "features": ["head_direction"]},
    {"name": "speed_cells", "count": 2, "features": ["speed"]},
    {"name": "event_cells", "count": 2, "features": ["event_0"]},
    {"name": "nonselective", "count": 4, "features": []},
]

exp = generate_tuned_selectivity_exp(
    population=population,
    duration=600,
    fps=20,
    seed=47,
    n_discrete_features=1,
    verbose=False,
)

print(
    f"   [OK] Created experiment with {exp.n_cells} neurons and {exp.n_frames} timepoints"
)
print(f"   [OK] Features: {list(exp.dynamic_features.keys())}")

In [None]:
# Step 2: Analyze neuronal selectivity
print("2. Running INTENSE analysis...")
print("   - Two-stage statistical testing")
print("   - Mutual information metric")
print("   - Multiple comparison correction")

stats, significance, info, results = driada.compute_cell_feat_significance(
    exp,
    mode="two_stage",
    n_shuffles_stage1=100,
    n_shuffles_stage2=10000,
    pval_thr=0.001,
    multicomp_correction=None,
    ds=5,
    verbose=False,
)

print("   [OK] Analysis complete")

In [None]:
# Step 3: Extract significant results
print("3. Extracting significant results...")

significant_neurons = exp.get_significant_neurons()
total_pairs = sum(len(features) for features in significant_neurons.values())

print(f"   [OK] Found {len(significant_neurons)} neurons with significant selectivity")
print(f"   [OK] Total significant neuron-feature pairs: {total_pairs}")

# Step 4: Display results
print("\n4. Results summary:")

if significant_neurons:
    print("   Significant neuron-feature relationships:")
    for cell_id in list(significant_neurons.keys())[:3]:  # Show first 3
        for feat_name in significant_neurons[cell_id]:
            pair_stats = exp.get_neuron_feature_pair_stats(cell_id, feat_name)

            print(f"   - Neuron {cell_id} <-> Feature '{feat_name}':")
            print(f"     - Mutual Information: {pair_stats.get('me', 0):.4f} bits")
            if "pval" in pair_stats:
                print(f"     - P-value: {pair_stats['pval']:.2e}")
            # opt_delay is in frames; convert to seconds using experiment fps
            opt_delay_frames = pair_stats.get("opt_delay", 0)
            opt_delay_sec = opt_delay_frames / exp.fps if exp.fps else 0
            print(f"     - Optimal delay: {opt_delay_sec:.2f}s ({opt_delay_frames} frames)")

    if len(significant_neurons) > 3:
        remaining = len(significant_neurons) - 3
        print(f"   ... and {remaining} more significant neurons")
else:
    print("   No significant relationships found with current parameters.")
    print("   Try using different synthetic data or adjusting p-value threshold.")

In [None]:
# Step 5: Create visualization
print("5. Creating visualization...")

if significant_neurons:
    # Plot first significant neuron-feature pair
    cell_id = list(significant_neurons.keys())[0]
    feat_name = significant_neurons[cell_id][0]

    fig, ax = plt.subplots(figsize=(10, 6))
    driada.intense.plot_neuron_feature_pair(exp, cell_id, feat_name, ax=ax)
    plt.title(f"Neuron {cell_id} selectivity to {feat_name}")
    plt.tight_layout()
    plt.show()
else:
    print("   No visualization created (no significant relationships found)")

## 3. Complete pipeline with ground truth validation

Production workflow: all feature types (circular, spatial, linear,
discrete), Holm correction, disentanglement, delay optimization,
and ground truth validation.

In [None]:
# Population configuration - defines neuron groups and their selectivity
POPULATION = [
    {"name": "hd_cells", "count": 4, "features": ["head_direction"]},
    {"name": "place_cells", "count": 4, "features": ["position_2d"]},
    {"name": "speed_cells", "count": 4, "features": ["speed"]},
    {"name": "event_cells", "count": 4, "features": ["event_0"]},
    {"name": "mixed_cells", "count": 4, "features": ["head_direction", "event_0"]},
    {"name": "nonselective", "count": 4, "features": []},
]

# Analysis parameters
CONFIG = {
    # Recording parameters
    "duration": 900,        # seconds
    "fps": 20,              # sampling rate
    "seed": 42,
    # Tuning parameters
    "kappa": 4.0,           # von Mises concentration (HD cells)
    # Calcium dynamics
    "baseline_rate": 0.02,  # baseline firing rate
    "peak_rate": 2.0,       # peak response
    "decay_time": 1.5,      # calcium decay time
    "calcium_noise": 0.01,  # noise level
    # Discrete event parameters
    "n_discrete_features": 2,
    "event_active_fraction": 0.08,  # ~8% active time per event
    "event_avg_duration": 0.8,      # seconds
    # INTENSE analysis parameters
    "n_shuffles_stage1": 100,   # stage 1 screening shuffles
    "n_shuffles_stage2": 10000,  # stage 2 confirmation (FFT makes this fast)
    "pval_thr": 0.05,           # p-value threshold after correction
    "multicomp_correction": "holm",  # multiple comparison correction
}

# Custom tuning defaults based on config
tuning_defaults = {
    "head_direction": {"kappa": CONFIG["kappa"]},
}

exp3 = generate_tuned_selectivity_exp(
    population=POPULATION,
    tuning_defaults=tuning_defaults,
    duration=CONFIG["duration"],
    fps=CONFIG["fps"],
    baseline_rate=CONFIG["baseline_rate"],
    peak_rate=CONFIG["peak_rate"],
    decay_time=CONFIG["decay_time"],
    calcium_noise=CONFIG["calcium_noise"],
    n_discrete_features=CONFIG["n_discrete_features"],
    event_active_fraction=CONFIG["event_active_fraction"],
    event_avg_duration=CONFIG["event_avg_duration"],
    seed=CONFIG["seed"],
    verbose=True,
)
ground_truth = exp3.ground_truth

# Remap ground truth: INTENSE tests head_direction_2d (cos/sin), not raw angle
ground_truth["expected_pairs"] = [
    (nid, "head_direction_2d" if f == "head_direction" else f)
    for nid, f in ground_truth["expected_pairs"]
]

In [None]:
# Run INTENSE with full options
print("\nRunning INTENSE analysis...")
print(f"  Stage 1: {CONFIG['n_shuffles_stage1']} shuffles")
print(f"  Stage 2: {CONFIG['n_shuffles_stage2']} shuffles")
print(f"  P-value threshold: {CONFIG['pval_thr']}")

start_time = time.time()

# Build feature list: exclude x/y marginals (use position_2d) and raw
# circular features (use their _2d cos/sin representation instead)
feat_bunch = [
    feat_name for feat_name in exp3.dynamic_features.keys()
    if feat_name not in ["x", "y", "head_direction"]
]
print(f"  Features to test: {feat_bunch}")

# Run INTENSE with disentanglement to handle correlated features
# - find_optimal_delays=True: Search for best temporal alignment between
#   neural activity and features (compensates for calcium dynamics)
# - with_disentanglement=True: Identify redundant detections caused by
#   feature correlations (e.g., HD cells detecting position due to
#   trajectory patterns where animal faces certain directions at certain locations)
stats3, significance3, info3, results3, disent_results3 = driada.compute_cell_feat_significance(
    exp3,
    feat_bunch=feat_bunch,
    mode="two_stage",
    n_shuffles_stage1=CONFIG["n_shuffles_stage1"],
    n_shuffles_stage2=CONFIG["n_shuffles_stage2"],
    find_optimal_delays=True,  # Find best temporal alignment
    ds=5,  # Downsampling factor for speed
    pval_thr=CONFIG["pval_thr"],
    multicomp_correction=CONFIG["multicomp_correction"],
    use_precomputed_stats=False,  # Force fresh computation
    with_disentanglement=True,
    verbose=True,
)

analysis_time = time.time() - start_time
significant_neurons3 = exp3.get_significant_neurons()

total_pairs = sum(len(features) for features in significant_neurons3.values())
print(f"\n  Completed in {analysis_time:.1f} seconds")
print(f"  Significant neurons: {len(significant_neurons3)}/{exp3.n_cells}")
print(f"  Total significant pairs: {total_pairs}")

### Ground truth validation

Compare detections to known selectivity. `validate_against_ground_truth`
computes sensitivity, precision, and F1 per neuron type.

In [None]:
# Validate against ground truth using IntenseResults method
metrics = results3.validate_against_ground_truth(ground_truth, verbose=True)

### Disentanglement

Which multi-feature detections are redundant (one feature explains
the other) vs true mixed selectivity? Disentanglement removes
redundant pairs and improves precision. It uses conditional MI -- for each
neuron with multiple significant features, it tests whether
`I(neuron; F1 | F2) > 0` to determine if F1 contributes information
beyond F2.

In [None]:
# Apply disentanglement corrections and compute updated metrics
print("DISENTANGLEMENT ANALYSIS")
print("=" * 60)

if disent_results3 is not None:
    summary = disent_results3.get("summary", {})
    per_neuron_disent = disent_results3.get("per_neuron_disent", {})

    if "overall_stats" in summary:
        stats_d = summary["overall_stats"]
        print(f"\n  Neuron-feature pairs analyzed: {stats_d.get('total_neuron_pairs', 0)}")
        print(f"  Redundancy rate: {stats_d.get('redundancy_rate', 0):.1f}%")
        print(f"  True mixed selectivity rate: {stats_d.get('true_mixed_selectivity_rate', 0):.1f}%")

    # Build corrected significant_neurons using final_sels
    corrected = {}
    n_removed = 0
    for neuron_id, features in significant_neurons3.items():
        if neuron_id in per_neuron_disent:
            final = per_neuron_disent[neuron_id].get("final_sels", features)
            n_removed += len(features) - len(final)
            if final:
                corrected[neuron_id] = final
        else:
            corrected[neuron_id] = features

    # Compute corrected metrics against ground truth
    expected_pairs = set(ground_truth["expected_pairs"])
    tp, fp, fn = 0, 0, 0
    for neuron_id, features in corrected.items():
        for feat_name in features:
            if (neuron_id, feat_name) in expected_pairs:
                tp += 1
            else:
                fp += 1
    fn = len(expected_pairs) - tp

    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    f1 = 2 * precision * sensitivity / (precision + sensitivity) if (precision + sensitivity) > 0 else 0

    # Show before/after
    print(f"\n  Pairs removed by disentanglement: {n_removed}")
    print(f"\n  {'Metric':<15} {'Before':<12} {'After'}")
    print(f"  {'-'*40}")
    print(f"  {'Sensitivity':<15} {metrics['sensitivity']:>10.1%}   {sensitivity:>10.1%}")
    print(f"  {'Precision':<15} {metrics['precision']:>10.1%}   {precision:>10.1%}")
    print(f"  {'F1 Score':<15} {metrics['f1']:>10.1%}   {f1:>10.1%}")
    print(f"  {'False Pos':<15} {metrics['false_positives']:>10}   {fp:>10}")
else:
    print("  Disentanglement not performed.")

### Optimal delays

Temporal offset maximizing MI between neural activity and behavior.
Positive delays mean neural activity lags behind behavior (expected
for calcium imaging due to indicator dynamics).

In [None]:
# Print optimal delays for significant neuron-feature pairs
print("OPTIMAL DELAYS")
print("=" * 60)

optimal_delays = info3.get("optimal_delays")
if optimal_delays is not None:
    fps = CONFIG["fps"]
    print(f"\n  Delay optimization compensates for calcium indicator dynamics.")
    print(f"  Positive delays = neural activity lags behavior (expected).")
    print(f"  Sampling rate: {fps} Hz")

    # Report delays for significant pairs, grouped by neuron type
    neuron_types = ground_truth.get("neuron_types", {})
    type_delays = {}

    for neuron_id, features in significant_neurons3.items():
        neuron_type = neuron_types.get(neuron_id, "unknown")
        if neuron_type not in type_delays:
            type_delays[neuron_type] = []

        for feat_name in features:
            if feat_name in feat_bunch:
                feat_idx = feat_bunch.index(feat_name)
                delay_frames = optimal_delays[neuron_id, feat_idx]
                delay_sec = delay_frames / fps
                type_delays[neuron_type].append((neuron_id, feat_name, delay_frames, delay_sec))

    print(f"\n  Optimal delays for significant pairs:")
    for neuron_type in sorted(type_delays.keys()):
        delays = type_delays[neuron_type]
        if delays:
            # Calculate mean delay for this type
            mean_delay_sec = np.mean([d[3] for d in delays])
            print(f"\n  {neuron_type} (mean: {mean_delay_sec:.2f}s):")
            for neuron_id, feat_name, delay_frames, delay_sec in delays:
                print(f"    Neuron {neuron_id:2d} -> {feat_name:15s}: {delay_frames:4d} frames ({delay_sec:+.2f}s)")
else:
    print("  No delay optimization performed.")

In [None]:
# Visualization: selectivity heatmap
print("CREATING VISUALIZATIONS")
print("=" * 60)

# Create figure with subplots
fig = plt.figure(figsize=(14, 10))

# 1. Selectivity heatmap (main plot)
ax1 = fig.add_subplot(2, 2, (1, 2))

feature_names = feat_bunch
n_neurons = exp3.n_cells
n_features = len(feature_names)

# Create MI matrix using 'me'
mi_matrix = np.zeros((n_neurons, n_features))
for neuron_id, features in significant_neurons3.items():
    for feat_name in features:
        if feat_name in feature_names:
            feat_idx = feature_names.index(feat_name)
            pair_stats = exp3.get_neuron_feature_pair_stats(neuron_id, feat_name)
            mi_matrix[neuron_id, feat_idx] = pair_stats.get("me", 0)

im = ax1.imshow(mi_matrix, aspect="auto", cmap="viridis")
ax1.set_xlabel("Features")
ax1.set_ylabel("Neurons")
ax1.set_title("INTENSE selectivity heatmap (MI values)")
ax1.set_xticks(range(n_features))
ax1.set_xticklabels(feature_names, rotation=45, ha="right")

# Add colorbar
cbar = plt.colorbar(im, ax=ax1, shrink=0.8)
cbar.set_label("Mutual information (bits)")

# Add neuron type annotations
type_colors = {
    "hd_cells": "red",
    "place_cells": "blue",
    "speed_cells": "green",
    "event_cells": "orange",
    "mixed_cells": "purple",
    "nonselective": "gray",
}
for neuron_id, neuron_type in ground_truth["neuron_types"].items():
    color = type_colors.get(neuron_type, "gray")
    ax1.scatter(-0.7, neuron_id, c=color, s=20, marker="s")

# 2. Detection rates by type
ax2 = fig.add_subplot(2, 2, 3)
types = list(metrics["type_stats"].keys())
sensitivities = [metrics["type_stats"][t]["sensitivity"] * 100 for t in types]
colors = [type_colors.get(t, "gray") for t in types]

bars = ax2.bar(range(len(types)), sensitivities, color=colors)
ax2.set_xticks(range(len(types)))
ax2.set_xticklabels([t.replace("_", "\n") for t in types], fontsize=8)
ax2.set_ylabel("Detection rate (%)")
ax2.set_title("Detection rate by neuron type")
ax2.set_ylim(0, 105)
ax2.axhline(y=100, color="k", linestyle="--", alpha=0.3)

# Add percentage labels
for bar, pct in zip(bars, sensitivities):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2,
            f"{pct:.0f}%", ha="center", va="bottom", fontsize=9)

# 3. Summary statistics (before and after disentanglement)
ax3 = fig.add_subplot(2, 2, 4)
ax3.axis("off")

summary_text = (
    f"VALIDATION SUMMARY\n"
    f"{'=' * 30}\n\n"
    f"{'Metric':<12} {'Raw':>8}\n"
    f"{'-' * 30}\n"
    f"{'Sensitivity':<12} {metrics['sensitivity']:>7.1%}\n"
    f"{'Precision':<12} {metrics['precision']:>7.1%}\n"
    f"{'F1 Score':<12} {metrics['f1']:>7.1%}\n\n"
    f"Detection counts:\n"
    f"  True Positives:  {metrics['true_positives']}\n"
    f"  False Positives: {metrics['false_positives']}\n"
    f"  False Negatives: {metrics['false_negatives']}\n\n"
    f"Population:\n"
    f"  Neurons: {exp3.n_cells}, Features: {len(exp3.dynamic_features)}\n"
    f"  Expected pairs: {len(ground_truth['expected_pairs'])}\n"
)
ax3.text(0.05, 0.95, summary_text, transform=ax3.transAxes,
        fontfamily="monospace", fontsize=9, verticalalignment="top")

plt.tight_layout()
plt.show()

### Save and load results

Persist INTENSE results to disk with [`save_results`](https://driada.readthedocs.io/en/latest/api/intense/base.html) and reload them
with [`load_results`](https://driada.readthedocs.io/en/latest/api/intense/base.html) for later analysis.

In [None]:
# Save/load round-trip
with tempfile.TemporaryDirectory() as tmpdir:
    results_path = os.path.join(tmpdir, "intense_results.npz")
    save_results(results3, results_path)
    file_mb = os.path.getsize(results_path) / 1024 / 1024
    print(f"  Saved results: {results_path} ({file_mb:.1f} MB)")

    loaded = load_results(results_path)
    print(f"  Reloaded: {len(loaded.stats)} neurons")
    print(f"  Stats keys match: {set(str(k) for k in results3.stats.keys()) == set(loaded.stats.keys())}")

## 4. Feature-feature relations

Before disentangling neuron selectivity, check which behavioral variables
are themselves correlated. [`compute_feat_feat_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html) tests all
feature pairs with FFT-based circular shuffles.

In [None]:
# Generate a fresh experiment for feature-feature analysis
POPULATION_FF = [
    {"name": "hd_cells", "count": 4, "features": ["head_direction"]},
    {"name": "place_cells", "count": 4, "features": ["position_2d"]},
    {"name": "speed_cells", "count": 4, "features": ["speed"]},
    {"name": "event_cells", "count": 4, "features": ["event_0"]},
    {"name": "nonselective", "count": 4, "features": []},
]

fps_ff = 20

print("[1] Generating synthetic experiment")
print("-" * 40)
exp4 = generate_tuned_selectivity_exp(
    population=POPULATION_FF,
    duration=600,
    fps=fps_ff,
    n_discrete_features=2,
    seed=42,
    verbose=True,
)
print(f"  Features: {list(exp4.dynamic_features.keys())}")

# Add derived feature with known correlation to speed
print("\n[2] Adding derived feature")
print("-" * 40)
speed_data = exp4.dynamic_features["speed"].data

# 1-second moving average of speed.
# Preserves enough variance for significant MI with the raw signal.
kernel_size = int(1 * fps_ff)
kernel = np.ones(kernel_size) / kernel_size
smoothed = np.convolve(speed_data, kernel, mode="same")
exp4.dynamic_features["speed_smoothed"] = TimeSeries(
    smoothed, ts_type="linear", name="speed_smoothed"
)
print(f"  Added speed_smoothed (1-second moving average of speed)")

In [None]:
# Compute feat-feat significance.
# Use feat_bunch to select features explicitly.
# Exclude raw head_direction -- the pipeline uses head_direction_2d
# (cos/sin encoding) which preserves circular topology.
print("[3] Computing feature-feature significance")
print("-" * 40)
features_to_test = [
    "head_direction_2d", "speed", "position_2d",
    "event_0", "event_1", "speed_smoothed",
]
print(f"  Testing: {features_to_test}")
print(f"  (head_direction excluded -- use head_direction_2d for circular data)")

sim_mat, sig_mat, pval_mat, feature_names_ff, info_ff = compute_feat_feat_significance(
    exp4,
    feat_bunch=features_to_test,
    n_shuffles_stage1=100,
    n_shuffles_stage2=1000,
    pval_thr=0.01,
    verbose=True,
)

In [None]:
# Display results
print("[4] Results summary")
print("-" * 40)

display_names = []
for name in feature_names_ff:
    if isinstance(name, (list, tuple)):
        display_names.append(", ".join(str(n) for n in name))
    else:
        display_names.append(str(name))

n_ff = len(feature_names_ff)
print(f"\n  Features analyzed: {n_ff}")
print(f"  Feature names: {display_names}")

print(f"\n  Significant pairs:")
n_sig = 0
for i in range(n_ff):
    for j in range(i + 1, n_ff):
        if sig_mat[i, j]:
            n_sig += 1
            print(
                f"    {display_names[i]:20s} <-> {display_names[j]:20s}  "
                f"MI={sim_mat[i, j]:.4f}  p={pval_mat[i, j]:.2e}"
            )
if n_sig == 0:
    print("    (none)")
print(f"\n  Total significant pairs: {n_sig}/{n_ff * (n_ff - 1) // 2}")

# Create MI heatmap
fig, ax = plt.subplots(figsize=(8, 7))

plot_mat = sim_mat.copy().astype(float)
np.fill_diagonal(plot_mat, np.nan)

im = ax.imshow(plot_mat, cmap="Blues", aspect="equal")
cbar = plt.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label("Mutual information (bits)")

for i in range(n_ff):
    for j in range(n_ff):
        if i != j and sig_mat[i, j]:
            ax.text(j, i, "*", ha="center", va="center",
                    fontsize=14, fontweight="bold", color="red")

for i in range(n_ff):
    ax.add_patch(plt.Rectangle((i - 0.5, i - 0.5), 1, 1,
                               fill=True, facecolor="0.85", edgecolor="none"))

ax.set_xticks(range(n_ff))
ax.set_xticklabels(display_names, rotation=45, ha="right", fontsize=8)
ax.set_yticks(range(n_ff))
ax.set_yticklabels(display_names, fontsize=8)
ax.set_title("Feature-feature mutual information (* = significant)")

plt.tight_layout()
plt.show()

## 5. Mixed selectivity & disentanglement

Neurons responding to multiple correlated features: true mixed selectivity
vs redundant detections. Here we generate a 30-neuron population with
explicit mixed selectivity groups and apply disentanglement.

In [None]:
# Generate synthetic data with known mixed selectivity patterns
print("=== GENERATING MIXED SELECTIVITY DATA ===")

# Define population with real features for mixed selectivity demonstration
population5 = [
    {"name": "hd_cells", "count": 5, "features": ["head_direction"]},
    {"name": "speed_cells", "count": 5, "features": ["speed"]},
    {"name": "event_cells", "count": 5, "features": ["event_0"]},
    {"name": "mixed_hd_speed", "count": 5, "features": ["head_direction", "speed"], "combination": "weighted_sum"},
    {"name": "mixed_hd_event", "count": 5, "features": ["head_direction", "event_0"], "combination": "weighted_sum"},
    {"name": "mixed_speed_event", "count": 5, "features": ["speed", "event_0"], "combination": "weighted_sum"},
]

exp5 = generate_tuned_selectivity_exp(
    population=population5,
    duration=900,
    fps=20,
    seed=42,
    n_discrete_features=1,
    baseline_rate=0.1,
    peak_rate=2.0,
    decay_time=2.0,
    calcium_noise=0.05,
    verbose=False,
)

print(
    f"Generated experiment: {exp5.n_cells} neurons, {len(exp5.dynamic_features)} features, {exp5.n_frames/exp5.fps:.1f}s recording"
)

In [None]:
# Run INTENSE analysis with disentanglement
print("=== RUNNING INTENSE ANALYSIS ===")

# Run comprehensive analysis with disentanglement
results5 = driada.compute_cell_feat_significance(
    exp5,
    mode="two_stage",
    n_shuffles_stage1=100,  # Stage 1 screening
    n_shuffles_stage2=10000,  # FFT optimization makes high counts fast
    verbose=False,
    with_disentanglement=True,  # Enable disentanglement analysis
    multifeature_map=driada.intense.DEFAULT_MULTIFEATURE_MAP,
    # Uses default gamma_zi distribution (better for MI null distribution)
    pval_thr=0.05,  # Slightly less conservative threshold
)

stats5, significance5, info5, intense_results5, disentanglement_results5 = results5

# Extract significant relationships
significant_neurons5 = exp5.get_significant_neurons()

# Also get neurons with mixed selectivity (at least 2 features)
mixed_candidates5 = exp5.get_significant_neurons(min_nspec=2)

# Count multifeature relationships
multifeature_count = 0
for cell_id, features in significant_neurons5.items():
    for feat in features:
        if feat in exp5.dynamic_features and isinstance(
            exp5.dynamic_features[feat], MultiTimeSeries
        ):
            multifeature_count += 1

print(
    f"Found {len(significant_neurons5)} significant neurons, {len(mixed_candidates5)} with mixed selectivity"
)

### Disentanglement analysis

Extract and interpret disentanglement results from the pipeline.

In [None]:
# Process disentanglement results from the pipeline
print("=== DISENTANGLEMENT ANALYSIS ===")

if mixed_candidates5 and disentanglement_results5 is not None:
    # Extract results from the pipeline
    disent_matrix5 = disentanglement_results5.get("disent_matrix")
    count_matrix5 = disentanglement_results5.get("count_matrix")
    feat_names5 = disentanglement_results5.get("feature_names", [])

    if disent_matrix5 is not None and count_matrix5 is not None:
        print("Disentanglement analysis completed by pipeline")
        print(
            f"Matrix shape: {disent_matrix5.shape}, Non-zero entries: {np.count_nonzero(count_matrix5)}"
        )
        print(f"Feature names analyzed: {feat_names5}")

        # Show summary if available
        if "summary" in disentanglement_results5:
            summary5 = disentanglement_results5["summary"]
            if "overall_stats" in summary5:
                stats5s = summary5["overall_stats"]
                print("\nOverall statistics:")
                print(f"  Total neuron pairs: {stats5s.get('total_neuron_pairs', 0)}")
                print(f"  Redundancy rate: {stats5s.get('redundancy_rate', 0):.1f}%")
                print(
                    f"  True mixed selectivity rate: {stats5s.get('true_mixed_selectivity_rate', 0):.1f}%"
                )
    else:
        print("Disentanglement matrices not found in results.")
        disent_matrix5, count_matrix5, feat_names5 = None, None, None
else:
    print("No mixed selectivity candidates or disentanglement results.")
    disent_matrix5, count_matrix5, feat_names5 = None, None, None

In [None]:
# Interpret: redundancy vs independence vs synergy
print("=== INTERPRETING DISENTANGLEMENT RESULTS ===")

if disent_matrix5 is not None and count_matrix5 is not None:
    redundancy_cases = []
    synergy_cases = []
    independence_cases = []

    # Calculate relative disentanglement matrix
    with np.errstate(divide="ignore", invalid="ignore"):
        rel_disent_matrix = np.divide(disent_matrix5, count_matrix5) * 100
        rel_disent_matrix[count_matrix5 == 0] = np.nan

    # Extract disentanglement cases based on matrix values
    for i in range(len(feat_names5)):
        for j in range(
            i + 1, len(feat_names5)
        ):  # Only upper triangle to avoid duplicates
            if count_matrix5[i, j] > 0:  # Only consider pairs with data
                feat1 = feat_names5[i]
                feat2 = feat_names5[j]

                # Get disentanglement score (percentage)
                disent_score = rel_disent_matrix[i, j]

                if not np.isnan(disent_score):
                    # Classify based on disentanglement score
                    if disent_score < 30:  # Redundancy: feat2 dominates
                        redundancy_cases.append(
                            (f"{feat1}-{feat2}", (feat1, feat2), disent_score / 100)
                        )
                    elif disent_score > 70:  # Synergy: feat1 dominates
                        synergy_cases.append(
                            (f"{feat1}-{feat2}", (feat1, feat2), disent_score / 100)
                        )
                    else:  # Independence: balanced
                        independence_cases.append(
                            (f"{feat1}-{feat2}", (feat1, feat2), disent_score / 100)
                        )

    # Summary statistics
    total_pairs = len(redundancy_cases) + len(synergy_cases) + len(independence_cases)
    print(
        f"Found {len(redundancy_cases)} redundancy, {len(independence_cases)} independence, {len(synergy_cases)} synergy cases"
    )

    # Show a few examples if available
    if redundancy_cases and len(redundancy_cases) > 0:
        feat1, feat2 = redundancy_cases[0][1]
        print(f"Example redundancy: {feat1} <-> {feat2}")
    elif synergy_cases and len(synergy_cases) > 0:
        feat1, feat2 = synergy_cases[0][1]
        print(f"Example synergy: {feat1} + {feat2}")
else:
    print("No disentanglement data to interpret.")

In [None]:
# Selectivity heatmap visualization
print("Creating neuron-feature selectivity heatmap...")

try:
    fig_select, ax_select, stats_select = driada.intense.plot_selectivity_heatmap(
        exp5, significant_neurons5, metric="mi", use_log_scale=False, figsize=(12, 10)
    )

    print(
        f"  - {stats_select['n_selective']} selective neurons ({stats_select['selectivity_rate']:.1f}%)"
    )
    print(f"  - {stats_select['n_pairs']} neuron-feature pairs")
    plt.show()
except Exception as e:
    print(f"Error creating selectivity heatmap: {str(e)}")

In [None]:
# Disentanglement heatmap
if disent_matrix5 is not None and count_matrix5 is not None:
    try:
        print("Creating disentanglement heatmap...")
        fig_disent, ax_disent = driada.intense.plot_disentanglement_heatmap(
            disent_matrix5,
            count_matrix5,
            feat_names5,
            title="Feature disentanglement analysis",
            figsize=(10, 8),
        )
        plt.show()
    except Exception as e:
        print(f"Error creating disentanglement heatmap: {str(e)}")
else:
    print("No disentanglement results to visualize")