# Population geometry & dimensionality reduction

Individual neurons encode specific variables (Notebook 02), but the
population *as a whole* forms a low-dimensional manifold whose geometry
reflects the task.  [**DRIADA**](https://driada.readthedocs.io) provides
a unified DR toolkit to extract, compare, and evaluate these manifolds.

| Step | Notebook | What it does |
|---|---|---|
| Load & inspect | [01 -- Data loading](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/01_data_loading_and_neurons.ipynb) | Wrap your recording into an `Experiment`, reconstruct spikes, assess quality |
| Single-neuron selectivity | [02 -- INTENSE](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/02_selectivity_detection_intense.ipynb) | Detect which neurons encode which behavioral variables |
| **Population geometry** | **03 -- this notebook** | Extract low-dimensional manifolds from population activity |
| Network analysis | [04 -- Networks](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/04_network_analysis.ipynb) | Build and analyze cell-cell interaction graphs |
| Putting it together | [05 -- Advanced](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/05_advanced_capabilities.ipynb) | Combine INTENSE + DR, leave-one-out importance, RSA, RNN analysis |

**Sections:**

1. **DR API quick reference** -- `MVData` wraps a matrix and provides
   one-line access to 7 DR methods.
2. **Method comparison** -- Systematic benchmark on synthetic datasets
   with quality metrics (k-NN preservation, trustworthiness, continuity,
   normalized stress).
3. **Sequential DR on neural data** -- PCA first (denoise), then UMAP.
   Often better than direct UMAP on high-dimensional neural recordings.
4. **Autoencoder-based DR** -- Standard AE with `continue_learning`,
   Beta-VAE, and PCA baseline on a circular manifold. Requires PyTorch.
5. **Circular manifold & dimensionality estimation** -- Head direction
   cells encode a ring. Extract it via DR and estimate intrinsic
   dimensionality.
6. **INTENSE-guided DR** -- Use INTENSE to select neurons before DR.
   Selective neurons produce cleaner spatial embeddings.

In [None]:
# TODO: revert to '!pip install -q driada' after v1.0.0 PyPI release
!pip install -q git+https://github.com/iabs-neuro/driada.git@main
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import tracemalloc
import warnings
from typing import Dict, Tuple
import pandas as pd

from sklearn.datasets import make_swiss_roll, make_s_curve, make_blobs
from sklearn.preprocessing import StandardScaler
from scipy.sparse import csr_matrix

# DRIADA dimensionality reduction
from driada.dim_reduction import (
    MVData,
    dr_sequence,
    knn_preservation_rate,
    trustworthiness,
    continuity,
    stress,
)
from driada.dim_reduction.manifold_metrics import (
    manifold_preservation_score,
    compute_embedding_alignment_metrics,
    procrustes_analysis,
)

# DRIADA network analysis (used for ProximityGraph demo in Section 1.3)
from driada.network import Network

# DRIADA experiment / synthetic data
from driada.experiment.synthetic import (
    generate_2d_manifold_data,
    generate_circular_manifold_data,
)
from driada.experiment import generate_circular_manifold_exp

# DRIADA dimensionality estimation
from driada.dimensionality import (
    eff_dim,
    correlation_dimension,
    geodesic_dimension,
    pca_dimension,
)

# DRIADA INTENSE + mixed population
from driada import (
    compute_cell_feat_significance,
    generate_mixed_population_exp,
)
from driada.utils import (
    compute_spatial_decoding_accuracy,
    compute_spatial_information,
)

# DRIADA visualization
from driada.utils.visual import (
    visualize_circular_manifold,
    plot_trajectories,
    plot_embeddings_grid,
    plot_neuron_selectivity_summary,
    DEFAULT_DPI,
)

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)

## 1. DR API quick reference

