# Advanced capabilities

The previous notebooks introduced the building blocks separately:
selectivity (02), manifolds (03), and networks (04).  This notebook
shows how they combine -- and how
[**DRIADA**](https://driada.readthedocs.io) extends beyond calcium
imaging to any system that produces time-varying population activity.

| Step | Notebook | What it does |
|---|---|---|
| Load & inspect | [01 -- Data loading](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/01_data_loading_and_neurons.ipynb) | Wrap your recording into an `Experiment`, reconstruct spikes, assess quality |
| Single-neuron selectivity | [02 -- INTENSE](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/02_selectivity_detection_intense.ipynb) | Detect which neurons encode which behavioral variables |
| Population geometry | [03 -- Dimensionality reduction](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/03_population_geometry_dr.ipynb) | Extract low-dimensional manifolds from population activity |
| Network analysis | [04 -- Networks](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/04_network_analysis.ipynb) | Build and analyze cell-cell interaction graphs |
| **Putting it together** | **05 -- this notebook** | Combine INTENSE + DR, leave-one-out importance, RSA, RNN analysis |

**Sections:**

1. **Embedding selectivity (DR -> INTENSE)** -- Reverse the usual
   direction: run INTENSE on embedding components to discover what each
   DR dimension encodes. Identifies functional clusters in the
   population geometry.
2. **Leave-one-out neuron importance** -- Remove each neuron and
   measure manifold degradation. Validates whether neurons important
   for the embedding are the same ones identified by INTENSE.
3. **Representational similarity analysis (RSA)** -- Compare
   population-level representations across regions, sessions, or
   conditions using Representational Dissimilarity Matrices (RDMs).
4. **Beyond calcium: DRIADA on RNN activations** -- Full pipeline on
   simulated RNN units: INTENSE + DR + network analysis.

In [None]:
# TODO: revert to '!pip install -q driada' after v1.0.0 PyPI release
!pip install -q git+https://github.com/iabs-neuro/driada.git@main
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore', category=UserWarning)

## 1. Embedding selectivity (DR -> INTENSE)

Notebooks 2--3 used INTENSE to select neurons for DR. Here we reverse
the direction: create an embedding first (PCA, UMAP), then treat each
component as a *feature* and run INTENSE on it. This answers: **which
neurons drive which embedding dimensions?**

Key APIs:
- [`compute_embedding_selectivity`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html#driada.intense.pipelines.compute_embedding_selectivity) -- INTENSE on embedding components.
  Internally, each embedding component is added as a temporary dynamic
  feature and tested via standard INTENSE, so parameters like `n_shuffles`
  and `pval_thr` pass through to [`compute_cell_feat_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html#driada.intense.pipelines.compute_cell_feat_significance).
- [`get_functional_organization`](https://driada.readthedocs.io/en/latest/api/integration.html#driada.integration.manifold_analysis.get_functional_organization) -- cluster and participation analysis
- [`compare_embeddings`](https://driada.readthedocs.io/en/latest/api/integration.html#driada.integration.manifold_analysis.compare_embeddings) -- cross-method comparison

The synthetic data below is created with
[`generate_tuned_selectivity_exp`](https://driada.readthedocs.io/en/latest/api/experiment/synthetic.html#driada.experiment.synthetic.generators.generate_tuned_selectivity_exp).

In [None]:
from driada.experiment.synthetic import generate_tuned_selectivity_exp
from driada.intense import compute_embedding_selectivity
from driada.integration import get_functional_organization, compare_embeddings
from driada.utils.visual import plot_embedding_comparison

print('1. Generating synthetic population...')

population = [
    {'name': 'hd_cells', 'count': 12,
     'features': ['head_direction']},
    {'name': 'place_cells', 'count': 10,
     'features': ['position_2d']},
    {'name': 'speed_cells', 'count': 8,
     'features': ['speed']},
    {'name': 'conjunctive', 'count': 5,
     'features': ['head_direction', 'speed'], 'combination': 'and'},
    {'name': 'non_selective', 'count': 15,
     'features': []},
]

exp_emb = generate_tuned_selectivity_exp(
    population, duration=300, fps=20, seed=42, verbose=True
)

gt_groups = {}
idx = 0
for group in population:
    for _ in range(group['count']):
        gt_groups[idx] = group['name']
        idx += 1

print(f'  {exp_emb.n_cells} neurons, {exp_emb.n_frames} frames')
for group in population:
    print(f'    {group["name"]:20s}: {group["count"]} neurons')

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(14, 5))

n_show = min(5, exp_emb.n_cells)
time_sec = np.arange(exp_emb.n_frames) / exp_emb.fps

ax = axes[0]
for i in range(n_show):
    ax.plot(time_sec, exp_emb.calcium.data[i], linewidth=0.6, label=f'neuron {i}')
ax.set_ylabel('dF/F0')
ax.set_title(f'Synthetic neural traces ({exp_emb.n_cells} neurons)')
ax.legend(loc='upper right', fontsize=8)
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.imshow(exp_emb.calcium.data, aspect='auto', cmap='hot', interpolation='none')
ax.set_xlabel('Frame')
ax.set_ylabel('Neuron')

plt.tight_layout()
plt.show()

In [None]:
print('\n2. Creating embeddings...')

n_pca_components = 4
pca_emb = exp_emb.create_embedding('pca', n_components=n_pca_components, ds=5)
print(f'  PCA: {pca_emb.shape} (n_frames, n_components)')

umap_emb = exp_emb.create_embedding(
    'umap', n_components=3, n_neighbors=50, min_dist=0.8, random_state=42, ds=5
)
print(f'  UMAP: {umap_emb.shape}')

In [None]:
print('\n3. Computing embedding selectivity (INTENSE on components)...')

results_emb = compute_embedding_selectivity(
    exp_emb,
    embedding_methods=['pca', 'umap'],
    mode='two_stage',
    n_shuffles_stage1=50,
    n_shuffles_stage2=1000,
    find_optimal_delays=False,
    pval_thr=0.01,
    ds=5,
    verbose=True,
    seed=42,
)

n_total = exp_emb.n_cells
for method in ['pca', 'umap']:
    r = results_emb[method]
    n_sig = len(r['significant_neurons'])
    print(f'\n  {method.upper()} summary:')
    print(f'    {n_sig}/{n_total} neurons significantly selective '
          f'({100 * n_sig / n_total:.0f}%)')
    for comp_idx in range(r['n_components']):
        n_sel = len(r['component_selectivity'][comp_idx])
        if n_sel > 0:
            print(f'    component {comp_idx}: {n_sel} selective neurons')

In [None]:
print('\n4. Functional organization (PCA)...')

org = get_functional_organization(
    exp_emb, 'pca', intense_results=results_emb['pca']['intense_results']
)

print('\n  Component importance (variance explained):')
for i, imp in enumerate(org['component_importance']):
    print(f'    component {i}: {imp:.3f}')

print(f'\n  Participating neurons: {org["n_participating_neurons"]}/{n_total}')
print(f'  Mean components per neuron: {org["mean_components_per_neuron"]:.2f}')

print('\n  Component specialization:')
for comp_idx, spec in org['component_specialization'].items():
    n_sel = spec['n_selective_neurons']
    rate = spec['selectivity_rate']
    if n_sel > 0:
        group_counts = {}
        for nid in spec['selective_neurons']:
            g = gt_groups.get(nid, 'unknown')
            group_counts[g] = group_counts.get(g, 0) + 1
        groups_str = ', '.join(f'{g}={c}' for g, c in sorted(group_counts.items()))
        print(f'    component {comp_idx}: {n_sel} neurons ({rate:.0%}) -- {groups_str}')

print(f'\n  Functional clusters: {len(org["functional_clusters"])}')
for i, cluster in enumerate(org['functional_clusters']):
    comps = cluster['components']
    size = cluster['size']
    group_counts = {}
    for nid in cluster['neurons']:
        g = gt_groups.get(nid, 'unknown')
        group_counts[g] = group_counts.get(g, 0) + 1
    groups_str = ', '.join(f'{g}={c}' for g, c in sorted(group_counts.items()))
    print(f'    cluster {i}: components {comps}, {size} neurons -- {groups_str}')

In [None]:
print('\n5. Comparing PCA vs UMAP functional organization...')

intense_dict = {
    m: results_emb[m]['intense_results'] for m in ['pca', 'umap']
}
comparison = compare_embeddings(
    exp_emb, ['pca', 'umap'], intense_results_dict=intense_dict
)

for method in comparison['methods']:
    n_part = comparison['n_participating_neurons'][method]
    mean_comp = comparison['mean_components_per_neuron'][method]
    n_clust = comparison['n_functional_clusters'][method]
    print(f'  {method.upper():6s}: {n_part} participating neurons, '
          f'{mean_comp:.2f} mean components, {n_clust} clusters')

if 'participation_overlap' in comparison:
    for pair, overlap in comparison['participation_overlap'].items():
        print(f'  Participation overlap ({pair}): {overlap:.2f}')

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Functional organization analysis', fontsize=14)

# (a) Component importance
ax = axes[0, 0]
comp_imp = org['component_importance']
ax.bar(range(len(comp_imp)), comp_imp, color='steelblue', edgecolor='white')
ax.set_xlabel('PCA component')
ax.set_ylabel('Variance explained (fraction)')
ax.set_title('Component importance')
ax.set_xticks(range(len(comp_imp)))

# (b) Component specialization by neuron group
ax = axes[0, 1]
group_names = [g['name'] for g in population]
group_colors = plt.cm.Set2(np.linspace(0, 1, len(group_names)))
color_map = dict(zip(group_names, group_colors))

comp_indices = sorted(org['component_specialization'].keys())
bottom = np.zeros(len(comp_indices))
for gname in group_names:
    counts = []
    for comp_idx in comp_indices:
        spec = org['component_specialization'][comp_idx]
        c = sum(1 for nid in spec['selective_neurons']
                if gt_groups.get(nid) == gname)
        counts.append(c)
    ax.bar(comp_indices, counts, bottom=bottom,
           label=gname, color=color_map[gname], edgecolor='white')
    bottom += counts
ax.set_xlabel('PCA component')
ax.set_ylabel('Selective neurons')
ax.set_title('Component specialization by group')
ax.legend(fontsize=7, loc='upper right')
ax.set_xticks(comp_indices)

# (c) Neuron participation histogram
ax = axes[1, 0]
participation = org.get('neuron_participation', {})
if participation:
    n_comps_per_neuron = [len(comps) for comps in participation.values()]
    max_comps = max(n_comps_per_neuron) if n_comps_per_neuron else 1
    bins = np.arange(0.5, max_comps + 1.5, 1)
    ax.hist(n_comps_per_neuron, bins=bins, color='steelblue',
            edgecolor='white', rwidth=0.8)
    ax.set_xlabel('Number of components')
    ax.set_ylabel('Number of neurons')
    ax.set_title('Neuron participation distribution')
    ax.set_xticks(range(1, max_comps + 1))
else:
    ax.text(0.5, 0.5, 'No participating neurons',
            ha='center', va='center', transform=ax.transAxes)
    ax.set_title('Neuron participation distribution')

# (d) PCA vs UMAP comparison
ax = axes[1, 1]
methods = comparison['methods']
n_parts = [comparison['n_participating_neurons'][m] for m in methods]
n_clusts = [comparison['n_functional_clusters'][m] for m in methods]
x = np.arange(len(methods))
w = 0.35
ax.bar(x - w / 2, n_parts, w, label='Participating neurons',
       color='steelblue', edgecolor='white')
ax.bar(x + w / 2, n_clusts, w, label='Functional clusters',
       color='coral', edgecolor='white')
ax.set_xticks(x)
ax.set_xticklabels([m.upper() for m in methods])
ax.set_ylabel('Count')
ax.set_title('PCA vs UMAP comparison')
ax.legend(fontsize=8)

if 'participation_overlap' in comparison:
    for pair, overlap in comparison['participation_overlap'].items():
        ax.annotate(f'Overlap: {overlap:.2f}',
                    xy=(0.5, 0.95), xycoords='axes fraction',
                    ha='center', fontsize=9, color='gray')

plt.tight_layout()
plt.show()

In [None]:
ds_emb = 5  # must match ds used in create_embedding
hd = exp_emb.dynamic_features['head_direction'].data[::ds_emb]
spd = exp_emb.dynamic_features['speed'].data[::ds_emb]

fig_cmp = plot_embedding_comparison(
    embeddings={'PCA': pca_emb[:, :2], 'UMAP': umap_emb[:, :2]},
    features={'head_direction': hd, 'speed': spd},
    with_trajectory=False,
    compute_metrics=False,
    scatter_size=8,
)
plt.show()

## 2. Leave-one-out neuron importance

Remove each neuron from the population and measure how much the manifold
degrades. Neurons with high INTENSE MI should also be most important for
the embedding -- this validates INTENSE results from a DR perspective.

The analysis uses ground truth reconstruction metrics: alignment
correlation, decoding accuracy, and reconstruction error (comparing the
embedding against known head direction angles). INTENSE selectivity is
computed with
[`compute_cell_feat_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html#driada.intense.pipelines.compute_cell_feat_significance)
for comparison.

In [None]:
import pandas as pd
from scipy.stats import spearmanr
from tqdm import tqdm

from driada.experiment.synthetic import generate_tuned_selectivity_exp
from driada.dim_reduction import MVData
from driada.dim_reduction.manifold_metrics import (
    compute_reconstruction_error,
    compute_decoding_accuracy,
    compute_embedding_alignment_metrics,
    circular_structure_preservation,
)
from driada.intense import compute_cell_feat_significance
from driada.utils.visual import visualize_circular_manifold

print('[1/4] Generating mixed population experiment...')
population_loo = [
    {'name': 'hd_broad', 'count': 15, 'features': ['head_direction'],
     'tuning_params': {'kappa': 2.0}},
    {'name': 'hd_medium', 'count': 15, 'features': ['head_direction'],
     'tuning_params': {'kappa': 5.0}},
    {'name': 'hd_sharp', 'count': 15, 'features': ['head_direction'],
     'tuning_params': {'kappa': 10.0}},
    {'name': 'nonselective', 'count': 15, 'features': []},
]
exp_loo = generate_tuned_selectivity_exp(
    population=population_loo, duration=600, seed=42
)
print(f'  Created: {exp_loo.n_cells} neurons, {exp_loo.calcium.data.shape[1]} timepoints')
group_desc = ' + '.join(
    f'{g["count"]} {g["name"]}' for g in population_loo
)
print(f'  Population: {group_desc}')

In [None]:
dr_method = 'isomap'
dr_params = {'dim': 2, 'nn': 20, 'max_deleted_nodes': 0.3}
ds_loo = 5

print(f'\n[2/4] Running LOO-DR analysis with {dr_method}...')

neural_data_loo = exp_loo.calcium.scdata  # (n_neurons, n_timepoints)
n_neurons_loo = neural_data_loo.shape[0]

# Ground truth: head direction angles
ground_truth_loo = exp_loo.dynamic_features['head_direction'].data[::ds_loo]


def _compute_loo_embedding(data_matrix, method, params, downsampling):
    """Compute embedding from a neuron x timepoints matrix."""
    mv = MVData(data_matrix, downsampling=downsampling)
    emb = mv.get_embedding(method=method, **params)
    coords = emb.coords.T  # (n_samples, n_dims)
    # Handle lost nodes
    gt = ground_truth_loo.copy()
    if hasattr(emb, 'graph') and hasattr(emb.graph, 'lost_nodes'):
        lost = set(emb.graph.lost_nodes)
        if lost:
            surviving = [i for i in range(len(gt)) if i not in lost]
            gt = gt[surviving]
    return coords, gt


def _compute_metrics(coords, gt):
    """Compute ground truth reconstruction metrics."""
    metrics = {}
    try:
        result = compute_reconstruction_error(coords, gt, manifold_type='circular')
        metrics['reconstruction_error'] = result['error'] if isinstance(result, dict) else result
    except Exception:
        metrics['reconstruction_error'] = np.nan
    try:
        result = compute_decoding_accuracy(coords, gt, manifold_type='circular')
        metrics['decoding_accuracy'] = result['test_r2']
    except Exception:
        metrics['decoding_accuracy'] = np.nan
    try:
        result = compute_embedding_alignment_metrics(coords, gt, manifold_type='circular')
        metrics['alignment_corr'] = result['correlation']
    except Exception:
        metrics['alignment_corr'] = np.nan
    return metrics


print('  Computing baseline...')
baseline_coords, baseline_gt = _compute_loo_embedding(
    neural_data_loo, dr_method, dr_params, ds_loo
)
baseline_metrics = _compute_metrics(baseline_coords, baseline_gt)

loo_metric_rows = [{'neuron': 'all', **baseline_metrics}]
print(f'  LOO analysis for {n_neurons_loo} neurons...')
for nidx in tqdm(range(n_neurons_loo), desc=f'LOO {dr_method}'):
    mask = np.ones(n_neurons_loo, dtype=bool)
    mask[nidx] = False
    try:
        coords_i, gt_i = _compute_loo_embedding(
            neural_data_loo[mask], dr_method, dr_params, ds_loo
        )
        m = _compute_metrics(coords_i, gt_i)
    except Exception:
        m = {k: np.nan for k in ['reconstruction_error', 'decoding_accuracy', 'alignment_corr']}
    loo_metric_rows.append({'neuron': nidx, **m})

loo_results = pd.DataFrame(loo_metric_rows).set_index('neuron')

baseline_row = loo_results.loc['all']
importance_scores = []
for nidx in range(n_neurons_loo):
    row = loo_results.loc[nidx]
    if row.isna().all():
        importance_scores.append(np.nan)
        continue
    # Higher error when removed = more important (flip sign)
    error_deg = -(baseline_row['reconstruction_error'] - row['reconstruction_error']) / (baseline_row['reconstruction_error'] + 1e-10)
    align_deg = (baseline_row['alignment_corr'] - row['alignment_corr']) / (baseline_row['alignment_corr'] + 1e-10)
    decode_deg = (baseline_row['decoding_accuracy'] - row['decoding_accuracy']) / (baseline_row['decoding_accuracy'] + 1e-10)
    importance_scores.append((error_deg + align_deg + decode_deg) / 3)

importance = pd.Series(importance_scores, index=range(n_neurons_loo), name='importance')

print(f'\nLOO-DR Results:')
print(f'  Baseline metrics:')
print(f'    alignment_corr:      {baseline_row["alignment_corr"]:.4f}')
print(f'    decoding_accuracy:   {baseline_row["decoding_accuracy"]:.4f}')
print(f'    reconstruction_error: {baseline_row["reconstruction_error"]:.4f}')

if not importance.isna().all():
    print(f'\n  Top 5 most important neurons:')
    for neuron, score in importance.nlargest(5).items():
        row = loo_results.loc[neuron]
        align_delta = baseline_row['alignment_corr'] - row['alignment_corr']
        decode_delta = baseline_row['decoding_accuracy'] - row['decoding_accuracy']
        error_delta = row['reconstruction_error'] - baseline_row['reconstruction_error']
        print(f'    Neuron {neuron}: importance={score:.4f}')
        print(f'      align: {align_delta:+.4f}, decode: {decode_delta:+.4f}, '
              f'error: {error_delta:+.4f}')

In [None]:
print('  Verifying circular structure...')
ds = 5
mvdata_vis = MVData(exp_loo.calcium.data, downsampling=ds)
baseline_emb = mvdata_vis.get_embedding(method=dr_method, **dr_params)
baseline_coords = baseline_emb.coords.T  # (n_samples, 2)
ground_truth_full = exp_loo.dynamic_features['head_direction'].data[::ds]

# Handle lost nodes if any
if hasattr(baseline_emb, 'graph') and hasattr(baseline_emb.graph, 'lost_nodes'):
    lost = set(baseline_emb.graph.lost_nodes)
    surviving = [i for i in range(len(ground_truth_full)) if i not in lost]
    ground_truth_aligned = ground_truth_full[surviving]
    print(f'    Note: {len(lost)} nodes lost in graph construction')
else:
    ground_truth_aligned = ground_truth_full

circular_metrics = circular_structure_preservation(
    baseline_coords, true_angles=ground_truth_aligned, k_neighbors=3
)
print(f'  Circular structure verification:')
print(f'    Distance CV: {circular_metrics["distance_cv"]:.3f} (lower = more circular)')
print(f'    Consecutive preservation: {circular_metrics["consecutive_preservation"]:.1%}')
print(f'    Circular correlation: {circular_metrics["circular_correlation"]:.3f}')

fig = visualize_circular_manifold(
    [baseline_coords], ground_truth_aligned, [dr_method.upper()]
)
plt.show()

In [None]:
print('\n[3/4] Running INTENSE analysis...')
stats_loo, significant_loo, info_loo, intense_res_loo = compute_cell_feat_significance(
    exp_loo,
    feat_bunch=['head_direction_2d'],
    n_shuffles_stage1=100,
    n_shuffles_stage2=5000,
    find_optimal_delays=True,
    ds=5,
    verbose=True,
)

In [None]:
print('\n[4/4] Comparing LOO importance with INTENSE selectivity...')

intense_mi_values = np.full(n_neurons_loo, np.nan)
intense_pval_values = np.full(n_neurons_loo, np.nan)
for nid in range(n_neurons_loo):
    if nid in stats_loo and 'head_direction_2d' in stats_loo[nid]:
        intense_mi_values[nid] = stats_loo[nid]['head_direction_2d'].get('me', np.nan)
        intense_pval_values[nid] = stats_loo[nid]['head_direction_2d'].get('pval', np.nan)

combined = pd.DataFrame({
    'loo_importance': importance.values,
    'intense_mi': intense_mi_values,
    'intense_pval': intense_pval_values,
}, index=range(n_neurons_loo))

valid_data = combined.dropna(subset=['loo_importance', 'intense_mi'])

if len(valid_data) >= 5:
    corr, pval_corr = spearmanr(valid_data['loo_importance'], valid_data['intense_mi'])
    print(f'\n' + '=' * 70)
    print('KEY RESULT: Correlation between LOO importance and INTENSE selectivity')
    print('=' * 70)
    print(f'  Spearman correlation: r = {corr:.3f}')
    print(f'  p-value: {pval_corr:.3e}')
    if pval_corr < 0.05 and corr > 0:
        print('  -> Significant POSITIVE correlation: neurons important for')
        print('     manifold reconstruction ARE the ones selective for head_direction')
    elif pval_corr < 0.05:
        print('  -> Significant NEGATIVE correlation (unexpected)')
    else:
        print('  -> No significant correlation found')

    print(f'\n  Top 5 neurons by INTENSE MI vs their LOO importance:')
    top_intense = valid_data.nlargest(5, 'intense_mi')
    for idx, row in top_intense.iterrows():
        print(f'    Neuron {idx}: MI={row["intense_mi"]:.3f} bits, '
              f'LOO={row["loo_importance"]:.4f}')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# (a) Scatter plot: LOO importance vs INTENSE MI
ax = axes[0]
valid = combined.dropna(subset=['loo_importance', 'intense_mi'])
ax.scatter(valid['intense_mi'], valid['loo_importance'], alpha=0.6, s=30)
ax.set_xlabel('INTENSE MI (bits)')
ax.set_ylabel('LOO importance')
ax.set_title(f'LOO importance vs INTENSE MI\nr={corr:.3f}, p={pval_corr:.2e}')
ax.grid(True, alpha=0.3)

# (b) Importance ranking
ax = axes[1]
sorted_imp = importance.dropna().sort_values(ascending=False)
colors = ['steelblue' if i < 45 else 'lightcoral' for i in sorted_imp.index]
ax.bar(range(len(sorted_imp)), sorted_imp.values, color=colors, width=1.0)
ax.set_xlabel('Neuron (sorted by importance)')
ax.set_ylabel('LOO importance')
ax.set_title('Neuron importance ranking')
# Legend
from matplotlib.patches import Patch
ax.legend(handles=[
    Patch(color='steelblue', label='HD neurons (0-44)'),
    Patch(color='lightcoral', label='Non-selective (45-59)'),
], fontsize=8)

# (c) INTENSE MI distribution by group
ax = axes[2]
mi_values = combined['intense_mi'].values
hd_mi = mi_values[:45]
non_mi = mi_values[45:]
ax.boxplot([hd_mi[~np.isnan(hd_mi)], non_mi[~np.isnan(non_mi)]],
           labels=['HD neurons', 'Non-selective'])
ax.set_ylabel('INTENSE MI (bits)')
ax.set_title('MI distribution by group')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Representational similarity analysis (RSA)

Compare population-level representations across regions, sessions, or
conditions using **Representational Dissimilarity Matrices** (RDMs).
An RDM captures pairwise dissimilarity between stimulus conditions,
abstracting away neuron identity and number.

DRIADA's `rsa` module provides:
- [`compute_rdm_unified`](https://driada.readthedocs.io/en/latest/api/rsa/core.html#driada.rsa.core.compute_rdm_unified) -- RDM from neural data + condition labels
- [`rsa_compare`](https://driada.readthedocs.io/en/latest/api/rsa/core.html#driada.rsa.core.rsa_compare) -- Compare two populations directly
- `compare_rdms` -- Correlate two pre-computed RDMs
- [`bootstrap_rdm_comparison`](https://driada.readthedocs.io/en/latest/api/rsa/core.html#driada.rsa.core.bootstrap_rdm_comparison) -- Statistical significance via bootstrap
- `plot_rdm`, `plot_rdm_comparison` -- Visualization

### 3.1 RDM from stimulus-selective populations

We create a population with **two stimulus categories**: stimuli A & B
share neurons (Category A), stimuli C & D share neurons (Category B).
Cross-category pairs share no neurons. The RDM should reveal this 2x2
block structure.

In [None]:
from driada import rsa
from driada.experiment.synthetic import generate_tuned_selectivity_exp


def create_stimulus_labels_from_events(exp, event_names):
    """Convert binary event features to categorical stimulus labels.

    Timepoints with no event or multiple simultaneous events are labeled -1.
    """
    n_timepoints = exp.calcium.data.shape[1]
    labels = np.full(n_timepoints, -1, dtype=int)

    event_count = np.zeros(n_timepoints, dtype=int)
    for event_name in event_names:
        event_data = exp.dynamic_features[event_name].data
        event_count += (event_data > 0).astype(int)

    for idx, event_name in enumerate(event_names):
        event_data = exp.dynamic_features[event_name].data
        single_event = (event_data > 0) & (event_count == 1)
        labels[single_event] = idx

    return labels


population_rsa = [
    {'name': 'cat_a_shared', 'count': 20,
     'features': ['event_0', 'event_1'], 'combination': 'or'},
    {'name': 'event_0_specific', 'count': 15, 'features': ['event_0']},
    {'name': 'event_1_specific', 'count': 15, 'features': ['event_1']},
    {'name': 'cat_b_shared', 'count': 20,
     'features': ['event_2', 'event_3'], 'combination': 'or'},
    {'name': 'event_2_specific', 'count': 15, 'features': ['event_2']},
    {'name': 'event_3_specific', 'count': 15, 'features': ['event_3']},
]

print('Generating stimulus-selective neurons (100 neurons, 4 conditions)...')
exp_rsa = generate_tuned_selectivity_exp(
    population=population_rsa,
    n_discrete_features=4,
    duration=600,
    event_active_fraction=0.08,
    event_avg_duration=1.0,
    baseline_rate=0.05,
    peak_rate=2.0,
    seed=42,
    verbose=False,
    reconstruct_spikes='threshold',
)

print('Computing RDM from spike patterns...')
stimulus_labels = create_stimulus_labels_from_events(
    exp_rsa, ['event_0', 'event_1', 'event_2', 'event_3']
)

valid_mask = stimulus_labels >= 0
rdm1, labels1 = rsa.compute_rdm_unified(
    exp_rsa.spikes.data[:, valid_mask],
    items=stimulus_labels[valid_mask],
    metric='euclidean',
)

print(f'RDM shape: {rdm1.shape}')
print(f'Stimulus conditions: {labels1}')

label_names = ['Stim A', 'Stim B', 'Stim C', 'Stim D']
fig = rsa.plot_rdm(
    rdm1, labels=label_names[:len(labels1)],
    title='Neural RDM - stimulus conditions', show_values=True,
)
plt.show()

### 3.2 Comparing representations between regions

Generate two "regions" with partially shared tuning and compare their
representations directly using [`rsa_compare`](https://driada.readthedocs.io/en/latest/api/rsa/core.html#driada.rsa.core.rsa_compare) (no pre-computed RDMs
needed). Try multiple distance metrics and comparison methods.

In [None]:
np.random.seed(42)
n_items = 20
n_neurons_v1 = 100
n_neurons_v2 = 150

# Create base patterns that both regions respond to
base_patterns = np.random.randn(n_items, 50)

# V1: Direct representation with noise
v1_data = base_patterns @ np.random.randn(50, n_neurons_v1)
v1_data += 0.2 * np.random.randn(n_items, n_neurons_v1)

# V2: Transformed representation with noise
transform = np.random.randn(50, 50)
v2_data = (base_patterns @ transform) @ np.random.randn(50, n_neurons_v2)
v2_data += 0.2 * np.random.randn(n_items, n_neurons_v2)

print('Comparing V1 and V2 representations...')
similarity = rsa.rsa_compare(v1_data, v2_data)
print(f'V1-V2 similarity (Spearman): {similarity:.3f}')

print('\nDifferent distance metrics:')
for metric in ['correlation', 'euclidean', 'cosine']:
    sim = rsa.rsa_compare(v1_data, v2_data, metric=metric)
    print(f'  {metric}: {sim:.3f}')

print('\nDifferent comparison methods:')
for comparison in ['spearman', 'pearson', 'kendall']:
    sim = rsa.rsa_compare(v1_data, v2_data, comparison=comparison)
    print(f'  {comparison}: {sim:.3f}')

rdm_v1 = rsa.compute_rdm(v1_data)
rdm_v2 = rsa.compute_rdm(v2_data)

fig = rsa.plot_rdm_comparison(
    [rdm_v1, rdm_v2], titles=['V1 representation', 'V2 representation']
)
plt.show()

### 3.3 Cross-session comparison & bootstrap testing

Compare the same population structure recorded in two sessions (different
noise). Bootstrap significance testing quantifies whether the RDM
correlation is reliably above chance.

In [None]:
population_sessions = [
    {'name': 'cat_a_shared', 'count': 14,
     'features': ['event_0', 'event_1'], 'combination': 'or'},
    {'name': 'event_0_only', 'count': 10, 'features': ['event_0']},
    {'name': 'event_1_only', 'count': 10, 'features': ['event_1']},
    {'name': 'cat_b_shared', 'count': 14,
     'features': ['event_2', 'event_3'], 'combination': 'or'},
    {'name': 'event_2_only', 'count': 10, 'features': ['event_2']},
    {'name': 'event_3_only', 'count': 10, 'features': ['event_3']},
    {'name': 'cat_c_shared', 'count': 12,
     'features': ['event_4', 'event_5'], 'combination': 'or'},
    {'name': 'event_4_only', 'count': 10, 'features': ['event_4']},
    {'name': 'event_5_only', 'count': 10, 'features': ['event_5']},
]

event_names_6 = ['event_0', 'event_1', 'event_2', 'event_3', 'event_4', 'event_5']

print('Generating session 1 (100 neurons, 6 conditions, 3 categories)...')
exp_s1 = generate_tuned_selectivity_exp(
    population=population_sessions, n_discrete_features=6, duration=600,
    event_active_fraction=0.08, event_avg_duration=1.0,
    baseline_rate=0.05, peak_rate=2.0,
    seed=42, verbose=False, reconstruct_spikes='threshold',
)

print('Generating session 2 (same structure, different noise)...')
exp_s2 = generate_tuned_selectivity_exp(
    population=population_sessions, n_discrete_features=6, duration=600,
    event_active_fraction=0.08, event_avg_duration=1.0,
    baseline_rate=0.05, peak_rate=2.0,
    seed=123, verbose=False, reconstruct_spikes='threshold',
)

stim_labels_1 = create_stimulus_labels_from_events(exp_s1, event_names_6)
stim_labels_2 = create_stimulus_labels_from_events(exp_s2, event_names_6)

valid_1 = stim_labels_1 >= 0
valid_2 = stim_labels_2 >= 0

rdm_s1, labels_s1 = rsa.compute_rdm_unified(
    exp_s1.spikes.data[:, valid_1], items=stim_labels_1[valid_1],
    metric='euclidean',
)
rdm_s2, labels_s2 = rsa.compute_rdm_unified(
    exp_s2.spikes.data[:, valid_2], items=stim_labels_2[valid_2],
    metric='euclidean',
)

similarity_sessions = rsa.compare_rdms(rdm_s1, rdm_s2, method='spearman')
print(f'Cross-session RDM similarity: {similarity_sessions:.3f}')

label_names_6 = ['Stim A', 'Stim B', 'Stim C', 'Stim D', 'Stim E', 'Stim F']
fig = rsa.plot_rdm_comparison(
    [rdm_s1, rdm_s2], labels=label_names_6[:len(labels_s1)],
    titles=['Session 1', 'Session 2'],
)
plt.show()

In [None]:
print('Running bootstrap significance test (Pearson)...')
bootstrap_results = rsa.bootstrap_rdm_comparison(
    exp_s1.spikes.data[:, valid_1],
    exp_s2.spikes.data[:, valid_2],
    stim_labels_1[valid_1],
    stim_labels_2[valid_2],
    metric='euclidean',
    comparison_method='pearson',
    n_bootstrap=100,
    random_state=42,
)

print(f'Observed similarity: {bootstrap_results["observed"]:.3f}')
print(f'95% CI: [{bootstrap_results["ci_lower"]:.3f}, '
      f'{bootstrap_results["ci_upper"]:.3f}]')
print(f'Bootstrap stability p-value: {bootstrap_results["p_value"]:.3f}')
print('  (Tests if observed is extreme relative to bootstrap mean;')
print('   ~0.5 means stable. CI above 0 confirms reliable similarity.)')

### 3.4 MVData integration

RSA works seamlessly with [`MVData`](https://driada.readthedocs.io/en/latest/api/dim_reduction/data_structures.html#driada.dim_reduction.data.MVData) objects from the DR pipeline. Pass
an `MVData` object directly to [`compute_rdm_unified`](https://driada.readthedocs.io/en/latest/api/rsa/core.html#driada.rsa.core.compute_rdm_unified).

In [None]:
from driada.dim_reduction.data import MVData

n_features = 100
n_timepoints = 1000
n_conditions = 5

condition_duration = n_timepoints // n_conditions
conditions = np.repeat(np.arange(n_conditions), condition_duration)

patterns = np.random.randn(n_conditions, n_features)
data_mvdata = np.zeros((n_features, n_timepoints))
for i, cond in enumerate(conditions):
    data_mvdata[:, i] = patterns[cond] + 0.1 * np.random.randn(n_features)

mvdata_rsa = MVData(data_mvdata)

print('Computing RDM from MVData object...')
rdm_mv, labels_mv = rsa.compute_rdm_unified(mvdata_rsa, items=conditions)

print(f'RDM shape: {rdm_mv.shape}')
print(f'Unique conditions: {labels_mv}')

fig = rsa.plot_rdm(
    rdm_mv, labels=[f'Cond {i}' for i in labels_mv],
    title='RDM from MVData', show_values=True,
)
plt.show()

## 4. Beyond calcium: DRIADA on RNN activations

DRIADA is not limited to calcium imaging. Any `(n_units x n_timepoints)`
matrix works. This section demonstrates the full pipeline on simulated
RNN units: behavioral input generation, RNN simulation, then INTENSE +
DR + network analysis.

Data is loaded with
[`load_exp_from_aligned_data`](https://driada.readthedocs.io/en/latest/api/experiment/loading.html#driada.experiment.exp_build.load_exp_from_aligned_data),
selectivity tested with [`compute_cell_feat_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html#driada.intense.pipelines.compute_cell_feat_significance), pairwise
dependencies found with
[`compute_cell_cell_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html#driada.intense.pipelines.compute_cell_cell_significance),
and the resulting adjacency matrix wrapped in a
[`Network`](https://driada.readthedocs.io/en/latest/api/network/core.html#driada.network.net_base.Network)
object.

The `Experiment` constructor accepts multiple neural data key aliases:
`calcium`, `activations`, `neural_data`, `activity`, `rates`.

In [None]:
import scipy.sparse as sp
import networkx as nx

import driada
from driada.dim_reduction import MVData
from driada.experiment import load_exp_from_aligned_data
from driada.intense import compute_cell_cell_significance
from driada.network import Network

# Configuration
CONFIG = {
    'n_units': 64,
    'tau': 0.2,
    'g_rec': 1.2,
    'w_in': 0.8,
    'noise_sigma': 0.05,
    'duration': 300,  # 5 min (shorter for notebook speed)
    'fps': 20,
    'tau_smooth': 1.0,
    'obs_noise': 0.01,
    'n_shuffles_stage1': 100,
    'n_shuffles_stage2': 5000,
    'pval_thr': 0.001,
    'ds': 5,
    'cc_n_shuffles_stage2': 5000,
    'cc_pval_thr': 0.01,
    'seed': 42,
}

In [None]:
print('[1] GENERATING BEHAVIORAL INPUTS')
print('-' * 40)

rng = np.random.default_rng(CONFIG['seed'])
n_frames = CONFIG['duration'] * CONFIG['fps']
dt = 1.0 / CONFIG['fps']

# Smooth random walk for position
step_std = 0.3 * np.sqrt(dt)
x = np.empty(n_frames)
y = np.empty(n_frames)
x[0], y[0] = 0.5, 0.5
dx_raw = rng.normal(0, step_std, n_frames)
dy_raw = rng.normal(0, step_std, n_frames)
for t in range(1, n_frames):
    x[t] = x[t - 1] + dx_raw[t]
    y[t] = y[t - 1] + dy_raw[t]
    # Reflecting boundaries
    if x[t] < 0: x[t] = -x[t]
    elif x[t] > 1: x[t] = 2.0 - x[t]
    if y[t] < 0: y[t] = -y[t]
    elif y[t] > 1: y[t] = 2.0 - y[t]

# Derived features
dx = np.diff(x, prepend=x[0])
dy = np.diff(y, prepend=y[0])
speed = np.sqrt(dx**2 + dy**2) * CONFIG['fps']
head_direction = np.arctan2(dy, dx) % (2 * np.pi)

# Trial type: block-structured categorical
trial_type = np.zeros(n_frames, dtype=int)
t = 0
while t < n_frames:
    label = rng.integers(0, 3)
    block_len = max(int(rng.exponential(7.0) * CONFIG['fps']), int(CONFIG['fps']))
    trial_type[t:t + block_len] = label
    t += block_len

# Sparse binary event
event = (rng.random(n_frames) < 0.03).astype(float)

inputs = {'x': x, 'y': y, 'speed': speed, 'head_direction': head_direction,
          'trial_type': trial_type, 'event': event}
print(f'  Frames: {n_frames}, features: {list(inputs.keys())}')

In [None]:
print('\n[2] SIMULATING RNN')
print('-' * 40)

n_units = CONFIG['n_units']
tau = CONFIG['tau']
g = CONFIG['g_rec']
w_in = CONFIG['w_in']
sigma = CONFIG['noise_sigma']

# Stack input channels
input_names = ['x', 'y', 'speed', 'head_direction', 'trial_type', 'event']
u = np.stack([inputs[k].astype(float) for k in input_names], axis=0)
n_input = u.shape[0]

# Random fixed weights
W_rec = rng.normal(0, g / np.sqrt(n_units), (n_units, n_units))
W_in = rng.normal(0, w_in / np.sqrt(n_input), (n_units, n_input))

# Euler integration
state = np.zeros(n_units)
raw = np.empty((n_units, n_frames))
for t in range(n_frames):
    r = np.maximum(state, 0)  # ReLU
    raw[:, t] = r
    noise = rng.normal(0, sigma, n_units)
    state += (dt / tau) * (-state + W_rec @ r + W_in @ u[:, t] + noise)

# Exponential smoothing (mimics slow indicator dynamics)
alpha = dt / CONFIG['tau_smooth']
activations = np.empty_like(raw)
activations[:, 0] = raw[:, 0]
for t in range(1, n_frames):
    activations[:, t] = (1 - alpha) * activations[:, t - 1] + alpha * raw[:, t]

activations += rng.normal(0, CONFIG['obs_noise'], activations.shape)
activations = np.maximum(activations, 0)

mean_rate = activations.mean()
frac_active = (activations > 0).mean()
print(f'  Activations: {activations.shape}')
print(f'  Mean rate: {mean_rate:.3f}, fraction active: {frac_active:.2f}')

In [None]:
print('\n[3] LOADING INTO DRIADA')
print('-' * 40)

data_rnn = {'activations': activations, **inputs}

exp_rnn = load_exp_from_aligned_data(
    data_source='RNN',
    exp_params={'name': 'random_rnn'},
    data=data_rnn,
    feature_types={'head_direction': 'circular', 'speed': 'linear'},
    aggregate_features={('x', 'y'): 'position_2d'},
    static_features={'fps': float(CONFIG['fps'])},
    create_circular_2d=True,
    verbose=True,
)
print(f'  Experiment: {exp_rnn.n_cells} units, {exp_rnn.n_frames} frames')
print(f'  Features: {list(exp_rnn.dynamic_features.keys())}')

In [None]:
print('\n[4] INTENSE SELECTIVITY ANALYSIS')
print('-' * 40)

stats_rnn, significance_rnn, info_rnn, results_rnn = (
    driada.compute_cell_feat_significance(
        exp_rnn,
        mode='two_stage',
        n_shuffles_stage1=CONFIG['n_shuffles_stage1'],
        n_shuffles_stage2=CONFIG['n_shuffles_stage2'],
        pval_thr=CONFIG['pval_thr'],
        ds=CONFIG['ds'],
        verbose=True,
    )
)
significant_neurons_rnn = exp_rnn.get_significant_neurons()

# Per-feature summary
feat_counts = {}
for feats in significant_neurons_rnn.values():
    for f in feats:
        feat_counts[f] = feat_counts.get(f, 0) + 1
n_mixed = sum(1 for feats in significant_neurons_rnn.values() if len(feats) > 1)

print(f'\n  Selective units: {len(significant_neurons_rnn)} / {exp_rnn.n_cells}')
for feat, cnt in sorted(feat_counts.items(), key=lambda x: -x[1]):
    print(f'    {feat}: {cnt} units')
print(f'  Mixed selectivity (>1 feature): {n_mixed} units')

In [None]:
print('\n[5] DIMENSIONALITY REDUCTION')
print('-' * 40)

mvdata_rnn = MVData(exp_rnn.calcium.data)
emb_rnn = mvdata_rnn.get_embedding(method='pca')
print(f'  PCA embedding: {emb_rnn.coords.shape}')

In [None]:
print('\n[6] CELL-CELL FUNCTIONAL NETWORK')
print('-' * 40)

sim_mat_rnn, sig_mat_rnn, pval_mat_rnn, cell_ids_rnn, cc_info_rnn = (
    compute_cell_cell_significance(
        exp_rnn,
        data_type='calcium',
        ds=CONFIG['ds'],
        n_shuffles_stage1=CONFIG['n_shuffles_stage1'],
        n_shuffles_stage2=CONFIG['cc_n_shuffles_stage2'],
        pval_thr=CONFIG['cc_pval_thr'],
        multicomp_correction='holm',
        verbose=True,
    )
)

n_sig_rnn = int(np.sum(np.triu(sig_mat_rnn, k=1)))
n_pairs_rnn = len(cell_ids_rnn) * (len(cell_ids_rnn) - 1) // 2
print(f'\n  Significant pairs: {n_sig_rnn} / {n_pairs_rnn}')

weighted_rnn = sp.csr_matrix(sim_mat_rnn * sig_mat_rnn)
net_rnn = Network(
    adj=weighted_rnn, preprocessing='giant_cc', name='RNN functional network'
)
print(f'  Network: {net_rnn.n} nodes, {net_rnn.graph.number_of_edges()} edges')

In [None]:
fig = plt.figure(figsize=(18, 14))
fps = CONFIG['fps']

show_frames = min(50 * fps, n_frames)
t_sec = np.arange(show_frames) / fps

# ---- Row 1: Data overview ----

# (1,1) Input signals
ax = fig.add_subplot(3, 3, 1)
signals = [
    ('x', inputs['x'][:show_frames]),
    ('speed', inputs['speed'][:show_frames]),
    ('HD', inputs['head_direction'][:show_frames] / (2 * np.pi)),
    ('trial', inputs['trial_type'][:show_frames].astype(float) / 2),
]
for i, (label, sig) in enumerate(signals):
    ax.plot(t_sec, sig + i * 1.2, lw=0.5)
    ax.text(-1, i * 1.2 + 0.4, label, fontsize=7, ha='right')
ax.set_xlabel('Time (s)')
ax.set_yticks([])
ax.set_title('Input signals')

# (1,2) RNN activity raster
ax = fig.add_subplot(3, 3, 2)
order = np.argsort(activations.mean(axis=1))
ax.imshow(
    activations[order, :show_frames], aspect='auto', cmap='inferno',
    extent=[0, show_frames / fps, 0, activations.shape[0]],
)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Unit (sorted)')
ax.set_title('RNN activations')

# (1,3) INTENSE selectivity heatmap
ax = fig.add_subplot(3, 3, 3)
feat_names_rnn = [
    f for f in exp_rnn.dynamic_features
    if f not in ('x', 'y', 'head_direction')
]
mi_matrix_rnn = np.zeros((exp_rnn.n_cells, len(feat_names_rnn)))
for uid, feats in significant_neurons_rnn.items():
    idx_n = int(uid)
    for fname in feats:
        if fname in feat_names_rnn:
            pair_stats = exp_rnn.get_neuron_feature_pair_stats(uid, fname)
            col = feat_names_rnn.index(fname)
            mi_matrix_rnn[idx_n, col] = pair_stats.get('me', 0)
im = ax.imshow(mi_matrix_rnn, aspect='auto', cmap='viridis')
ax.set_xlabel('Feature')
ax.set_ylabel('Unit')
ax.set_xticks(range(len(feat_names_rnn)))
ax.set_xticklabels(feat_names_rnn, rotation=45, ha='right', fontsize=7)
ax.set_title('Selectivity (MI, significant only)')
plt.colorbar(im, ax=ax, fraction=0.046, label='MI (bits)')

# ---- Row 2: PCA embedding colored by different variables ----
coords_rnn = emb_rnn.coords
ds_rnn = 10
x_pc = coords_rnn[0, ::ds_rnn]
y_pc = coords_rnn[1, ::ds_rnn]

color_vars = [
    ('x position', inputs['x'][::ds_rnn]),
    ('head direction', inputs['head_direction'][::ds_rnn]),
    ('trial type', inputs['trial_type'][::ds_rnn].astype(float)),
]
cmaps = ['viridis', 'twilight', 'Set1']
for i, (label, cvar) in enumerate(color_vars):
    ax = fig.add_subplot(3, 3, 4 + i)
    sc = ax.scatter(x_pc, y_pc, c=cvar, cmap=cmaps[i], s=1, alpha=0.3,
                    rasterized=True)
    ax.set_xlabel('PC 1')
    ax.set_ylabel('PC 2')
    ax.set_title(f'PCA colored by {label}')
    plt.colorbar(sc, ax=ax, fraction=0.046)

# ---- Row 3: Network analysis ----

# (3,1) Similarity matrix
ax = fig.add_subplot(3, 3, 7)
im = ax.imshow(sim_mat_rnn, cmap='hot', aspect='auto')
ax.set_xlabel('Unit')
ax.set_ylabel('Unit')
ax.set_title('Cell-cell similarity (MI)')
plt.colorbar(im, ax=ax, fraction=0.046)

# (3,2) Network graph
ax = fig.add_subplot(3, 3, 8)
if net_rnn.graph.number_of_edges() > 0:
    pos = nx.spring_layout(net_rnn.graph, seed=CONFIG['seed'])
    nx.draw_networkx_nodes(net_rnn.graph, pos, ax=ax, node_size=30,
                           node_color='steelblue')
    nx.draw_networkx_edges(net_rnn.graph, pos, ax=ax, alpha=0.2, width=0.5)
    ax.set_title(f'Functional network ({net_rnn.n} nodes, '
                 f'{net_rnn.graph.number_of_edges()} edges)')
else:
    ax.text(0.5, 0.5, 'No significant edges',
            ha='center', va='center', transform=ax.transAxes)
    ax.set_title('Functional network')
ax.axis('off')

# (3,3) Summary text
ax = fig.add_subplot(3, 3, 9)
ax.axis('off')
n_selective = len(significant_neurons_rnn)
n_mixed_rnn = sum(1 for feats in significant_neurons_rnn.values() if len(feats) > 1)
n_sig_pairs_rnn = int(np.sum(np.triu(sig_mat_rnn, k=1)))
total_pairs_rnn = exp_rnn.n_cells * (exp_rnn.n_cells - 1) // 2
density_rnn = n_sig_pairs_rnn / total_pairs_rnn if total_pairs_rnn > 0 else 0

text = (
    f"RNN: {CONFIG['n_units']} units, g={CONFIG['g_rec']}\n"
    f"Recording: {CONFIG['duration']}s at {CONFIG['fps']} Hz\n\n"
    f"INTENSE selectivity:\n"
    f"  Selective units: {n_selective}/{exp_rnn.n_cells}\n"
    f"  Mixed selectivity: {n_mixed_rnn}\n\n"
    f"Functional network:\n"
    f"  Significant pairs: {n_sig_pairs_rnn}/{total_pairs_rnn}\n"
    f"  Density: {density_rnn:.3f}\n"
    f"  Nodes in GCC: {net_rnn.n}"
)
ax.text(0.05, 0.95, text, transform=ax.transAxes, fontsize=10,
        verticalalignment='top', fontfamily='monospace')

plt.tight_layout()
plt.show()