# NLDisco Tutorial

This tutorial introduces **NLDisco** (**N**eural **L**atent **Disco**very pipeline).

**Goal:** discover interpretable latents (i.e., *features*) in high-dimensional neural data.

NLDisco trains sparse autoencoders (SAEs): shallow encoder–decoder models trained to reconstruct neural activity ($y$ in the figure) from a set of sparsely active dictionary elements (hidden units, $d$). Sparsity encourages a monosemantic dictionary, where each unit corresponds to a single interpretable feature, making SAEs a simple but effective tool for neural latent discovery. This approach has had many successes in the field of AI mechanistic interpretability.

![](./tutorial_figures/sae.svg)

We consider a latent’s interpretability in two key aspects: 
1. its correspondence to a specific external variable – a "natural" behavioral or environmental feature
2. its explicit composition from contributing neural activity

**Terminology:**

In this tutorial we will be refferring to the following terms:
- *Features:* interpretable latents (latent dimensions that align with a meaningful behavioural or environmental variable)
- *Neurons:* biological neurons
- *Units:* hidden units of the SAE (model neurons, each corresponding to one latent dimension)

**NLDisco pipeline:**

1. Load and prepare data
    - Spike data (in the form of binned spike counts as $[examples \times neurons]$)
    - Behavioral and/or environmental (meta)data

2. Train SAEs 
    - Train them to reconstruct the neural data (in this case, from the MC Maze dataset)
    - Validate the quality of the SAEs by looking at the sparsity of the latent activations and reconstruction quality of the neural data

3. Save or load the SAE activations

4. Find features
    - Automatically generate promising mappings between SAE units (latents) and metadata
    - Find meaningful features and their associated biological neurons using an interactive dashboard
        - The mapping is a starting point to guide the search
        - Or the user can also choose to look through units manually

___

## Setup

**Environment setup:**