`MVData` wraps a *(n_features x n_samples)* matrix and provides one-line
DR via [`get_embedding`](https://driada.readthedocs.io/en/latest/api/dim_reduction/data_structures.html). Seven methods are available: **PCA**, **Isomap**, **LLE**,
**Laplacian Eigenmaps**, and **UMAP**. For multi-step
pipelines, see [`dr_sequence`](https://driada.readthedocs.io/en/latest/api/dim_reduction/algorithms.html).

```python
emb = mvdata.get_embedding(method='pca')          # defaults
emb = mvdata.get_embedding(method='umap',          # with params
                           n_neighbors=30, min_dist=0.1)
```

In [None]:
# Generate Swiss roll data for demonstration
n_samples = 1000
X_raw, color = make_swiss_roll(n_samples, noise=0.1, random_state=42)
X = X_raw.T  # Transpose to match MVData format (features x samples)

# Create MVData object
mvdata = MVData(X)

print('=' * 60)
print('DIMENSIONALITY REDUCTION METHODS')
print('=' * 60)

# Linear methods
print('\n--- Linear Methods ---')

print('\n1. PCA (Principal Component Analysis):')
print('   Default usage:')
print("     emb = mvdata.get_embedding(method='pca')")
print('   With parameters:')
print("     emb = mvdata.get_embedding(method='pca', dim=3)")
emb_pca = mvdata.get_embedding(method='pca')
print(f'   -> Result shape: {emb_pca.coords.shape}')

# Manifold learning methods
print('\n--- Manifold Learning Methods ---')

print('\n2. Isomap (Isometric Mapping):')
print('   Default usage:')
print("     emb = mvdata.get_embedding(method='isomap')")
print('   With parameters:')
print("     emb = mvdata.get_embedding(method='isomap', n_neighbors=30, dim=3)")
emb_iso = mvdata.get_embedding(method='isomap')
print(f'   -> Result shape: {emb_iso.coords.shape}')

print('\n3. LLE (Locally Linear Embedding):')
print('   Default usage:')
print("     emb = mvdata.get_embedding(method='lle')")
emb_lle = mvdata.get_embedding(method='lle')
print(f'   -> Result shape: {emb_lle.coords.shape}')

print('\n4. Laplacian Eigenmaps:')
print('   Default usage:')
print("     emb = mvdata.get_embedding(method='le')")
emb_le = mvdata.get_embedding(method='le')
print(f'   -> Result shape: {emb_le.coords.shape}')

In [None]:
# Visualization methods
print('--- Visualization Methods ---')

print('\n5. UMAP (Uniform Manifold Approximation and Projection):')
print('   Default usage:')
print("     emb = mvdata.get_embedding(method='umap')")
print('   With parameters:')
print("     emb = mvdata.get_embedding(method='umap', n_neighbors=50, min_dist=0.3)")
emb_umap = mvdata.get_embedding(method='umap', n_neighbors=50, min_dist=0.3)
print(f'   -> Result shape: {emb_umap.coords.shape}')

# Show parameter options
print('\n' + '=' * 60)
print('COMMON PARAMETERS')
print('=' * 60)
print('\nAll methods accept:')
print('  dim: int - Number of output dimensions (default: 2)')
print('\nGraph-based methods accept:')
print('  n_neighbors: int - Number of nearest neighbors')
print('\nUMAP specific:')
print('  min_dist: float - Minimum distance between points in embedding')
print('\nDiffusion maps specific:')
print('  dm_alpha: float - Diffusion map alpha parameter')

In [None]:
# Visualize PCA, Isomap, UMAP, LE side by side
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for i, (emb, name) in enumerate(zip(
    [emb_pca, emb_iso, emb_umap, emb_le],
    ['PCA', 'Isomap', 'UMAP', 'Laplacian Eigenmaps'],
)):
    ax = axes[i]
    coords = emb.coords

    scatter = ax.scatter(
        coords[0, :], coords[1, :], c=color, cmap='viridis', s=20, alpha=0.7
    )
    ax.set_title(f'{name} Embedding')
    ax.set_xlabel('Component 1')
    ax.set_ylabel('Component 2')

    # Add colorbar to first subplot
    if i == 0:
        plt.colorbar(scatter, ax=ax, label='Position on roll')

plt.tight_layout()
plt.show()

### Advanced: sequential DR, custom metrics

In [None]:
# Sequential dimensionality reduction (PCA -> UMAP)
print('ADVANCED USAGE PATTERNS')
print('=' * 60)

# Generate high-dimensional data
high_dim_data = np.random.randn(100, 500)  # 100 features, 500 samples
mvdata_highdim = MVData(high_dim_data)

print('\n1. High-dimensional data:')
emb_10d = mvdata_highdim.get_embedding(method='pca', dim=10)
print(f'   -> 10D embedding shape: {emb_10d.coords.shape}')

mvdata_10d = MVData(emb_10d.coords)
emb_2d = mvdata_10d.get_embedding(method='umap')
print(f'   -> Final 2D embedding shape: {emb_2d.coords.shape}')

print('\n2. Using custom metrics:')
emb_cosine = mvdata.get_embedding(method='isomap', metric='cosine')
print(f'   -> Cosine metric embedding shape: {emb_cosine.coords.shape}')

print('\n3. Handling sparse data:')
sparse_data = csr_matrix(X)
print(f'   -> Sparse matrix shape: {sparse_data.shape}')
mvdata_sparse = MVData(sparse_data)
emb_sparse = mvdata_sparse.get_embedding(method='pca')
print(f'   -> Sparse data embedding shape: {emb_sparse.coords.shape}')

print('\n4. Sequential dimensionality reduction (use high-dim data):')
print('   Method 1 (intuitive - manual chaining):')
emb1_seq = mvdata_highdim.get_embedding(method='pca', dim=20)
mvdata2_seq = MVData(emb1_seq.coords)
emb2_seq = mvdata2_seq.get_embedding(method='umap', dim=2)
print(f'   -> Result shape: {emb2_seq.coords.shape}')

print('\n   Method 2 (recommended - using dr_sequence):')
emb_seq = dr_sequence(mvdata_highdim, steps=[
    ('pca', {'dim': 20}),
    ('umap', {'dim': 2})
])
print(f'   -> Result shape: {emb_seq.coords.shape}')

### Graph structure behind DR

Graph-based DR methods (Isomap, LLE, Laplacian Eigenmaps) don't just
produce coordinates -- they construct an internal **proximity graph** where
nodes are data points and edges connect neighbors. In DRIADA, this graph
is a [`ProximityGraph`](https://driada.readthedocs.io/en/latest/api/dim_reduction/data_structures.html)
that **inherits from [`Network`](https://driada.readthedocs.io/en/latest/api/network/core.html)**,
giving you access to spectral decomposition, entropy, community detection,
and all other `Network` analysis methods.

Access it via `embedding.graph` after running any graph-based method.
For a full treatment of network spectral analysis, see
[Notebook 04 -- Network analysis](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/04_network_analysis.ipynb).

In [None]:
# The Isomap embedding from Section 1 built a k-NN proximity graph internally
pgraph = emb_iso.graph

print(f"Type: {type(pgraph).__name__}")
print(f"  inherits from Network: {isinstance(pgraph, Network)}")
print(f"Nodes: {pgraph.n}")
print(f"Edges: {pgraph.adj.nnz // 2}")
print(f"Mean degree: {pgraph.deg.mean():.1f}")
print(f"Metric used: {pgraph.metric}")

In [None]:
# Spectral analysis of the k-NN graph that powers Isomap
pgraph.diagonalize(mode='nlap')
nlap_spectrum = pgraph.get_spectrum('nlap')
ipr = pgraph.get_ipr('nlap')

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Normalized Laplacian spectrum
sorted_spec = np.sort(np.real(nlap_spectrum))
axes[0].plot(sorted_spec, 'o', markersize=2)
axes[0].set_xlabel('Index')
axes[0].set_ylabel('Eigenvalue')
axes[0].set_title('Normalized Laplacian spectrum')
axes[0].grid(True, alpha=0.3)

# IPR -- eigenvector localization
axes[1].plot(np.sort(ipr), 'o', markersize=2)
axes[1].axhline(1.0 / pgraph.n, color='r', linestyle='--',
                label=f'1/N = {1.0/pgraph.n:.4f}')
axes[1].set_xlabel('Eigenvector index (sorted)')
axes[1].set_ylabel('IPR')
axes[1].set_title('Inverse participation ratio')
axes[1].legend(fontsize=9)
axes[1].grid(True, alpha=0.3)

# Thermodynamic entropy
tlist = np.logspace(-2, 2, 50)
entropy = pgraph.calculate_thermodynamic_entropy(tlist, norm=True)
axes[2].semilogx(tlist, entropy, linewidth=2)
axes[2].set_xlabel('Temperature')
axes[2].set_ylabel('Entropy (bits)')
axes[2].set_title('Von Neumann entropy S(t)')
axes[2].grid(True, alpha=0.3)

plt.suptitle('Spectral analysis of Isomap k-NN graph', fontsize=14)
plt.tight_layout()
plt.show()

print(f'Fiedler value: {sorted_spec[1]:.4f}')
print(f'Spectral gap: {sorted_spec[1] - sorted_spec[0]:.4f}')
print(f'Max entropy: {np.max(entropy):.2f} bits '
      f'(upper bound = log2(N) = {np.log2(pgraph.n):.2f})')

The Laplacian spectrum reveals the graph's connectivity structure:
a large spectral gap indicates the graph is well-connected, while
clustered eigenvalues near zero suggest loosely connected components.
The IPR shows whether eigenvectors are delocalized (spread across
all nodes) or localized (concentrated on a few).

These same spectral tools apply to *any* `Network` -- functional
connectivity from INTENSE, structural connectomes, or correlation
networks. See [Notebook 04](https://colab.research.google.com/github/iabs-neuro/driada/blob/main/notebooks/04_network_analysis.ipynb)
for the full spectral analysis toolkit.

## 2. Method comparison

Systematic benchmark on multiple synthetic datasets. Quality metrics:

- [`knn_preservation_rate`](https://driada.readthedocs.io/en/latest/api/dim_reduction/manifold_metrics.html) -- fraction of original k nearest neighbors preserved in the embedding.
- [`trustworthiness`](https://driada.readthedocs.io/en/latest/api/dim_reduction/manifold_metrics.html) -- fraction of embedding neighbors that are true neighbors in the original space.
- [`continuity`](https://driada.readthedocs.io/en/latest/api/dim_reduction/manifold_metrics.html) -- fraction of true neighbors that remain neighbors in the embedding.
- [`stress`](https://driada.readthedocs.io/en/latest/api/dim_reduction/manifold_metrics.html) -- normalized Frobenius distance between the original and embedded distance matrices.

For a single composite score, use [`manifold_preservation_score`](https://driada.readthedocs.io/en/latest/api/dim_reduction/manifold_metrics.html).

In [None]:
def generate_test_datasets(n_samples=1000, noise=0.0, seed=42):
    """Generate various test datasets for DR method comparison."""
    np.random.seed(seed)
    datasets = {}

    # 1. Swiss Roll - classic nonlinear manifold
    print('Generating Swiss Roll...')
    X_swiss, color_swiss = make_swiss_roll(
        n_samples=n_samples, noise=noise, random_state=seed
    )
    datasets['swiss_roll'] = (X_swiss, color_swiss)

    # 2. S-Curve - another nonlinear manifold
    print('Generating S-Curve...')
    X_scurve, color_scurve = make_s_curve(
        n_samples=n_samples, noise=noise, random_state=seed
    )
    datasets['s_curve'] = (X_scurve, color_scurve)

    # 3. Circular manifold - tests circular topology
    print('Generating Circular manifold...')
    angles = np.linspace(0, 2 * np.pi, n_samples, endpoint=False)
    circle_3d = np.column_stack([
        np.cos(angles), np.sin(angles),
        0.1 * np.random.randn(n_samples),
    ])
    datasets['circle_3d'] = (circle_3d, angles)

    # 4. High-dimensional Gaussian (intrinsic dim ~5 in 50D)
    print('Generating High-D Gaussian...')
    U = np.random.randn(50, 5)
    V = np.random.randn(5, n_samples)
    X_gaussian = (U @ V + noise * np.random.randn(50, n_samples)).T
    datasets['gaussian_50d'] = (X_gaussian, X_gaussian @ U[:, 0])

    # 5. Clustered data (5 clusters in 20D)
    print('Generating Clustered data...')
    X_clusters, y_clusters = make_blobs(
        n_samples=n_samples, n_features=20,
        centers=5, cluster_std=0.5, random_state=seed,
    )
    datasets['clusters_20d'] = (X_clusters, y_clusters)

    # 6. Noisy sphere
    print('Generating Noisy sphere...')
    phi = np.random.uniform(0, 2 * np.pi, n_samples)
    theta = np.random.uniform(0, np.pi, n_samples)
    sphere_3d = np.column_stack([
        np.sin(theta) * np.cos(phi),
        np.sin(theta) * np.sin(phi),
        np.cos(theta),
    ])
    sphere_3d += noise * np.random.randn(n_samples, 3)
    datasets['sphere_3d'] = (sphere_3d, phi)

    return datasets


# Generate test datasets
print('1. GENERATING TEST DATASETS')
print('-' * 40)
datasets = generate_test_datasets(n_samples=1000, noise=0.05)
print(f'\nGenerated {len(datasets)} test datasets')

In [None]:
def get_dr_method_configs() -> Dict[str, Dict]:
    """Get configuration parameters for each DR method."""
    configs = {}
    configs['pca'] = {
        'params': {},
        'description': 'Linear projection maximizing variance',
    }
    configs['isomap'] = {
        'params': {'n_neighbors': 10, 'max_deleted_nodes': 0.3},
        'description': 'Preserves geodesic distances',
    }
    configs['umap'] = {
        'params': {'n_neighbors': 15, 'min_dist': 0.1},
        'description': 'Balances local and global structure',
    }
    return configs


# Configure methods
print('2. CONFIGURING DR METHODS')
print('-' * 40)
methods = get_dr_method_configs()
print(f'Configured {len(methods)} DR methods: {", ".join(methods.keys())}')

In [None]:
def evaluate_dr_method(data, labels, method_name, method_config, k_neighbors=10):
    """Evaluate a single DR method on a dataset and return quality metrics."""
    results = {
        'method': method_name,
        'description': method_config.get('description', ''),
        'success': False,
    }

    try:
        mvdata = MVData(data.T)  # MVData expects (n_features, n_samples)
        tracemalloc.start()
        start_time = time.time()
        params = method_config['params']

        if method_config.get('requires_distmat', False):
            mvdata.get_distmat()

        embedding = mvdata.get_embedding(method=method_name, **params)
        runtime = time.time() - start_time
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()

        coords = embedding.coords.T  # Shape: (n_samples, n_dims)

        # Handle lost nodes if graph-based method
        data_filtered, labels_filtered = data, labels
        if hasattr(embedding, 'graph') and hasattr(embedding.graph, 'lost_nodes'):
            lost_nodes = embedding.graph.lost_nodes
            if len(lost_nodes) > 0:
                kept_mask = np.ones(data.shape[0], dtype=bool)
                kept_mask[lost_nodes] = False
                data_filtered = data[kept_mask]
                labels_filtered = labels[kept_mask]

        # Compute quality metrics
        if coords.shape[0] == data_filtered.shape[0]:
            results.update({
                'success': True,
                'embedding': coords,
                'labels': labels_filtered,
                'runtime': runtime,
                'memory_mb': peak / 1024 / 1024,
                'knn_preservation': knn_preservation_rate(data_filtered, coords, k=k_neighbors),
                'trustworthiness': trustworthiness(data_filtered, coords, k=k_neighbors),
                'continuity': continuity(data_filtered, coords, k=k_neighbors),
                'stress': stress(data_filtered, coords, normalized=True),
                'n_samples': coords.shape[0],
                'n_lost': data.shape[0] - coords.shape[0],
            })
        else:
            results['error'] = (
                f'Dimension mismatch: {coords.shape[0]} vs {data_filtered.shape[0]}'
            )
    except Exception as e:
        results['error'] = f'{type(e).__name__}: {str(e)}'

    return results

In [None]:
def run_comparison(
    datasets: Dict[str, Tuple[np.ndarray, np.ndarray]],
    methods: Dict[str, Dict],
) -> pd.DataFrame:
    """
    Run systematic comparison of DR methods on all datasets.
    """
    all_results = []
    total_comparisons = len(datasets) * len(methods)
    current = 0

    for dataset_name, (data, labels) in datasets.items():
        print(f'\n{"=" * 60}')
        print(f'Dataset: {dataset_name} (shape: {data.shape})')
        print(f'{"=" * 60}')

        # Standardize data
        scaler = StandardScaler()
        data_scaled = scaler.fit_transform(data)

        for method_name, method_config in methods.items():
            current += 1
            print(
                f'\n[{current}/{total_comparisons}] Evaluating {method_name}...',
                end='', flush=True,
            )

            result = evaluate_dr_method(data_scaled, labels, method_name, method_config)
            result['dataset'] = dataset_name
            result['n_features'] = data.shape[1]

            if result['success']:
                print(
                    f' Done! (runtime: {result["runtime"]:.2f}s, '
                    + f'k-NN: {result["knn_preservation"]:.3f}, '
                    + f'trust: {result["trustworthiness"]:.3f})'
                )
            else:
                print(f' Failed! ({result.get("error", "Unknown error")})')

            all_results.append(result)

    return pd.DataFrame(all_results)


# Run the comparison
print('\n3. RUNNING SYSTEMATIC COMPARISON')
print('-' * 40)
results_df = run_comparison(datasets, methods)

### Speed benchmark & quality summary

In [None]:
# Generate recommendations
success_df = results_df[results_df['success']].copy()

# Calculate summary statistics
method_summary = (
    success_df.groupby('method')
    .agg({
        'knn_preservation': 'mean',
        'trustworthiness': 'mean',
        'continuity': 'mean',
        'stress': 'mean',
        'runtime': 'mean',
    })
    .round(3)
)

# Best overall quality
success_df['avg_quality'] = success_df[
    ['knn_preservation', 'trustworthiness', 'continuity']
].mean(axis=1)
best_quality = success_df.groupby('method')['avg_quality'].mean().idxmax()

# Fastest method
fastest = method_summary['runtime'].idxmin()

# Best for visualization (high trustworthiness)
best_viz = method_summary['trustworthiness'].idxmax()

print('[SUMMARY] METHOD SUMMARY:')
print(method_summary)

print(f'\n[BEST] Best Overall Quality: {best_quality}')
print(f'[BEST] Fastest: {fastest} (avg {method_summary.loc[fastest, "runtime"]:.3f}s)')
print(f'[BEST] Best Visualization: {best_viz} '
       f'(trustworthiness: {method_summary.loc[best_viz, "trustworthiness"]:.3f})')

print('\n[RECOMMENDATIONS] USE CASE RECOMMENDATIONS:')
print('  - Exploratory visualization: UMAP')
print('  - Distance preservation: Isomap')
print('  - Linear relationships: PCA - fast, interpretable')
print('  - Manifold learning: Isomap or UMAP')
print('  - Large datasets: PCA or UMAP - computationally efficient')
print('  - Geodesic distances: Isomap')

In [None]:
# Visualizations: quality metrics heatmap
plt.figure(figsize=(12, 8))

metrics = ['knn_preservation', 'trustworthiness', 'continuity']
metric_labels = {
    'knn_preservation': 'k-NN preservation',
    'trustworthiness': 'Trustworthiness',
    'continuity': 'Continuity',
}

for i, metric in enumerate(metrics):
    plt.subplot(2, 2, i + 1)
    pivot = success_df.pivot_table(
        values=metric, index='method', columns='dataset', aggfunc='mean'
    )
    sns.heatmap(
        pivot, annot=True, fmt='.3f', cmap='RdYlGn',
        vmin=0, vmax=1,
        cbar_kws={'label': metric_labels[metric]},
    )
    plt.title(f'{metric_labels[metric]} by method and dataset')
    plt.xlabel('Dataset')
    plt.ylabel('Method')

# Runtime comparison
plt.subplot(2, 2, 4)
runtime_pivot = success_df.pivot_table(
    values='runtime', index='method', columns='dataset', aggfunc='mean'
)
sns.heatmap(
    np.log10(runtime_pivot + 0.001), annot=runtime_pivot.round(2),
    fmt='g', cmap='YlOrRd',
    cbar_kws={'label': 'log10(Runtime in seconds)'},
)
plt.title('Runtime by method and dataset')
plt.xlabel('Dataset')
plt.ylabel('Method')

plt.tight_layout()
plt.show()

In [None]:
# Quality vs Speed trade-off
plt.figure(figsize=(10, 6))

for dataset in success_df['dataset'].unique():
    mask = success_df['dataset'] == dataset
    plt.scatter(
        success_df[mask]['runtime'],
        success_df[mask]['avg_quality'],
        label=dataset, s=100, alpha=0.7,
    )

    # Add method labels
    for _, row in success_df[mask].iterrows():
        plt.annotate(
            row['method'],
            (row['runtime'], row['avg_quality']),
            xytext=(5, 5), textcoords='offset points',
            fontsize=8, alpha=0.7,
        )

plt.xscale('log')
plt.xlabel('Runtime (seconds, log scale)')
plt.ylabel('Average Quality Score')
plt.title('Quality vs speed trade-off')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Example embeddings for Swiss Roll
swiss_results = success_df[success_df['dataset'] == 'swiss_roll']

if len(swiss_results) > 0:
    n_methods = len(swiss_results)
    fig, axes = plt.subplots(1, n_methods, figsize=(5 * n_methods, 5))
    if n_methods == 1:
        axes = [axes]

    for i, (_, result) in enumerate(swiss_results.iterrows()):
        if i >= n_methods:
            break

        ax = axes[i]
        embedding = result['embedding']
        labels = result['labels']

        scatter = ax.scatter(
            embedding[:, 0], embedding[:, 1],
            c=labels, cmap='viridis', s=20, alpha=0.7,
        )

        ax.set_title(
            f"{result['method']}\n"
            f"(k-NN: {result['knn_preservation']:.3f}, "
            f"Trust: {result['trustworthiness']:.3f})"
        )
        ax.set_xlabel('Dim 1')
        ax.set_ylabel('Dim 2')

    plt.suptitle('Swiss roll embeddings', fontsize=16)
    plt.tight_layout()
    plt.show()

## 3. Sequential DR on neural data

PCA first (denoise) followed by UMAP. Often better than direct UMAP on
high-dimensional neural data because PCA removes noise-dominated
dimensions before the nonlinear step.

In [None]:
def compute_manifold_metrics(true_positions, embedding_coords, k=10):
    """Compute various manifold preservation metrics."""
    metrics = {}

    # KNN preservation
    metrics['knn_preservation'] = knn_preservation_rate(
        true_positions.T, embedding_coords.T, k=k,
    )

    # Trustworthiness and continuity
    metrics['trustworthiness'] = trustworthiness(
        true_positions.T, embedding_coords.T, k=k
    )
    metrics['continuity'] = continuity(
        true_positions.T, embedding_coords.T, k=k
    )

    # Overall manifold preservation score (returns dict)
    manifold_scores = manifold_preservation_score(
        true_positions.T, embedding_coords.T, k_neighbors=k
    )
    # Extract the overall score
    metrics['manifold_score'] = manifold_scores['overall_score']

    return metrics


print('Generating synthetic neural data from 2D spatial environment...')

# Generate synthetic data with 2D spatial manifold
calcium, positions, place_field_centers, firing_rates = generate_2d_manifold_data(
    n_neurons=100,
    duration=800,  # seconds
    fps=20.0,  # Hz
    field_sigma=0.1,
    step_size=0.02,
    seed=123,
    verbose=True,
)

# Extract neural activity and true positions
neural_data = calcium  # (n_neurons, n_timepoints)
true_positions = positions  # (2, n_timepoints)

print(f'Neural data shape: {neural_data.shape}')
print(f'True positions shape: {true_positions.shape}')

In [None]:
# Create MVData object with downsampling
mvdata_neural = MVData(neural_data, downsampling=5)

# Approach 1: Direct UMAP on all neurons
print('\n=== Approach 1: Direct UMAP ===')
embedding_direct = mvdata_neural.get_embedding(
    method='umap', dim=2, n_neighbors=50, min_dist=0.8
)

# Approach 2: PCA -> UMAP sequence
print('\n=== Approach 2: PCA -> UMAP ===')
embedding_sequence = dr_sequence(
    mvdata_neural,
    steps=[
        ('pca', {'dim': 20}),  # First reduce to 20 PCs
        ('umap', {'dim': 2, 'n_neighbors': 50, 'min_dist': 0.8}),
    ],
)

# Compute manifold preservation metrics
print('\n=== Manifold Preservation Metrics ===')

# Downsample true positions to match the embeddings
true_positions_ds = true_positions[:, ::5]

metrics_direct = compute_manifold_metrics(
    true_positions_ds, embedding_direct.coords
)
metrics_sequence = compute_manifold_metrics(
    true_positions_ds, embedding_sequence.coords
)

print('\nDirect UMAP:')
for name, value in metrics_direct.items():
    print(f'  {name}: {value:.4f}')

print('\nPCA -> UMAP:')
for name, value in metrics_sequence.items():
    print(f'  {name}: {value:.4f}')

In [None]:
# Visualization: side-by-side comparison
fig = plt.figure(figsize=(15, 5))

# Plot true positions
ax1 = plt.subplot(131)
scatter = ax1.scatter(
    true_positions_ds[0], true_positions_ds[1],
    c=np.arange(true_positions_ds.shape[1]),
    cmap='viridis', alpha=0.6, s=20,
)
ax1.set_title('True 2D Positions')
ax1.set_xlabel('X position')
ax1.set_ylabel('Y position')
ax1.set_aspect('equal')

# Plot direct UMAP embedding
ax2 = plt.subplot(132)
ax2.scatter(
    embedding_direct.coords[0], embedding_direct.coords[1],
    c=np.arange(embedding_direct.coords.shape[1]),
    cmap='viridis', alpha=0.6, s=20,
)
ax2.set_title(
    f'Direct UMAP\n(Manifold score: {metrics_direct["manifold_score"]:.3f})'
)
ax2.set_xlabel('UMAP 1')
ax2.set_ylabel('UMAP 2')
ax2.set_aspect('equal')

# Plot PCA->UMAP embedding
ax3 = plt.subplot(133)
ax3.scatter(
    embedding_sequence.coords[0], embedding_sequence.coords[1],
    c=np.arange(embedding_sequence.coords.shape[1]),
    cmap='viridis', alpha=0.6, s=20,
)
ax3.set_title(
    f'PCA -> UMAP\n(Manifold score: {metrics_sequence["manifold_score"]:.3f})'
)
ax3.set_xlabel('UMAP 1')
ax3.set_ylabel('UMAP 2')
ax3.set_aspect('equal')

plt.colorbar(scatter, ax=[ax1, ax2, ax3], label='Time', fraction=0.02)
plt.tight_layout()
plt.show()

# Print improvement percentages
print('\n=== Method Comparison ===')
print('PCA -> UMAP vs Direct UMAP:')
for metric in ['knn_preservation', 'trustworthiness', 'continuity']:
    improvement = (
        (metrics_sequence[metric] - metrics_direct[metric])
        / metrics_direct[metric] * 100
    )
    print(f'  {metric}: {improvement:+.1f}%')

In [None]:
# Detailed metrics comparison bar chart
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot true positions for reference
time_points = np.arange(true_positions_ds.shape[1])
ax = axes[0]
scatter = ax.scatter(
    true_positions_ds[0], true_positions_ds[1],
    c=time_points, cmap='viridis', s=20, alpha=0.6,
)
ax.set_title('True 2D Positions')
ax.set_xlabel('X position')
ax.set_ylabel('Y position')
ax.set_aspect('equal')
plt.colorbar(scatter, ax=ax, label='Time')

# Summary metrics comparison
ax = axes[1]
metrics_names = list(metrics_direct.keys())
x = np.arange(len(metrics_names))
width = 0.35

values_direct = [metrics_direct[m] for m in metrics_names]
values_sequence = [metrics_sequence[m] for m in metrics_names]

ax.bar(x - width / 2, values_direct, width, label='Direct UMAP', alpha=0.8)
ax.bar(x + width / 2, values_sequence, width, label='PCA -> UMAP', alpha=0.8)

ax.set_ylabel('Score')
ax.set_title('Manifold preservation metrics')
ax.set_xticks(x)
ax.set_xticklabels(
    [m.replace('_', '\n') for m in metrics_names], rotation=45, ha='right'
)
ax.legend()
ax.set_ylim(0, 1.1)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Print improvement summary
print('\n=== Improvement Summary ===')
for metric in metrics_names:
    improvement = (
        (metrics_sequence[metric] - metrics_direct[metric])
        / metrics_direct[metric] * 100
    )
    print(f'{metric}: {improvement:+.1f}%')

## 4. Autoencoder-based DR

Neural network DR alternatives via [`flexible_ae`](https://driada.readthedocs.io/en/latest/api/dim_reduction/neural_methods.html): **standard autoencoder** (AE) with
`continue_learning`, **Beta-VAE** with KL divergence, and a PCA baseline.
Key parameters: `architecture` (`'ae'` or `'vae'`) selects the model type,
and `continue_learning` resumes training without resetting weights.
Requires PyTorch.

In [None]:
try:
    import torch  # noqa: F401
    HAS_TORCH = True
    print('PyTorch available -- autoencoder examples will run.')
except ImportError:
    HAS_TORCH = False
    print(
        'PyTorch not found. Install with: pip install torch\n'
        'Autoencoder cells will be skipped.'
    )

In [None]:
if HAS_TORCH:
    print('=' * 60)
    print('DRIADA autoencoder DR example')
    print('=' * 60)

    # ------------------------------------------------------------------
    # 1. Generate synthetic data (head direction cells on circular manifold)
    # ------------------------------------------------------------------
    print('\n[1] Generating synthetic head direction cell data')
    print('-' * 40)
    calcium_ae, head_direction_ae, preferred_dirs_ae, rates_ae = (
        generate_circular_manifold_data(
            n_neurons=200,
            kappa=4.0,
            duration=1200,
            seed=42,
            verbose=True,
        )
    )
    print(f'  Calcium shape: {calcium_ae.shape}')
    print(f'  Head direction shape: {head_direction_ae.shape}')

    mvdata_ae = MVData(calcium_ae, verbose=False)
    color_ae = head_direction_ae  # angle for coloring

In [None]:
if HAS_TORCH:
    # ------------------------------------------------------------------
    # 2. Standard autoencoder with continue_learning
    # ------------------------------------------------------------------
    print('\n[2] Standard autoencoder')
    print('-' * 40)

    # Train for 5 epochs (not fully converged)
    emb_ae = mvdata_ae.get_embedding(
        method='flexible_ae',
        dim=2,
        architecture='ae',
        inter_dim=64,
        epochs=5,
        lr=1e-3,
        feature_dropout=0.1,
        loss_components=[{'name': 'reconstruction', 'weight': 1.0, 'loss_type': 'mse'}],
        verbose=False,
    )
    print(f'  After 5 epochs   - loss: {emb_ae.nn_loss:.4f}')

    # Continue training for 25 more epochs
    emb_ae.continue_learning(25, lr=1e-3, verbose=False)
    print(f'  After 25 more    - loss: {emb_ae.nn_loss:.4f}')

    # Fine-tune with lower learning rate
    emb_ae.continue_learning(20, lr=1e-4, verbose=False)
    print(f'  After 20 fine-tune - loss: {emb_ae.nn_loss:.4f}')

    # ------------------------------------------------------------------
    # 3. Beta-VAE
    # ------------------------------------------------------------------
    print('\n[3] Beta-VAE (beta=4.0)')
    print('-' * 40)
    emb_vae = mvdata_ae.get_embedding(
        method='flexible_ae',
        dim=2,
        architecture='vae',
        inter_dim=64,
        epochs=150,
        lr=1e-3,
        feature_dropout=0.1,
        loss_components=[
            {'name': 'reconstruction', 'weight': 1.0, 'loss_type': 'mse'},
            {'name': 'beta_vae', 'weight': 1.0, 'beta': 4.0},
        ],
        verbose=False,
    )
    print(f'  Embedding shape: {emb_vae.coords.shape}')
    print(f'  Test loss (reconstruction + KL): {emb_vae.nn_loss:.4f}')

    # ------------------------------------------------------------------
    # 4. PCA baseline
    # ------------------------------------------------------------------
    print('\n[4] PCA baseline')
    print('-' * 40)
    emb_pca_ae = mvdata_ae.get_embedding(method='pca', dim=2)
    print(f'  Embedding shape: {emb_pca_ae.coords.shape}')

In [None]:
if HAS_TORCH:
    # ------------------------------------------------------------------
    # 5. Side-by-side visualization
    # ------------------------------------------------------------------
    print('\n[5] Creating comparison plot')
    print('-' * 40)

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    embeddings_ae = [
        (emb_pca_ae, 'PCA'),
        (emb_ae, 'AE (5 + 25 + 20 epochs)'),
        (emb_vae, 'Beta-VAE'),
    ]

    for ax, (emb, title) in zip(axes, embeddings_ae):
        coords = emb.coords  # (dim, n_samples)
        sc = ax.scatter(
            coords[0], coords[1], c=color_ae, cmap='hsv',
            s=2, alpha=0.5, vmin=0, vmax=2 * np.pi
        )
        ax.set_title(title)
        ax.set_xlabel('Dim 1')
        ax.set_ylabel('Dim 2')

    fig.colorbar(sc, ax=axes[-1], label='Head direction (rad)')
    plt.suptitle('Circular manifold recovery (colored by head direction)')
    plt.tight_layout()
    plt.show()

    print('\n' + '=' * 60)
    print('Autoencoder DR example complete')
    print('=' * 60)

## 5. Circular manifold & dimensionality estimation

Head direction cells encode a ring. We generate 100 HD cells with
[`generate_circular_manifold_exp`](https://driada.readthedocs.io/en/latest/api/experiment/synthetic.html), estimate
the intrinsic dimensionality ([`pca_dimension`](https://driada.readthedocs.io/en/latest/api/dimensionality/linear.html) scree,
[`correlation_dimension`](https://driada.readthedocs.io/en/latest/api/dimensionality/intrinsic.html),
[`geodesic_dimension`](https://driada.readthedocs.io/en/latest/api/dimensionality/intrinsic.html),
[`eff_dim`](https://driada.readthedocs.io/en/latest/api/dimensionality/effective.html) participation ratio), compare real vs
shuffled data, and extract the circular manifold via DR
([`visualize_circular_manifold`](https://driada.readthedocs.io/en/latest/api/utils/visualization.html)).

In [None]:
def estimate_dimensionality(neural_data, methods=None, ds=1):
    """Estimate intrinsic dimensionality using multiple DRIADA methods."""
    if methods is None:
        methods = [
            'pca_90', 'pca_95', 'participation_ratio',
            'correlation_dim', 'geodesic_dim',
        ]

    dim_estimates = {}

    # Downsample data if requested
    if ds > 1:
        neural_data_ds = neural_data[:, ::ds]
        print(f'  Downsampled: {neural_data.shape} -> {neural_data_ds.shape}')
    else:
        neural_data_ds = neural_data

    # Transpose data for methods that expect (n_samples, n_features)
    data_transposed = neural_data_ds.T

    # Linear methods
    if 'pca_90' in methods:
        dim_estimates['pca_90'] = pca_dimension(data_transposed, threshold=0.90)
    if 'pca_95' in methods:
        dim_estimates['pca_95'] = pca_dimension(data_transposed, threshold=0.95)

    # Nonlinear intrinsic methods
    if 'correlation_dim' in methods:
        try:
            print('  Computing correlation dimension...')
            dim_estimates['correlation_dim'] = correlation_dimension(data_transposed)
        except Exception as e:
            print(f'  Warning: correlation_dimension failed: {e}')
            dim_estimates['correlation_dim'] = np.nan

    if 'geodesic_dim' in methods:
        try:
            print('  Computing geodesic dimension (this may take time)...')
            dim_estimates['geodesic_dim'] = geodesic_dimension(
                data_transposed, k=20, mode='fast', factor=4
            )
        except Exception as e:
            print(f'  Warning: geodesic_dimension failed: {e}')
            dim_estimates['geodesic_dim'] = np.nan

    # Effective dimensionality (participation ratio)
    if 'participation_ratio' in methods:
        dim_estimates['participation_ratio'] = eff_dim(
            neural_data_ds.T, enable_correction=False, q=2
        )

    return dim_estimates

In [None]:
print('=' * 70)
print('CIRCULAR MANIFOLD EXTRACTION FROM HEAD DIRECTION CELLS')
print('=' * 70)

print('\n1. Generating head direction cell population...')

# Generate synthetic head direction cells
exp_circ, info_circ = generate_circular_manifold_exp(
    n_neurons=100,
    duration=600,  # 10 minutes
    kappa=4.0,     # Tuning width
    seed=42,
    verbose=True,
    return_info=True,
)

# Extract neural activity and true head directions
neural_data_circ = exp_circ.calcium.scdata  # Shape: (n_neurons, n_timepoints)
true_angles = info_circ['head_direction']   # Ground truth angles

print(f'\nGenerated {neural_data_circ.shape[0]} neurons, '
      f'{neural_data_circ.shape[1]} timepoints')
print(f'Neural activity shape: {neural_data_circ.shape}')

In [None]:
# Estimate intrinsic dimensionality
print('\n2. Estimating intrinsic dimensionality of neural population...')
print('-' * 50)

dim_methods = [
    'pca_90', 'pca_95', 'participation_ratio',
    'correlation_dim', 'geodesic_dim',
]

# Use ds=5 downsampling for faster computation
dim_estimates = estimate_dimensionality(
    neural_data_circ, methods=dim_methods, ds=5
)

print('Dimensionality estimates:')
for method, estimate in dim_estimates.items():
    print(f'  {method:20s}: {estimate:.2f}')

print('\nNote: Head direction cells should have intrinsic dimensionality ~ 1')
print('      (circular manifold), but finite sampling may increase estimates')

# Compare with temporally shuffled data to demonstrate manifold structure
print('\n2b. Comparing with temporally shuffled data (destroys manifold)...')
print('-' * 50)

# Get shuffled calcium data from experiment
shuffled_calcium = exp_circ.get_multicell_shuffled_calcium()

# Estimate dimensionality on shuffled data
dim_estimates_shuffled = estimate_dimensionality(
    shuffled_calcium, methods=dim_methods, ds=5
)

print('\nDimensionality estimates (SHUFFLED data):')
for method, estimate in dim_estimates_shuffled.items():
    print(f'  {method:20s}: {estimate:.2f}')

print('\nComparison (Real vs Shuffled):')
print(f'{"Method":<20s} {"Real":>8s} {"Shuffled":>8s} {"Increase":>10s}')
print('-' * 50)
for method in dim_methods:
    real = dim_estimates[method]
    shuffled = dim_estimates_shuffled[method]
    increase = ((shuffled - real) / real) * 100
    print(f'{method:<20s} {real:8.2f} {shuffled:8.2f} {increase:+9.1f}%')

print('\nInterpretation: Temporal shuffling destroys the circular manifold structure,')
print('                dramatically increasing dimensionality.')

In [None]:
# Plot eigenvalue spectrum
print('\n3. Plotting eigenvalue spectrum...')

# Compute correlation matrix
data_centered = neural_data_circ - np.mean(neural_data_circ, axis=1, keepdims=True)
corr_mat = np.corrcoef(data_centered)

# Get eigenvalues
eigenvalues = np.linalg.eigvalsh(corr_mat)[::-1]  # Descending order
eigenvalues = eigenvalues[eigenvalues > 0]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Eigenvalue spectrum
ax1.plot(eigenvalues, 'o-', markersize=4)
ax1.set_xlabel('Component')
ax1.set_ylabel('Eigenvalue')
ax1.set_title('Eigenvalue spectrum')
ax1.set_yscale('log')
ax1.grid(True, alpha=0.3)

# Cumulative variance explained
cumvar = np.cumsum(eigenvalues) / np.sum(eigenvalues)
ax2.plot(cumvar, 'o-', markersize=4)
ax2.axhline(0.9, color='r', linestyle='--', label='90% variance')
ax2.axhline(0.95, color='orange', linestyle='--', label='95% variance')
ax2.set_xlabel('Number of Components')
ax2.set_ylabel('Cumulative Variance Explained')
ax2.set_title('Cumulative variance explained')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Apply dimensionality reduction using MVData
print('\n4. Applying dimensionality reduction methods using MVData...')
print('-' * 50)

# Create MVData object from calcium data with downsampling
downsampling_circ = 10
mvdata_circ = MVData(neural_data_circ, downsampling=downsampling_circ)

# Downsample true angles to match
true_angles_ds = true_angles[::downsampling_circ]

# Dictionary to store embeddings
embeddings_dict_circ = {}

# PCA
print('- PCA...')
pca_embedding_circ = mvdata_circ.get_embedding(method='pca', dim=2)
embeddings_dict_circ['PCA'] = pca_embedding_circ.coords.T
print(
    f'  First 2 PCs explain '
    f'{100 * sum(pca_embedding_circ.reducer_.explained_variance_ratio_):.1f}% of variance'
)

# Isomap
print('- Isomap...')
isomap_embedding_circ = mvdata_circ.get_embedding(
    method='isomap', dim=2, n_neighbors=50
)
embeddings_dict_circ['Isomap'] = isomap_embedding_circ.coords.T

# UMAP with increased parameters for better global structure
print('- UMAP...')
umap_embedding_circ = mvdata_circ.get_embedding(
    method='umap', n_components=2, n_neighbors=100, min_dist=0.5
)
embeddings_dict_circ['UMAP'] = umap_embedding_circ.coords.T

In [None]:
# Visualize extracted manifolds
print('\n5. Visualizing extracted manifolds...')

# Create embedding comparison visualization
embeddings_list_circ = [
    embeddings_dict_circ[m] for m in ['PCA', 'Isomap', 'UMAP']
]
fig_embedding = visualize_circular_manifold(
    embeddings_list_circ, true_angles_ds, ['PCA', 'Isomap', 'UMAP']
)
plt.show()

# Trajectory visualization
print('\n6. Analyzing temporal continuity of extracted manifolds...')

# Use only first 1000 timepoints for trajectory visualization
traj_len = min(1000, embeddings_dict_circ['PCA'].shape[0])
trajectories_dict = {
    method: emb[:traj_len] for method, emb in embeddings_dict_circ.items()
}

fig3 = plot_trajectories(
    embeddings=trajectories_dict,
    trajectory_kwargs={'arrow_spacing': 50, 'linewidth': 0.5, 'alpha': 0.5},
    figsize=(15, 5),
    dpi=DEFAULT_DPI,
)
plt.show()

In [None]:
# Summary statistics
print('\n7. Summary of manifold extraction quality:')
print('-' * 60)
print(f'{"Method":10s} | {"Correlation":12s} | {"Mean Error":12s} | {"Quality":8s}')
print('-' * 60)

for method, embedding in embeddings_dict_circ.items():
    # Use manifold metrics API
    alignment_metrics = compute_embedding_alignment_metrics(
        embedding, true_angles_ds, 'circular'
    )
    r = alignment_metrics['correlation']
    error = alignment_metrics['error']

    # Quality assessment
    if abs(r) > 0.95:
        quality_str = 'Excellent'
    elif abs(r) > 0.85:
        quality_str = 'Good'
    elif abs(r) > 0.70:
        quality_str = 'Fair'
    else:
        quality_str = 'Poor'

    print(f'{method:10s} | {r:12.3f} | {error:9.3f} rad | {quality_str:8s}')

print('\n' + '=' * 70)
print('CONCLUSIONS:')
print('- Head direction cells have low intrinsic dimensionality (~1-2)')
print('- Temporal shuffling destroys manifold structure (dimensionality increases)')
print('- Nonlinear methods (Isomap, UMAP) better preserve circular topology')
print('- PCA captures variance but may distort circular structure')
print('- Higher n_neighbors helps preserve global structure')
print('=' * 70)

## 6. INTENSE-guided DR

Use [`compute_cell_feat_significance`](https://driada.readthedocs.io/en/latest/api/intense/pipelines.html) (INTENSE) to identify spatially selective neurons
from a [`generate_mixed_population_exp`](https://driada.readthedocs.io/en/latest/api/experiment/synthetic.html) dataset, then compare DR
quality on **all neurons** vs **selective neurons only** vs a **random
subset**. Selective neurons produce cleaner spatial embeddings.

In [None]:
from scipy.spatial.distance import pdist


def compute_spatial_correspondence_metrics(embedding, true_positions):
    """
    Compute spatial-specific metrics for evaluating embedding quality.

    Parameters
    ----------
    embedding : ndarray, shape (n_samples, n_dims)
        Low-dimensional embedding
    true_positions : ndarray, shape (n_samples, n_spatial_dims)
        True spatial positions

    Returns
    -------
    metrics : dict
        Dictionary containing spatial correspondence metrics
    """
    metrics = {}

    # 1. SPATIAL DECODING ACCURACY
    decoding_metrics = compute_spatial_decoding_accuracy(
        embedding.T,
        true_positions,
        test_size=0.5,
        n_estimators=20,
        max_depth=3,
        min_samples_leaf=50,
        random_state=42,
    )
    metrics.update(decoding_metrics)

    # 2. SPATIAL INFORMATION CONTENT
    mi_metrics = compute_spatial_information(
        embedding.T, true_positions
    )
    metrics.update(mi_metrics)

    # 3. DISTANCE CORRELATION
    try:
        dist_embed = pdist(embedding)
        dist_true = pdist(true_positions)
        metrics['distance_correlation'] = np.corrcoef(dist_embed, dist_true)[0, 1]
    except:
        metrics['distance_correlation'] = 0.0

    # 4. PROCRUSTES ANALYSIS
    try:
        embedding_2d = embedding[:, :2] if embedding.shape[1] >= 2 else embedding
        _, disparity, _ = procrustes_analysis(
            true_positions, embedding_2d, scaling=True, reflection=True
        )
        metrics['procrustes_disparity'] = disparity
    except:
        metrics['procrustes_disparity'] = 1.0

    return metrics

In [None]:
print('=' * 70)
print('INTENSE-Guided Dimensionality Reduction for Spatial Data')
print('=' * 70)

# 1. Generate mixed population data
print('\n1. Generating mixed population with spatial and non-spatial neurons...')

n_neurons_intense = 50   # Minimal for notebook execution speed
duration_intense = 300   # 5 minutes
n_shuffles_1 = 100       # FFT makes shuffle count cheap
n_shuffles_2 = 5000      # Better statistics with minimal overhead
ds_intense = 5

exp_intense = generate_mixed_population_exp(
    n_neurons=n_neurons_intense,
    manifold_type='2d_spatial',
    manifold_fraction=0.5,  # 1/2 place cells, 1/2 feature cells
    n_discrete_features=3,
    n_continuous_features=3,
    duration=duration_intense,
    seed=42,
    verbose=True,
)
print(f'  Created experiment with {exp_intense.n_cells} neurons, '
      f'{exp_intense.n_frames} timepoints')
print(f'  Available features: {list(exp_intense.dynamic_features.keys())}')

In [None]:
# 2. Run INTENSE with position_2d MultiTimeSeries
print('\n2. Running INTENSE analysis on 2D position (MultiTimeSeries)...')
stats_i, significance_i, info_i, results_i = compute_cell_feat_significance(
    exp_intense,
    feat_bunch=['position_2d'],  # Using MultiTimeSeries only
    find_optimal_delays=False,
    mode='two_stage',
    n_shuffles_stage1=n_shuffles_1,
    n_shuffles_stage2=n_shuffles_2,
    ds=ds_intense,
    pval_thr=0.01,
    multicomp_correction=None,
    verbose=True,
)

In [None]:
# 3. Categorize neurons by selectivity
print('\n3. Categorizing neurons by selectivity...')

# Get neurons selective to spatial position
sig_neurons_2d = list(exp_intense.get_significant_neurons(fbunch='position_2d').keys())
spatial_neurons_i = sig_neurons_2d

print(f'  Spatial neurons (position_2d): {len(sig_neurons_2d)}')
print(f'  Non-spatial neurons: {exp_intense.n_cells - len(spatial_neurons_i)}')

# Check if we have enough spatial neurons
if len(spatial_neurons_i) < 5:
    print('\nWARNING: Not enough spatial neurons detected!')
    print('Try running with more neurons or adjusting parameters.')

# Extract true positions
position_2d_i = exp_intense.dynamic_features['position_2d'].data
x_pos_i = position_2d_i[0, :]
y_pos_i = position_2d_i[1, :]
true_positions_i = np.column_stack([x_pos_i, y_pos_i])

# Downsample positions to match calcium data
if ds_intense > 1:
    true_positions_i = true_positions_i[::ds_intense]

# VERIFICATION: Check ground truth vs detected spatial neurons
print('\n[CHECK] Analyzing ground truth vs detected spatial neurons...')

# Get ground truth spatial neurons (first 50% are spatial by construction)
n_true_spatial = int(exp_intense.n_cells * 0.5)
true_spatial_neurons = list(range(n_true_spatial))
true_nonspatial_neurons = list(range(n_true_spatial, exp_intense.n_cells))

print(f'  Ground truth spatial neurons: {n_true_spatial} '
      f'(indices 0-{n_true_spatial - 1})')
print(f'  Ground truth non-spatial neurons: '
      f'{exp_intense.n_cells - n_true_spatial} '
      f'(indices {n_true_spatial}-{exp_intense.n_cells - 1})')

# Check detection accuracy
detected_spatial_set = set(spatial_neurons_i)
true_spatial_set = set(true_spatial_neurons)
true_nonspatial_set = set(true_nonspatial_neurons)

true_positives = detected_spatial_set & true_spatial_set
false_positives = detected_spatial_set & true_nonspatial_set
false_negatives = true_spatial_set - detected_spatial_set

print(f'  True positives (correctly detected spatial): {len(true_positives)}')
print(f'  False positives (non-spatial detected as spatial): {len(false_positives)}')
print(f'  False negatives (spatial missed): {len(false_negatives)}')

precision_i = (
    len(true_positives) / len(detected_spatial_set) if detected_spatial_set else 0
)
recall_i = len(true_positives) / len(true_spatial_set) if true_spatial_set else 0
f1_i = (
    2 * precision_i * recall_i / (precision_i + recall_i)
    if (precision_i + recall_i) > 0 else 0
)

print(f'  Detection Precision: {precision_i:.3f}')
print(f'  Detection Recall: {recall_i:.3f}')
print(f'  Detection F1-score: {f1_i:.3f}')

In [None]:
# 4. Create scenarios to demonstrate benefit
print('\n4. Creating test scenarios...')

# Get all neurons
calcium_all_i = exp_intense.calcium.scdata[:, ::ds_intense]

# Get spatial neurons (detected by INTENSE)
calcium_spatial_i = exp_intense.calcium.scdata[spatial_neurons_i, ::ds_intense]

# Get non-selective neurons
all_neurons_i = set(range(exp_intense.n_cells))
selective_neurons_i = set(spatial_neurons_i)
for feat in ['d_feat_0', 'd_feat_1', 'd_feat_2',
             'c_feat_0', 'c_feat_1', 'c_feat_2']:
    try:
        feat_neurons = exp_intense.get_significant_neurons(fbunch=feat)
        if feat_neurons:
            selective_neurons_i.update(feat_neurons.keys())
    except:
        pass
non_selective_neurons_i = list(all_neurons_i - selective_neurons_i)
calcium_non_selective_i = (
    exp_intense.calcium.scdata[non_selective_neurons_i, ::ds_intense]
    if non_selective_neurons_i else None
)

# Get random half of all neurons
np.random.seed(42)
random_half_idx_i = np.random.choice(
    exp_intense.n_cells, size=exp_intense.n_cells // 2, replace=False
)
calcium_random_half_i = exp_intense.calcium.scdata[random_half_idx_i, ::ds_intense]

print(f'  All neurons: {calcium_all_i.shape[0]} neurons')
print(f'  Spatial neurons (INTENSE): {calcium_spatial_i.shape[0]} neurons')
print(f'  Random half: {calcium_random_half_i.shape[0]} neurons')
print(f'  Non-selective neurons: {len(non_selective_neurons_i)} neurons')

In [None]:
# 5. Define DR methods and scenarios
print('\n5. Applying dimensionality reduction methods...')

dr_methods_i = {
    'PCA': {'method': 'pca', 'params': {'dim': 2}},
    'Isomap': {'method': 'isomap', 'params': {'dim': 2, 'n_neighbors': 30}},
    'UMAP': {
        'method': 'umap',
        'params': {
            'dim': 2, 'n_neighbors': 80,
            'min_dist': 0.8, 'random_state': 42,
        },
    },
}

results_intense = {}

scenarios_i = [
    ('All neurons', calcium_all_i),
    ('Spatial neurons', calcium_spatial_i),
    ('Random half', calcium_random_half_i),
    ('Non-selective', calcium_non_selective_i),
]

In [None]:
# Run DR methods on each scenario and compute spatial metrics
for method_name, method_config in dr_methods_i.items():
    print(f'\n  {method_name}:')
    results_intense[method_name] = {}

    for scenario_name, calcium_data in scenarios_i:
        if calcium_data is None:
            print(f'    - {scenario_name}: No neurons in this category, skipping')
            continue
        if calcium_data.shape[0] < 10:
            print(f'    - {scenario_name}: Too few neurons ({calcium_data.shape[0]}), skipping')
            continue

        print(f'    - {scenario_name}...')

        try:
            mvdata_i = MVData(calcium_data)

            # Adjust n_neighbors for smaller datasets
            params = method_config['params'].copy()
            if 'n_neighbors' in params:
                params['n_neighbors'] = min(
                    params['n_neighbors'], calcium_data.shape[1] // 10
                )

            embedding_obj = mvdata_i.get_embedding(
                method=method_config['method'], **params
            )
            embedding_i = embedding_obj.coords.T

            metrics_i = compute_spatial_correspondence_metrics(
                embedding_i, true_positions_i
            )

            results_intense[method_name][scenario_name] = {
                'embedding': embedding_i,
                'metrics': metrics_i,
            }

            print(
                f"      Spatial decoding R^2: {metrics_i['r2_avg']:.3f}, "
                f"Distance corr: {metrics_i['distance_correlation']:.3f}, "
                f"MI: {metrics_i['mi_total']:.3f}"
            )

        except Exception as e:
            print(f'      Failed: {e}')
            results_intense[method_name][scenario_name] = None

    # Calculate improvements
    if (
        'All neurons' in results_intense[method_name]
        and 'Spatial neurons' in results_intense[method_name]
    ):
        if (
            results_intense[method_name]['All neurons']
            and results_intense[method_name]['Spatial neurons']
        ):
            r2_all_i = results_intense[method_name]['All neurons']['metrics']['r2_avg']
            r2_spatial_i = results_intense[method_name]['Spatial neurons']['metrics']['r2_avg']
            improvement_i = (r2_spatial_i / max(r2_all_i, 0.001) - 1) * 100
            print(f'    Spatial vs All improvement: {improvement_i:+.1f}%')

In [None]:
# 6. Visualize results
print('\n6. Creating visualizations...')

# Prepare embeddings for grid plot
grid_embeddings = {}
grid_metrics = {}
for method_name in dr_methods_i.keys():
    grid_embeddings[method_name] = {}
    grid_metrics[method_name] = {}
    for scenario in ['All neurons', 'Spatial neurons', 'Random half', 'Non-selective']:
        if scenario in results_intense[method_name] and results_intense[method_name][scenario]:
            grid_embeddings[method_name][scenario] = (
                results_intense[method_name][scenario]['embedding']
            )
            grid_metrics[method_name][scenario] = {
                'R^2': results_intense[method_name][scenario]['metrics']['r2_avg']
            }

labels_i = np.arange(len(true_positions_i))

fig1 = plot_embeddings_grid(
    embeddings=grid_embeddings,
    labels=labels_i,
    metrics=grid_metrics,
    colormap='viridis',
    figsize=(18, 12),
    n_cols=4,
    dpi=DEFAULT_DPI,
)

fig1.suptitle(
    'INTENSE-Guided DR: Benefit of Spatial Neuron Selection', fontsize=14
)
plt.show()

# Create neuron selectivity summary
selectivity_counts = {
    'Spatial\n(any)': len(spatial_neurons_i),
    'Non-spatial': exp_intense.n_cells - len(spatial_neurons_i),
}

fig_summary = plot_neuron_selectivity_summary(
    selectivity_counts=selectivity_counts,
    total_neurons=exp_intense.n_cells,
    figsize=(8, 6),
    dpi=DEFAULT_DPI,
)
plt.show()

In [None]:
# 7. Quality metrics comparison figure
print('\n7. Creating quality metrics comparison...')

fig2, axes = plt.subplots(2, 2, figsize=(12, 10))
fig2.suptitle('Dimensionality reduction quality metrics comparison', fontsize=16)

metrics_to_show = [
    ('r2_avg', 'Spatial Decoding R^2'),
    ('distance_correlation', 'Distance Correlation'),
    ('mi_total', 'Spatial Information (MI)'),
    ('procrustes_disparity', 'Procrustes Disparity'),
]

scenarios_to_show = [
    ('All neurons', 'All'),
    ('Spatial neurons', 'Spatial'),
    ('Random half', 'Random'),
    ('Non-selective', 'Non-sel'),
]

for idx, (metric_key, metric_title) in enumerate(metrics_to_show):
    ax = axes[idx // 2, idx % 2]
    method_names = list(dr_methods_i.keys())

    x = np.arange(len(scenarios_to_show))
    width = 0.25

    for i, method in enumerate(method_names):
        values = []
        for scenario, _ in scenarios_to_show:
            if scenario in results_intense[method] and results_intense[method][scenario]:
                value = results_intense[method][scenario]['metrics'][metric_key]
                values.append(value)
            else:
                values.append(0)

        offset = (i - len(method_names) / 2 + 0.5) * width
        bars = ax.bar(x + offset, values, width, label=method, alpha=0.8)

        # Add value labels on bars
        for bar, value in zip(bars, values):
            if value != 0:
                height = bar.get_height()
                ax.text(
                    bar.get_x() + bar.get_width() / 2.0, height,
                    f'{value:.2f}', ha='center', va='bottom', fontsize=8,
                )

    ax.set_ylabel(metric_title)
    ax.set_title(metric_title)
    ax.set_xticks(x)
    ax.set_xticklabels(
        [label for _, label in scenarios_to_show], rotation=45, ha='right'
    )
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    if 'r2' in metric_key:
        ax.set_ylim(0, 1.0)

plt.tight_layout()
plt.show()

In [None]:
# 8. Summary statistics
print('\n' + '=' * 70)
print('SUMMARY')
print('=' * 70)

print('\nBest performing method for spatial reconstruction:')
best_method_i = None
best_score_i = -1
for method_name in results_intense.keys():
    if (
        'Spatial neurons' in results_intense[method_name]
        and results_intense[method_name]['Spatial neurons']
    ):
        score = results_intense[method_name]['Spatial neurons']['metrics']['r2_avg']
        if score > best_score_i:
            best_score_i = score
            best_method_i = method_name

if best_method_i:
    print(f'  {best_method_i} with spatial neurons')
    print(f'  Spatial decoding R^2: {best_score_i:.3f}')

print('\nSpatial decoding R^2 comparison:')
for method_name in results_intense.keys():
    print(f'\n  {method_name}:')
    scenarios_order = [
        'All neurons', 'Spatial neurons', 'Random half', 'Non-selective',
    ]
    for scenario in scenarios_order:
        if scenario in results_intense[method_name] and results_intense[method_name][scenario]:
            r2 = results_intense[method_name][scenario]['metrics']['r2_avg']
            print(f'    {scenario:20s}: {r2:.3f}')

    # Calculate key comparisons
    if (
        'All neurons' in results_intense[method_name]
        and results_intense[method_name]['All neurons']
    ):
        r2_all_s = results_intense[method_name]['All neurons']['metrics']['r2_avg']

        if (
            'Spatial neurons' in results_intense[method_name]
            and results_intense[method_name]['Spatial neurons']
        ):
            r2_spatial_s = results_intense[method_name]['Spatial neurons']['metrics']['r2_avg']
            imp_s = (r2_spatial_s / max(r2_all_s, 0.001) - 1) * 100
            print(f'    -> Spatial vs All improvement: {imp_s:+.1f}%')

        if (
            'Random half' in results_intense[method_name]
            and results_intense[method_name]['Random half']
        ):
            r2_random_s = results_intense[method_name]['Random half']['metrics']['r2_avg']
            ratio_s = r2_random_s / max(r2_all_s, 0.001)
            print(f'    -> Random half / All ratio: {ratio_s:.2f}')