Prerequisite: an installed version of [pixi](https://pixi.sh/latest/)

Steps:
1. In the repo's root directory, run `pixi install --manifest-path ./pyproject.toml`. This will create an environment in a newly created `.pixi/envs` folder.
2. Run `pixi run postinstall`

**Data download:**

Once your environment is set up, use it as a kernel for this notebook and run the cell below to automatically download and preprocess the [churchland_shenoy_neural_2012 dataset](https://brainsets.readthedocs.io/en/latest/glossary/brainsets.html#churchland-shenoy-neural-2012) (also known as MC_Maze). With the data ready, you can jump straight into the tutorial!

In [None]:
from nldisco import mc_maze

# Directories for raw and processed data
raw_data_dir = "../data/raw"
processed_data_dir = "../data/processed"

# Subject name and number of their sessions to download
# Max 4 for jenkins and 3 for nitschke
# Be aware that these files are large (~2-7GB each)
subject_name = "nitschke"  # "jenkins" or "nitschke"
num_files = 1  

mc_maze.download_and_preprocess(raw_data_dir, processed_data_dir, subject_name, num_files=num_files)

# To load all the data, use:
# mc_maze.download_and_preprocess(raw_data_dir, processed_data_dir, "jenkins", num_files=4)
# mc_maze.download_and_preprocess(raw_data_dir, processed_data_dir, "nitschke", num_files=3)


___

## Tutorial start!

In [None]:
"""Set notebook settings."""

%load_ext autoreload
%autoreload 2

In [None]:
"""Import packages."""

# Standard library
from datetime import datetime
from pathlib import Path

# IPython/Jupyter
from IPython.display import display

# Third-party
import numpy as np
import pandas as pd
import seaborn as sns
import torch as t
from einops import asnumpy, reduce
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter1d
from sklearn.metrics import r2_score
from tqdm.notebook import tqdm

# Local project modules
from nldisco import train as mt
from nldisco import mc_maze
from nldisco.train_val_split import train_val_split_by_proportion, train_val_split_by_session
from nldisco import pipeline

# 1. Load and prepare data

In the churchland_shenoy_neural_2012 dataset, the subjects are performing a center-out reaching task on a variety of different maze configurations. Each maze configuration comes in 3 versions:
- 1 target and no barriers.
- 1 target with barriers.
- 3 targets, with barriers. But 2 of the targets are distractors and unaccessible given the barrier configuration.

So maze conditions 1, 2, 3 are related; as are maze conditions 4, 5, 6; and so on.
Neural activity was recorded from the dorsal premotor (PMd) and primary motor (M1) cortices. A variety of other data (monkey hand position, velocity and acceleration, gaze position...) is also provided.

## Load and prepare spike data

In [None]:
data_path = Path(r"../data/processed")
save_path = Path(r"../saved_sae_unit_activations")
subject_name = "nitschke"  # "jenkins" or "nitschke"

# Load
sessions = mc_maze.load_sessions(data_path, subject_name)

In [None]:
"""Bin spike data."""

bin_size = 0.05 # in seconds
spikes_df = mc_maze.bin_spike_data(sessions, bin_size) # this can take several minutes

display(spikes_df)

In [None]:
"""Quick plots and stats to get a sense of the spike data."""

print("Firing rates distribution:")
# Compute mean firing rate (Hz) per neuron
duration_sec = len(spikes_df) * bin_size
mean_firing_rates = spikes_df.sum(axis=0) / duration_sec  # spikes_arr/sec
# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(mean_firing_rates, bins=30, edgecolor='black')
plt.xlabel('Mean Firing Rate (Hz)')
plt.ylabel('Number of Neurons')
plt.title('Distribution of Mean Firing Rates per Neuron')
plt.grid(True)
plt.tight_layout()
plt.show()
# Print stats
print("Mean: {:.2f} Hz".format(mean_firing_rates.mean()))
print("Range: {:.2f}–{:.2f} Hz".format(mean_firing_rates.min(), mean_firing_rates.max()))

print("\n\nSpike count distribution and sparsity stats:")
# Flatten spike counts
flattened_spike_counts = spikes_df.values.flatten()
# Define bins that align exactly to integer spike counts
max_count = flattened_spike_counts.max()
bins = np.arange(0, max_count + 2) - 0.5  # centers bins on integers
# Plot histogram
plt.figure(figsize=(8, 5))
plt.hist(flattened_spike_counts, bins=bins, edgecolor='black')
plt.title("Distribution of Spike Counts per Neuron per Time Bin")
plt.xlabel("Spike Count")
plt.ylabel("Number of Neuron-Bin Combinations")
plt.yscale("log")
plt.grid(True)
plt.tight_layout()
plt.show()
# Print stats
frac_nonzero_bins = (spikes_df != 0).values.sum() / spikes_df.size
frac_nonzero_examples = (spikes_df.sum(axis=1) > 0).mean()
print(f"Fraction of non-zero bins: {frac_nonzero_bins:.4f}")
print(f"Fraction of non-zero examples: {frac_nonzero_examples:.4f}")


## Load and prepare environment / behavior (meta)data

In [None]:
# Retrieve and collate metadata (hand/eye/events) across sessions
metadata, trials_df = mc_maze.retrieve_metadata(sessions) # this can take a minute
# Bin metadata to the match the binned spikes_df
metadata_binned = mc_maze.bin_metadata(metadata, trials_df, bin_size, spikes_df.index)

print("Metadata:")
display(metadata)
print("Binned metadata:")
display(metadata_binned)

## Train/val split

In [None]:
"""Train/val split, smooth and normalize spikes."""

split_by_session = False # if False, will split by proportion

if split_by_session:
    train_trials, val_trials = train_val_split_by_session(
        metadata_binned["trial_idx"].to_numpy(),
        metadata_binned["session"].to_numpy(),
        train_sessions=[1, 2],  # pick your training sessions
        shuffle=True,
        seed=0,
    )
else:
    train_trials, val_trials = train_val_split_by_proportion(
        metadata_binned["trial_idx"].values,
        train_proportion=0.8,
        shuffle=True,
        seed=0,
    )

# Create boolean masks to split metadata and spikes into train/val sets
train_mask = metadata_binned['trial_idx'].isin(train_trials)
val_mask = metadata_binned['trial_idx'].isin(val_trials)

# Split metadata
metadata_binned_train = metadata_binned[train_mask].reset_index(drop=True)
metadata_binned_val = metadata_binned[val_mask].reset_index(drop=True)

# Split spikes 
spikes_arr = spikes_df.values.astype(np.float32)
spikes_train_arr = spikes_arr[train_mask]
spikes_val_arr = spikes_arr[val_mask]

# Smooth spikes
sigma = 0.05 / bin_size
spikes_train_arr = gaussian_filter1d(spikes_train_arr, sigma=sigma, axis=0)
spikes_val_arr = gaussian_filter1d(spikes_val_arr, sigma=sigma, axis=0)

# Normalize spikes (fit normalization on training data only)
train_max = spikes_train_arr.max()
spikes_train_arr = spikes_train_arr / train_max
spikes_val_arr = spikes_val_arr / train_max

# Summary
print(f"Train set: {len(train_trials)} trials ({train_mask.sum()} time bins)")
print(f"Val set: {len(val_trials)} trials ({val_mask.sum()} time bins)")
print(f"Spike data shapes: train {spikes_train_arr.shape}, val {spikes_val_arr.shape}")

In [None]:
"""Convert to torch tensors and move to device."""

# it's best to have a gpu for training!
device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"{device=}")

spikes_train = t.from_numpy(spikes_train_arr).to(device).to(dtype=t.bfloat16)
spikes_val = t.from_numpy(spikes_val_arr).to(device).to(dtype=t.bfloat16)

display(spikes_train)
display(spikes_val)

# 2. Train SAEs

> If desired, you can choose to skip this section and load pre-saved SAE unit activations instead. A set of activations per subject (Jenkins and Nitschke) is provided - go to section "3. Save/load SAE activations" for more information. It is however still highly recommended to read the rest of this section to understand how the SAEs are trained.

This code trains 2 SAEs with identical setups so that you can compare the different instances and ensure they both give similar results. For each time bin of neural data in the train and val sets, the SAEs' hidden-layer unit activations are calculated, and it is these activations that will be used to find a unit's correspondance with external variables (features). A particularity of the NLDisco pipeline below is that it trains *Matryoshka* SAEs.

**Matryoshka architecture:**

The Matryoshka architecture segments the latent space into multiple levels, each of which attempts a full reconstruction of the target neural activity. Black boxes indicate the latents involved in a given level, while light-red boxes indicate additional latents recruited at lower levels. A top-$k$ selection is used to choose which latents to recruit for reconstruction at each level (yellow neuron within each light-red box, $k=1$ for this example figure). 

This hierarchical arrangement is motivated by the idea that multi-scale feature learning could mitigate “feature absorption” (where broad features dominate more specific ones), potentially allowing both coarse and detailed patterns to be represented simultaneously.

- Latents in the highest level ($L_1$) typically correspond to broad, high-level features (e.g., a round object), 
- Latents exclusive to the lowest level ($L_3$) often correspond to more specific, fine-grained features (e.g., a basketball)

![](./tutorial_figures/msae.svg)

**Key training parameters to play with:**

`SaeConfig` (model-level):
- `dsae_topk_map`: how many top-k units are kept active at each level (controls sparsity per level)
- `dsae_loss_x_map`: relative weight of each Matryoshka level to the overall reconstruction loss

`optimize()` (optimizer-level):
- `n_steps`: total training steps
- `batch_sz`: how many examples per batch
- `lr` (set in optimizer): learning rate used by whatever optimizer you choose (e.g. Adam below) and you can optionally use a scheduler (`use_lr_sched=True`)
- `dead_neuron_window`: number of steps a unit can stay inactive before being flagged as “dead”  
  - Dead units are revived with an auxiliary loss: instead of reconstructing the full input, they try to reconstruct only the residual error (the part the active units failed to capture)
  - This gives inactive units a chance to become useful again, preventing them from staying permanently silent   
- `loss_fn` – reconstruction objective: built-ins are mse and msle, or you can pass a custom callable


## Train SAEs

In [None]:
"""Set config."""

dsae_topk_map = {256: 8, 512: 16, 1024: 24} # total of 1024 units, 3 overlapping levels: 0-256, 0-512 and 0-1024
dsae_topk_map = dict(sorted(dsae_topk_map.items()))  # ensure sorted from smallest to largest
dsae_loss_x_map = {256: 1, 512: 1.25, 1024: 1.5}
dsae_loss_x_map = dict(sorted(dsae_loss_x_map.items()))
dsae = max(dsae_topk_map.keys())
n_inst = 2

In [None]:
"""Train model."""

sae_cfg = mt.SaeConfig(
    n_input_ae=spikes_train.shape[1],
    dsae_topk_map=dsae_topk_map,
    dsae_loss_x_map=dsae_loss_x_map,
    n_instances=n_inst,
)
sae = mt.Sae(sae_cfg).to(device)
loss_fn = mt.msle
tau = 1.0
lr = 5e-3

n_epochs = 20
batch_sz = 1024
n_steps = (spikes_train.shape[0] // batch_sz) * n_epochs
log_freq = max(1, n_steps // n_epochs // 2)
dead_unit_window = max(1, n_steps // n_epochs // 3)

data_log = mt.optimize(  # train model
    spk_cts=spikes_train,
    sae=sae,
    loss_fn=loss_fn,
    optimizer=t.optim.Adam(sae.parameters(), lr=lr),
    use_lr_sched=True,
    dead_neuron_window=dead_unit_window,
    n_steps=n_steps,
    log_freq=log_freq,
    batch_sz=batch_sz,
    log_wandb=False,
    plot_l0=False,
    tau=tau,
)

## Validate SAEs

In [None]:
"""Check for nans in weights."""

sae.W_dec.isnan().sum(), sae.W_enc.isnan().sum()

In [None]:
"""Visualize weights."""

fig, ax = plt.subplots(figsize=(8, 6))
for inst in range(n_inst):
    W_dec_flat = asnumpy(sae.W_dec[inst].float()).ravel()
    sns.histplot(W_dec_flat, bins=1000, stat="probability", alpha=0.7, label=f"SAE {inst}")
    
ax.set_title("SAE decoder weights")
ax.set_xlabel("Weight value")
ax.set_ylabel("Frequency")
ax.legend()

In [None]:
"""Visualize metrics over all examples and neurons."""

topk_acts_4d_train, recon_spk_cts_train, r2_per_neuron_train, _, cossim_per_neuron_train, _ = mt.eval_model(
    spikes_train, sae, batch_sz=batch_sz
)

In [None]:
"""Calculate variance explained of summed spike counts."""

n_recon_examples = recon_spk_cts_train.shape[0]
recon_summed_spk_cts = reduce(recon_spk_cts_train, "example inst neuron -> example inst", "sum")

actual_summed_spk_cts = reduce(spikes_train, "example neuron -> example", "sum")
actual_summed_spk_cts = actual_summed_spk_cts[:n_recon_examples]  # trim to match

for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_spk_cts.float()),
        asnumpy(recon_summed_spk_cts[:, inst].float()),
    )
    print(f"SAE instance {inst} R² (summed spike count over all neurons per example) = {r2:.3f}")


## Remove bad neurons and retrain

In [None]:
"""Remove bad neurons and retrain."""

# Set threshold for removing neurons
r2_thresh = 0.1
inst = 0
r2_inst = r2_per_neuron_train[:, inst]
keep_mask = r2_inst > r2_thresh
print(f"frac neurons above {r2_thresh=}: {keep_mask.sum() / keep_mask.shape[0]:.2f}")
print(f"Number to keep: {keep_mask.sum()} / {keep_mask.shape[0]}")

if keep_mask.all():
    print("All neurons pass threshold — skipping retrain.")
    spikes_train_pruned = spikes_train
    spikes_val_pruned = spikes_val
else:
    # Prune
    spikes_train_pruned = spikes_train[:, keep_mask]
    spikes_val_pruned = spikes_val[:, keep_mask]

    # Retrain SAE on pruned train data
    sae_cfg = mt.SaeConfig(
        n_input_ae=spikes_train_pruned.shape[1],
        dsae_topk_map=dsae_topk_map,
        dsae_loss_x_map=dsae_loss_x_map,
        seq_len=1,
        n_instances=n_inst,
    )
    sae = mt.Sae(sae_cfg).to(device)
    loss_fn = mt.msle
    tau = 1.0
    lr = 5e-3

    n_epochs = 20
    batch_sz = 1024
    n_steps = (spikes_train_pruned.shape[0] // batch_sz) * n_epochs
    log_freq = max(1, n_steps // n_epochs // 2)
    dead_unit_window = max(1, n_steps // n_epochs // 3)

    data_log = mt.optimize(
        spk_cts=spikes_train_pruned,
        sae=sae,
        loss_fn=loss_fn,
        optimizer=t.optim.Adam(sae.parameters(), lr=lr),
        use_lr_sched=True,
        dead_neuron_window=dead_unit_window,
        n_steps=n_steps,
        log_freq=log_freq,
        batch_sz=batch_sz,
        log_wandb=False,
        plot_l0=False,
        tau=tau,
    )

In [None]:
"""Re-visualize metrics over all examples and neurons."""

if keep_mask.all():
    print("All neurons pass threshold — skipping re-visualization.")
else:
    topk_acts_4d_train, recon_spk_cts_train, r2_per_neuron_train, _, cossim_per_neuron_train, _ = mt.eval_model(
        spikes_train_pruned, sae, batch_sz=batch_sz
    )

    n_recon_examples_train = recon_spk_cts_train.shape[0]
    recon_summed_train = reduce(recon_spk_cts_train, "example inst neuron -> example inst", "sum")

    actual_summed_train = reduce(spikes_train_pruned, "example neuron -> example", "sum")
    actual_summed_train = actual_summed_train[:n_recon_examples_train]

    for inst in range(n_inst):
        r2 = r2_score(
            asnumpy(actual_summed_train.float()),
            asnumpy(recon_summed_train[:, inst].float()),
        )
        print(f"SAE instance {inst} R² (summed spike count per example) = {r2:.3f}")

In [None]:
"""Visualize metrics on validation data."""

if spikes_val_pruned.shape[0] == 0:
    print("No validation data available.")
else:
    print("Validation data metrics:")

    topk_acts_4d_val, recon_spk_cts_va, r2_per_neuron_va, _, cossim_per_neuron_va, _ = mt.eval_model(
        spikes_val_pruned, sae, batch_sz=batch_sz
    )

    n_recon_examples_val = recon_spk_cts_va.shape[0]
    recon_summed_val = reduce(recon_spk_cts_va, "example inst neuron -> example inst", "sum")

    actual_summed_val = reduce(spikes_val_pruned, "example neuron -> example", "sum")
    actual_summed_val = actual_summed_val[:n_recon_examples_val]

    for inst in range(n_inst):
        r2 = r2_score(
            asnumpy(actual_summed_val.float()),
            asnumpy(recon_summed_val[:, inst].float()),
        )
        print(f"SAE instance {inst} R² (summed spike count per example) = {r2:.3f}")

# 3. Save/load SAE activations

A set of activations per subject (Jenkins and Nitschke) is provided in the `saved_sae_unit_activations` folder. They were both generated using a split by proportion (80/20 train/val split). If you want to use them, switch the load_activations option below to `True`.

In [None]:
"""Load saved activations if available; otherwise build acts_df and (optionally) save."""

load_activations = False
save_activations = True
activations_file_train = "sae_activations_train.parquet"
activations_file_val = "sae_activations_val.parquet"
mask_file_train = "train_mask.parquet"
mask_file_val = "val_mask.parquet"

# Build save path (same style as before)
session_dates = []
for session in sessions:
    session_date = datetime.fromtimestamp(session.session.recording_date).strftime("%Y%m%d")
    session_dates.append(session_date)
session_dates_str = "_".join(session_dates)

activations_save_path = save_path / f"{subject_name}_{session_dates_str}" / "sae_activations"
activations_save_path.mkdir(parents=True, exist_ok=True)

if load_activations:
    acts_df_train = pd.read_parquet(activations_save_path / activations_file_train)
    acts_df_val = pd.read_parquet(activations_save_path / activations_file_val) if (activations_save_path / activations_file_val).exists() else None
    train_mask = pd.read_parquet(activations_save_path / mask_file_train)["mask"]
    val_mask = pd.read_parquet(activations_save_path / mask_file_val)["mask"]
    print(f"Loaded activations from {activations_save_path}")
else:
    # Train
    arr_tr = asnumpy(topk_acts_4d_train)  # [example_idx, instance_idx, unit_idx, activation_value]
    # Sparse activations (tight dtypes on indices, fp32 values)
    acts_df_train = pd.DataFrame({
        "example_idx": arr_tr[:, 0].astype(int),
        "instance_idx": arr_tr[:, 1].astype(int),
        "unit_idx": arr_tr[:, 2].astype(int),
        "activation_value": arr_tr[:, 3].astype(np.float32),
    })

    if spikes_val_pruned.shape[0] > 0:
        # Val
        arr_va = asnumpy(topk_acts_4d_val)
        acts_df_val = pd.DataFrame({
            "example_idx": arr_va[:, 0].astype(int),
            "instance_idx": arr_va[:, 1].astype(int),
            "unit_idx": arr_va[:, 2].astype(int),
            "activation_value": arr_va[:, 3].astype(np.float32),
        })
    else:
        acts_df_val = None

    n_examples_train = (int(acts_df_train["example_idx"].max()) + 1)
    std_threshold = 1e-6

    # Precompute squared values once, then sum both in one grouped pass
    acts_df_train_with_sq = acts_df_train.assign(activation_value_sq=acts_df_train["activation_value"] ** 2)

    unit_stats = (
        acts_df_train_with_sq.groupby(["instance_idx", "unit_idx"], as_index=False)
          .agg(sum_val=("activation_value", "sum"),
               sum_sq=("activation_value_sq", "sum"))
    )
    # Population variance which takes into account missing rows where activations are zero
    unit_stats["mean"] = unit_stats["sum_val"] / n_examples_train
    unit_stats["var"]  = (unit_stats["sum_sq"] / n_examples_train) - unit_stats["mean"]**2
    unit_stats["std"]  = np.sqrt(np.clip(unit_stats["var"].to_numpy(), 0.0, None))

    kept_unit = unit_stats.loc[unit_stats["std"] > std_threshold, ["instance_idx", "unit_idx"]]
    n_dropped = len(unit_stats) - len(kept_unit)

    if n_dropped:
        # Semi-join to keep only surviving (instance, unit) pairs
        acts_df_train = acts_df_train.merge(kept_unit, on=["instance_idx", "unit_idx"], how="inner")
        acts_df_val = acts_df_val.merge(kept_unit, on=["instance_idx", "unit_idx"], how="inner") if (spikes_val_pruned.shape[0] > 0) else None
        print(f"Pruned {n_dropped} units (std ≤ {std_threshold}). Kept {len(kept_unit)}.")

    if save_activations:
        acts_df_train.to_parquet(activations_save_path / activations_file_train, index=False)
        acts_df_val.to_parquet(activations_save_path / activations_file_val, index=False) if spikes_val_pruned.shape[0] > 0 else None
        train_mask.to_frame(name="mask").to_parquet(activations_save_path / mask_file_train, index=True)
        val_mask.to_frame(name="mask").to_parquet(activations_save_path / mask_file_val, index=True)
        print(f"Saved activations to {activations_save_path}")

if acts_df_val is not None:
    print(f"Activations: \nTrain shape: {acts_df_train.shape}, Val shape: {acts_df_val.shape}")
else:
    print(f"Activations: \nTrain shape: {acts_df_train.shape}")

# 4. Find features

## Pick dataset

In [None]:
"""Pick whether to find features in training or validation set."""

search_train = True

if search_train:
    acts_df_split = acts_df_train
    metadata_binned_split = metadata_binned[train_mask].copy()
    spikes_df_split = spikes_df[train_mask].copy()
else:
    acts_df_split = acts_df_val
    metadata_binned_split = metadata_binned[val_mask].copy()
    spikes_df_split = spikes_df[val_mask].copy()

## Automatically map units to metadata

**How it works:**

Units are mapped to metadata variables through the calculation of a selectivity score. For a unit $u$ and condition $c$ (variable/value combination e.g., velocity is between 0 and 1, or maze condition = 3, etc.):

$$
\text{activation\_frac\_during} =
\frac{\#\{\text{activations of } u \text{ in examples with } c\}}
     {\#\{\text{examples with } c\}}
$$

$$
\text{activation\_frac\_outside} =
\frac{\#\{\text{activations of } u \text{ in examples without } c\}}
     {\#\{\text{examples without } c\}}
$$

$$
\text{selectivity\_score} =
\frac{\text{activation\_frac\_during}}
     {\text{activation\_frac\_during} + \text{activation\_frac\_outside}}
$$

- $\approx 1$: unit mainly active *during* the condition (highly selective)  
- $\approx 0.5$: unit active equally in/out (not selective)  
- $\approx 0$: unit mostly active *outside* the condition

The map_units_to_metadata function:
1. For discrete variables: computes activation fractions + selectivity score per condition value.  
2. For continuous variables: bins, then reuses discrete analysis.  
3. Results are ranked by selectivity score and the `top_n_mappings` are returned.

**Key arguments to play with:**
- `discrete_vars` and `continuous_vars`: as default only one of each was included because the function takes time to run, but you may be interested in exploring different or additional variables
- `n_bins_continuous`: number of bins for continuous variables. This will affect whether you find more general features (small number of bins so you have less granularity e.g., you can only distinguish between fast vs slow hand velocity) or specific features (larger number of bins for more granularity e.g., you can now distinguish between very fast vs fast vs intermediate vs slow vs very slow hand velocity)
- `min_activation_frac`: minimum fraction of condition examples a unit must activate in
- `top_n_mappings`: number of highest-scoring mappings kept per variable/value/instance combination (default `3`)

In [None]:
"""Map units to metadata variables."""

# discrete_vars = ['event', 'maze_condition', 'barriers', 'targets', 'hit_position_x', 'hit_position_y', 'hit_position_angle']
# continuous_vars = ['vel_magnitude', 'accel_magnitude', 'movement_angle']
discrete_vars = ['event']
continuous_vars = ['vel_magnitude']
unit_metadata_mapping = pipeline.map_units_to_metadata( # this can take a minute
    acts_df_split, metadata_binned_split,
    discrete_vars=discrete_vars,
    continuous_vars=continuous_vars,
    min_activation_frac=0.5,
    n_bins_continuous=[12],
    top_n_mappings=5
)
unit_metadata_mapping

## Find meaningful features and their associated neurons

The code below generates a dashboard like this one, providing an interactive way to explore the SAE units in the search for meaningful features and their associated neurons:

![](./tutorial_figures/feature_finding_dashboard.png)

The unit_metadata_mapping dataframe generated just previously is used as a starting point to identify promising features ("preset" option on the dashboard). If you want more preset options to explore, feel free to return to the "Automatically map units to metadata" section and play with the arguments that control the mapping process.

You may also choose to look through units manually, though this tends to take more time ("manual selection" option on the dashboard).

Remember:
- Matryoshka SAEs split units across multiple levels, with higher levels capturing broad patterns and lower levels capturing more specific ones. Check the `dsae` settings you used to train your SAEs to see what units were allocated to each level.
- The neural recordings come from two brain regions (PMd and M1). Use the `neurons_df` dataframe generated below to look up the mapping of neuron IDs to brain regions - this can help you see whether particular features are driven more strongly by neurons in one region or the other.
- By default, this tutorial set `search_train = True` at the beginning of this section, so the feature search runs on the training set. You can switch to the validation set instead, which lets you check that the SAEs generalise to unseen data (and to unseen sessions if you chose to split by session in section "1. Load and prepare data").

In [None]:
"""Print neuron to brain region mapping for reference."""

neuron_ids = [u.decode() for u in session.units.id]
neurons_df = pd.DataFrame({
    "neuron_id": [int(u.split("elec")[1]) for u in neuron_ids],
    "region": ["PMd" if "group_1" in u else "M1" for u in neuron_ids]
})

neurons_df

In [None]:
"""Feature finding dashboard."""

pipeline.build_feature_finding_dashboard(
    unit_metadata_mapping=unit_metadata_mapping,
    acts_df=acts_df_split,
    spikes_df=spikes_df_split,
    metadata_binned=metadata_binned_split
)