# NLDisco MC_Maze pipeline

This code uses the **NLDisco** (**N**eural **L**atent **Disco**very) pipeline on MC_Maze data.

**Goal:** Discover interpretable latents (i.e., *features*) in high-dimensional neural data.

**Terminology:**

- *Neural / Neuronal:* Refers to biological neurons. Distinguished from *model neurons* (see below).
- *Units:* Putative biological neurons -- the output from spikesorting extracellular electrophysiological data.
- *Model neurons:* Neurons in a neural network model (aka *latents*)
- *Features:* Interpretable latents (latent dimensions that align with meaningful behavioral or environmental variables)

![](./figures/sae.svg)

*(If the image above does not render, see ./figures/sae.svg)*

Motivated by successful applications of sparse dictionary learning in AI mechanistic interpretability, NLDisco trains overcomplete sparse encoder-decoder (SED) models to reconstruct neural activity based on a set of sparsely active dictionary elements (i.e. latents), implemented as hidden layer neurons. In the figure above, this is illustrated as reconstructing target neural activity $z$ from input neural activity $y$ via $d$. Sparsity in the latent space encourages a monosemantic dictionary, where each hidden layer neuron corresponds to a single neural representation that can be judged for interpretability, making SEDs a simple but effective tool for neural latent discovery. 

These SEDs can be configured as autoencoders (SAEs) if the target for $z$ is $y$ (e.g. M1 activity based on M1 activity), or as transcoders if the target for $z$ is dependent on or related to $y$ (e.g. M1 activity based on M2 activity, or M1 activity on day 2 based on M1 activity on day 1). In this tutorial, we will work exclusively with SAEs.

We consider a latent’s interpretability in two key aspects:

1. its correspondence to a specific external variable – a "natural" behavioral or environmental feature

2. its explicit composition from contributing neural activity

**NLDisco pipeline:**

1. Load and prepare data
    - Neural data in the form of $[time \times space]$, and in this tutorial specifically as binned spike counts of $[examples \times units]$

2. Train models 
    - Train models to reconstruct the neural data (in this tutorial, from the Churchland MC Maze dataset)
    - Validate the quality of the models by looking at the sparsity of the latent activations and reconstruction quality of the neural data

3. Save or load the model activations

4. Decode a behavioural variable (in this case hand velocity)
    - Using NLDisco latents
    - Using CEBRA embeddings, as a comparison

5. Find features
    - Automatically generate promising mappings between model neurons (latents) and behavioral and/or environmental data
    - Find meaningful features and their top contributing units using an interactive dashboard
        - The mappings are a starting point to guide the search
        - A user can also choose to look through model neurons manually

6. Make paper plots

___

## Setup

**Environment setup:**

Prerequisite: an installed version of [pixi](https://pixi.sh/latest/)

Steps:
1. In the repo's root directory, run `pixi install --manifest-path ./pyproject.toml`. This will create an environment in a newly created `.pixi/envs` folder.
2. Run `pixi run postinstall`

**Data download:**

Once your environment is set up, use it as a kernel for this notebook and run the cell below to automatically download and preprocess the [churchland_shenoy_neural_2012 dataset](https://brainsets.readthedocs.io/en/latest/glossary/brainsets.html#churchland-shenoy-neural-2012) (also known as MC_Maze). With the data downloaded, you can jump straight into the tutorial!

In [None]:
# from nldisco import mc_maze

# # Directories for raw and processed data
# raw_data_dir = "../data/raw"
# processed_data_dir = "../data/processed"

# # Subject name and number of their sessions to download
# # Max 4 for jenkins and 3 for nitschke
# # Be aware that these files are large (~2-7GB each)
# subject_name = "nitschke"  # "jenkins" or "nitschke"
# num_files = 1  

# mc_maze.download_and_preprocess(raw_data_dir, processed_data_dir, subject_name, num_files=num_files)

# # To load all the data, use:
# # mc_maze.download_and_preprocess(raw_data_dir, processed_data_dir, "jenkins", num_files=4)
# # mc_maze.download_and_preprocess(raw_data_dir, processed_data_dir, "nitschke", num_files=3)

___

## Tutorial start!

In the churchland_shenoy_neural_2012 dataset, the subjects are performing a center-out reaching task on a variety of different maze configurations. Each maze configuration comes in 3 versions:
- 1 target and no barriers.
- 1 target with barriers.
- 3 targets, with barriers. But 2 of the targets are distractors and unaccessible given the barrier configuration.

So maze conditions 1, 2, 3 are related; as are maze conditions 4, 5, 6; and so on.
Neural activity was recorded from the dorsal premotor (PMd) and primary motor (M1) cortices. A variety of other data (monkey hand position, velocity and acceleration, gaze position...) is also provided.

In [None]:
"""Set notebook settings."""

%load_ext autoreload
%autoreload 2

In [None]:
"""Import packages."""

# Standard library
from datetime import datetime
from pathlib import Path

# IPython/Jupyter
from IPython.display import display

# Third-party
import cebra
import numpy as np
import pandas as pd
import seaborn as sns
import torch as t
from einops import asnumpy, reduce
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter1d
from sklearn.metrics import r2_score
from sklearn.preprocessing import StandardScaler

# Local project modules
from nldisco import train as nt
from nldisco import mc_maze, decode, pipeline, cebra_utils
from nldisco.train_val_split import train_val_split_by_proportion, train_val_split_by_session

# 1. Load and prepare data

## Load and prepare spike data

In [None]:
"""Load session data and bin spikes."""

data_path = Path(r"../data/processed")
subject_name = "nitschke"  # "jenkins" or "nitschke"
sessions = mc_maze.load_sessions(data_path, subject_name)

bin_size = 0.05  # in seconds
spikes_df = mc_maze.bin_spike_data(sessions, bin_size)  # this can take several minutes

display(spikes_df)

In [None]:
"""Quick plots and stats to get a sense of the spike data."""

print("Firing rates distribution:")
# Compute mean firing rate (Hz) per unit
duration_sec = len(spikes_df) * bin_size
mean_firing_rates = spikes_df.sum(axis=0) / duration_sec  # spikes_arr/sec
# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(mean_firing_rates, bins=30, edgecolor='black')
plt.xlabel('Mean Firing Rate (Hz)')
plt.ylabel('Number of Units')
plt.title('Distribution of Mean Firing Rates per Unit')
plt.grid(True)
plt.tight_layout()
plt.show()
# Print stats
print("Mean: {:.2f} Hz".format(mean_firing_rates.mean()))
print("Range: {:.2f}–{:.2f} Hz".format(mean_firing_rates.min(), mean_firing_rates.max()))

print("\n\nSpike count distribution and sparsity stats:")
# Flatten spike counts
flattened_spike_counts = spikes_df.values.flatten()
# Define bins that align exactly to integer spike counts
max_count = flattened_spike_counts.max()
bins = np.arange(0, max_count + 2) - 0.5  # centers bins on integers
# Plot histogram
plt.figure(figsize=(8, 5))
plt.hist(flattened_spike_counts, bins=bins, edgecolor='black')
plt.title("Distribution of Spike Counts per Unit per Time Bin")
plt.xlabel("Spike Count")
plt.ylabel("Number of Unit-Bin Combinations")
plt.yscale("log")
plt.grid(True)
plt.tight_layout()
plt.show()
# Print stats
frac_nonzero_bins = (spikes_df != 0).values.sum() / spikes_df.size
frac_nonzero_examples = (spikes_df.sum(axis=1) > 0).mean()
print(f"Fraction of non-zero bins: {frac_nonzero_bins:.4f}")
print(f"Fraction of non-zero examples: {frac_nonzero_examples:.4f}")


## Load and prepare environment / behavior (meta)data

In [None]:
"""Load and bin (meta)data to match spike bins."""

# Retrieve and collate metadata (hand/eye/events) across sessions
metadata, trials_df = mc_maze.retrieve_metadata(sessions) # this can take a minute
# Bin metadata to the match the binned spikes_df
metadata_binned = mc_maze.bin_metadata(metadata, trials_df, bin_size, spikes_df.index)

print("Metadata:")
display(metadata)
print("Binned metadata:")
display(metadata_binned)

## Train/val split

In [None]:
"""Train/val split, smooth and normalize spikes."""

split_by_session = False # if False, will split by proportion

if split_by_session:
    train_trials, val_trials = train_val_split_by_session(
        metadata_binned["trial_idx"].to_numpy(),
        metadata_binned["session"].to_numpy(),
        train_sessions=[1, 2],  # pick your training sessions
        shuffle=True,
        seed=0,
    )
else:
    train_trials, val_trials = train_val_split_by_proportion(
        metadata_binned["trial_idx"].values,
        train_proportion=0.8,
        shuffle=True,
        seed=0,
    )

# Create boolean masks to split metadata and spikes into train/val sets
train_mask = metadata_binned['trial_idx'].isin(train_trials)
val_mask = metadata_binned['trial_idx'].isin(val_trials)

# Split metadata
metadata_binned_train = metadata_binned[train_mask].reset_index(drop=True)
metadata_binned_val = metadata_binned[val_mask].reset_index(drop=True)

# Split spikes 
spikes_arr = spikes_df.values.astype(np.float32)
spikes_train_arr = spikes_arr[train_mask]
spikes_val_arr = spikes_arr[val_mask]

# Smooth spikes
sigma = 0.05 / bin_size
spikes_train_arr = gaussian_filter1d(spikes_train_arr, sigma=sigma, axis=0)
spikes_val_arr = gaussian_filter1d(spikes_val_arr, sigma=sigma, axis=0)

# Normalize spikes (fit normalization on training data only)
train_max = spikes_train_arr.max()
spikes_train_arr = spikes_train_arr / train_max
spikes_val_arr = spikes_val_arr / train_max

# Summary
print(f"Train set: {len(train_trials)} trials ({train_mask.sum()} time bins)")
print(f"Val set: {len(val_trials)} trials ({val_mask.sum()} time bins)")
print(f"Spike data shapes: train {spikes_train_arr.shape}, val {spikes_val_arr.shape}")

In [None]:
"""Convert spike data to torch tensors and move to torch device."""

# it's best to have a gpu for training!
device = t.device("cuda" if t.cuda.is_available() else "cpu")
print(f"{device=}")

spikes_train = t.from_numpy(spikes_train_arr).to(device).to(dtype=t.bfloat16)
spikes_val = t.from_numpy(spikes_val_arr).to(device).to(dtype=t.bfloat16)

display(spikes_train)
display(spikes_val)

# 2. Train models

> If desired, you can choose to skip this section and load pre-saved SAE activations instead. A set of activations per subject (Jenkins and Nitschke) is provided - go to section "3. Save/load SAE activations" for more information. It is however still highly recommended to read the rest of this section to understand how the SAEs are trained.

This code trains 2 SAEs with identical setups so that you can compare the different instances and ensure they both give similar results. For each time bin of neural data in the train and val sets, the SAEs' latent activations are calculated (i.e., the output values of their hidden layer neurons), and it is these activations that will be used to find a latent's correspondence with external variables (features). A particularity of the NLDisco pipeline below is that it trains *Matryoshka* SAEs.

**Matryoshka architecture:**

The Matryoshka architecture segments the latent space into multiple levels, each of which attempts a full reconstruction of the target neural activity. In the figure below, black boxes indicate the latents (model neurons) involved in a given level, while light-red boxes indicate additional latents recruited at lower levels. A top-$k$ selection is used to choose which latents to recruit for reconstruction at each level (yellow neuron within each light-red box -- $k=1$ for each level in this example).

This nested arrangement is motivated by the idea that multi-scale feature learning can mitigate “feature absorption” (a common issue where a more specific feature subsumes a portion of a more general feature), allowing both coarse and detailed representations to emerge simultaneously.

- Latents in the highest level ($L_1$) typically correspond to broad, high-level features (e.g., a round object), 
- Latents exclusive to the lowest level ($L_3$) often correspond to more specific, fine-grained features (e.g., a basketball)

![](./figures/msae.svg)

*(If the image above does not render, see ./figures/msae.svg)*

**Key training parameters to play with:**

`SedConfig` (model-level):
- `dsed_topk_map`: how many top-k model neurons are kept active at each level (controls sparsity per level)
- `dsed_loss_x_map`: relative weight of each Matryoshka level to the overall reconstruction loss

`optimize()` (optimizer-level):
- `n_steps`: total training steps
- `batch_sz`: how many examples per batch
- `lr` (set in optimizer): learning rate used by whatever optimizer you choose (e.g. Adam below) and you can optionally use a scheduler (`use_lr_sched=True`)
- `dead_latent_window`: number of steps a latent (model neuron) can stay inactive before being flagged as “dead”  
  - Dead latents are revived with an auxiliary loss: instead of reconstructing the full input, they try to reconstruct only the residual error (the part the active neurons failed to capture)
  - This gives inactive latents a chance to become useful again, preventing them from staying permanently silent   
- `loss_fn` – reconstruction objective: built-ins are mse and msle, or you can pass a custom callable


## Train SAEs

In [None]:
"""Set config."""

# total of 1024 model neurons in 3 nested levels: 0-256, 0-512 and 0-1024
dsae_topk_map = {256: 8, 512: 16, 1024: 24}
dsae_topk_map = dict(sorted(dsae_topk_map.items()))  # ensure sorted from smallest to largest
dsae_loss_x_map = {256: 1, 512: 1.25, 1024: 1.5}
dsae_loss_x_map = dict(sorted(dsae_loss_x_map.items()))
dsae = max(dsae_topk_map.keys())
n_inst = 2  # number of SAE instances to train in parallel

In [None]:
"""Train model."""

sae_cfg = nt.SedConfig(
    n_input=spikes_train.shape[1],
    dsed_topk_map=dsae_topk_map,
    dsed_loss_x_map=dsae_loss_x_map,
    n_instances=n_inst,
)
sae = nt.Sed(sae_cfg).to(device)
loss_fn = nt.msle
lr = 5e-3

n_epochs = 20
batch_sz = 1024
n_steps = (spikes_train.shape[0] // batch_sz) * n_epochs
log_freq = max(1, n_steps // n_epochs // 2)
dead_latent_window = max(1, n_steps // n_epochs // 3)

data_log = nt.optimize(  # train model
    spk_cts=spikes_train,
    sed=sae,
    loss_fn=loss_fn,
    optimizer=t.optim.Adam(sae.parameters(), lr=lr),
    use_lr_sched=True,
    dead_latent_window=dead_latent_window,
    n_steps=n_steps,
    log_freq=log_freq,
    batch_sz=batch_sz,
    log_wandb=False,
    plot_l0=False,
)

## Validate SAEs

To validate the SAEs, examine the following printed values and plots.

1. **NaN check:**
Confirms that no NaN values appear in the encoder/decoder weights.

2. **Decoder weights:**
The histograms show the distribution of decoder weights for each SAE instance. Both should look similar and roughly centered around zero. If one model has a very different distribution, it may not have trained properly.

3. **L0 of latents:**
Shows how many latents are active per example, measured at the final Matryoshka level (which contains all latents). You want the median to be around the top-$k$ setting you chose for this largest level.

4. **Latent activity density:**
Fraction of time each latent is active. Most should fire sparsely (low fractions). Many at 0 (dead) is common but ideally limited, while many at 1 (always-on) is undesirable as it breaks sparsity.

5. **R² of reconstructions:**
R² between reconstructions and true spike counts, shown per example and per unit.

6. **Cosine similarity of reconstructions:**
Cosine similarity between reconstructions and true spike counts, shown per example and per unit.

7. **R² of summed spike counts:**
Reports how well the reconstructions capture the total population activity (sum of all spikes per example).

**What to look for:**
- No NaNs in encoder/decoder weights.
- All SAE instances producing broadly similar plots.
- Decoder weights centered near zero.
- Latents used sparsely but not all dead.
- Reconstruction metrics high (R² and cosine similarity near 1).
- High R² for summed spike counts (close to 1).

In [None]:
"""Check for nans in weights."""

sae.W_dec.isnan().sum(), sae.W_enc.isnan().sum()

In [None]:
"""Visualize weights."""

fig, ax = plt.subplots(figsize=(8, 6))
for inst in range(n_inst):
    W_dec_flat = asnumpy(sae.W_dec[inst].float()).ravel()
    sns.histplot(W_dec_flat, bins=1000, stat="probability", alpha=0.7, label=f"SAE {inst}")
    
ax.set_title("SAE decoder weights")
ax.set_xlabel("Weight value")
ax.set_ylabel("Frequency")
ax.legend()

In [None]:
"""Visualize metrics over all examples and units."""

fig, topk_acts_4d_train, recon_spk_cts_train, r2_per_unit_train, _, cossim_per_unit_train, _ = (
    nt.eval_model(spikes_train, sae, batch_sz=batch_sz)
)

In [None]:
"""Calculate variance explained of summed spike counts."""

n_recon_examples = recon_spk_cts_train.shape[0]
recon_summed_spk_cts = reduce(recon_spk_cts_train, "example inst neuron -> example inst", "sum")

actual_summed_spk_cts = reduce(spikes_train, "example neuron -> example", "sum")
actual_summed_spk_cts = actual_summed_spk_cts[:n_recon_examples]  # trim to match

for inst in range(n_inst):
    r2 = r2_score(
        asnumpy(actual_summed_spk_cts.float()),
        asnumpy(recon_summed_spk_cts[:, inst].float()),
    )
    print(f"SAE instance {inst} R² (summed spike count over all units per example) = {r2:.3f}")


## Remove bad units and retrain

In [None]:
"""Remove bad units and retrain."""

# Set threshold for removing units
r2_thresh = 0.1
inst = 0
r2_inst = r2_per_unit_train[:, inst]
keep_mask = r2_inst > r2_thresh
print(f"frac units above {r2_thresh=}: {keep_mask.sum() / keep_mask.shape[0]:.2f}")
print(f"Number to keep: {keep_mask.sum()} / {keep_mask.shape[0]}")

if keep_mask.all():
    print("All units pass threshold — skipping retrain.")
    spikes_train_pruned = spikes_train
    spikes_val_pruned = spikes_val
else:
    # Prune
    spikes_train_pruned = spikes_train[:, keep_mask]
    spikes_val_pruned = spikes_val[:, keep_mask]

    # Retrain SAE on pruned train data
    sae_cfg = nt.SedConfig(
        n_input=spikes_train_pruned.shape[1],
        dsed_topk_map=dsae_topk_map,
        dsed_loss_x_map=dsae_loss_x_map,
        seq_len=1,
        n_instances=n_inst,
    )
    sae = nt.Sed(sae_cfg).to(device)
    loss_fn = nt.msle
    lr = 5e-3

    n_epochs = 20
    batch_sz = 1024
    n_steps = (spikes_train_pruned.shape[0] // batch_sz) * n_epochs
    log_freq = max(1, n_steps // n_epochs // 2)
    dead_latent_window = max(1, n_steps // n_epochs // 3)

    data_log = nt.optimize(
        spk_cts=spikes_train_pruned,
        sed=sae,
        loss_fn=loss_fn,
        optimizer=t.optim.Adam(sae.parameters(), lr=lr),
        use_lr_sched=True,
        dead_latent_window=dead_latent_window,
        n_steps=n_steps,
        log_freq=log_freq,
        batch_sz=batch_sz,
        log_wandb=False,
        plot_l0=False,
    )

In [None]:
"""Re-visualize metrics over all examples and units."""

if keep_mask.all():
    print("All units pass threshold — skipping re-visualization.")
else:
    fig, topk_acts_4d_train, recon_spk_cts_train, r2_per_unit_train, _, cossim_per_unit_train, _ = (
        nt.eval_model(spikes_train_pruned, sae, batch_sz=batch_sz)
)
    n_recon_examples_train = recon_spk_cts_train.shape[0]
    recon_summed_train = reduce(recon_spk_cts_train, "example inst unit -> example inst", "sum")

    actual_summed_train = reduce(spikes_train_pruned, "example unit -> example", "sum")
    actual_summed_train = actual_summed_train[:n_recon_examples_train]

    for inst in range(n_inst):
        r2 = r2_score(
            asnumpy(actual_summed_train.float()),
            asnumpy(recon_summed_train[:, inst].float()),
        )
        print(f"SAE instance {inst} R² (summed spike count per example) = {r2:.3f}")

In [None]:
"""Visualize metrics on validation data."""

if spikes_val_pruned.shape[0] == 0:
    print("No validation data available.")
else:
    print("Validation data metrics:")

    fig, topk_acts_4d_val, recon_spk_cts_val, r2_per_unit_val, _, cossim_per_unit_val, _ = (
        nt.eval_model(spikes_val_pruned, sae, batch_sz=batch_sz)
)
    n_recon_examples_val = recon_spk_cts_val.shape[0]
    recon_summed_val = reduce(recon_spk_cts_val, "example inst unit -> example inst", "sum")

    actual_summed_val = reduce(spikes_val_pruned, "example unit -> example", "sum")
    actual_summed_val = actual_summed_val[:n_recon_examples_val]

    for inst in range(n_inst):
        r2 = r2_score(
            asnumpy(actual_summed_val.float()),
            asnumpy(recon_summed_val[:, inst].float()),
        )
        print(f"SAE instance {inst} R² (summed spike count per example) = {r2:.3f}")

# 3. Save/load SAE activations

A set of activations per subject (Jenkins and Nitschke) is provided in the `saved_sae_latent_activations` folder. They were both generated using a split by proportion (80/20 train/val split). If you want to use them, set the `load_activations` variable below to `True`.

In [None]:
"""Load saved activations if available; otherwise build acts_df and (optionally) save."""

save_path = Path(r"../saved_sae_latent_activations")
load_activations = True
save_activations = False
activations_file_train = "sae_activations_train.parquet"
activations_file_val = "sae_activations_val.parquet"
mask_file_train = "train_mask.parquet"
mask_file_val = "val_mask.parquet"

# Build save path (same style as before)
session_dates = []
for session in sessions:
    session_date = datetime.fromtimestamp(session.session.recording_date).strftime("%Y%m%d")
    session_dates.append(session_date)
session_dates_str = "_".join(session_dates)

data_identifier = f"{subject_name}_{session_dates_str}"
activations_save_path = save_path / data_identifier / "sae_activations"
activations_save_path.mkdir(parents=True, exist_ok=True)

if load_activations:
    acts_df_train = pd.read_parquet(activations_save_path / activations_file_train)
    acts_df_val = (
        pd.read_parquet(activations_save_path / activations_file_val) 
        if (activations_save_path / activations_file_val).exists() else None
    )
    train_mask = pd.read_parquet(activations_save_path / mask_file_train)["mask"]
    val_mask = pd.read_parquet(activations_save_path / mask_file_val)["mask"]
    print(f"Loaded activations from {activations_save_path}")
else:
    # Train
    arr_tr = asnumpy(topk_acts_4d_train)  # [example_idx, instance_idx, latent_idx, act_value]
    # Sparse activations (tight dtypes on indices, fp32 values)
    acts_df_train = pd.DataFrame({
        "example_idx": arr_tr[:, 0].astype(int),
        "instance_idx": arr_tr[:, 1].astype(int),
        "latent_idx": arr_tr[:, 2].astype(int),
        "activation_value": arr_tr[:, 3].astype(np.float32),
    })

    if spikes_val_pruned.shape[0] > 0:
        # Val
        arr_va = asnumpy(topk_acts_4d_val)
        acts_df_val = pd.DataFrame({
            "example_idx": arr_va[:, 0].astype(int),
            "instance_idx": arr_va[:, 1].astype(int),
            "latent_idx": arr_va[:, 2].astype(int),
            "activation_value": arr_va[:, 3].astype(np.float32),
        })
    else:
        acts_df_val = None

    n_examples_train = (int(acts_df_train["example_idx"].max()) + 1)
    std_threshold = 1e-6

    # Precompute squared values once, then sum both in one grouped pass
    acts_df_train_with_sq = (
        acts_df_train.assign(activation_value_sq=acts_df_train["activation_value"] ** 2)
    )

    latent_stats = (
        acts_df_train_with_sq.groupby(["instance_idx", "latent_idx"], as_index=False)
        .agg(sum_val=("activation_value", "sum"),
            sum_sq=("activation_value_sq", "sum"))
    )
    n_examples_train = int(acts_df_train["example_idx"].max()) + 1
    latent_stats["mean"] = latent_stats["sum_val"] / n_examples_train
    latent_stats["var"]  = (latent_stats["sum_sq"] / n_examples_train) - latent_stats["mean"]**2
    latent_stats["std"]  = np.sqrt(np.clip(latent_stats["var"].to_numpy(), 0.0, None))

    kept_latents = latent_stats.loc[latent_stats["std"] > std_threshold, ["instance_idx", "latent_idx"]]
    n_dropped = len(latent_stats) - len(kept_latents)

    if n_dropped:
        acts_df_train = acts_df_train.merge(kept_latents, on=["instance_idx", "latent_idx"], how="inner")
        acts_df_val = (
            acts_df_val.merge(kept_latents, on=["instance_idx", "latent_idx"], how="inner")
            if (spikes_val_pruned.shape[0] > 0) else None
        )
        print(f"Pruned {n_dropped} SAE latents (std ≤ {std_threshold}). Kept {len(kept_latents)}.")

    if save_activations:
        acts_df_train.to_parquet(activations_save_path / activations_file_train, index=False)
        (
            acts_df_val.to_parquet(activations_save_path / activations_file_val, index=False) 
            if spikes_val_pruned.shape[0] > 0 else None
        )
        train_mask.to_frame(name="mask").to_parquet(
            activations_save_path / mask_file_train, index=True
        )
        val_mask.to_frame(name="mask").to_parquet(
            activations_save_path / mask_file_val, index=True
        )
        print(f"Saved activations to {activations_save_path}")

print(f"Activations: \nTrain shape: {acts_df_train.shape}" + (f", Val shape: {acts_df_val.shape}" if acts_df_val is not None else ""))

# 4. Decode

Here we use our NLDisco latents to decode a behavioural variable (in this case hand velocity). We also provide code to train CEBRA models and use the resulting CEBRA embeddings to decode as a comparison.

In [None]:
"""Define decoding target."""

# Target for decoding (can be replaced with other metadata e.g., accel, position, etc.)
decoding_target = np.column_stack([
    metadata_binned["vel_x"].to_numpy(dtype=np.float32),
    metadata_binned["vel_y"].to_numpy(dtype=np.float32),
])
y_train = decoding_target[train_mask]
y_val = decoding_target[val_mask]

### With NLDisco

In [None]:
"""Prepare the data for decoding."""

if acts_df_val is None:
        raise ValueError("Validation split is required for this section.")

# Build CSR matrices
n_examples_train = int(train_mask.sum())
n_examples_val = int(val_mask.sum()) if (acts_df_val is not None) else 0
feature_index = decode.build_feature_index(acts_df_train, acts_df_val)
X_train = decode.df_to_csr(acts_df_train, feature_index, n_examples=n_examples_train)
X_val = decode.df_to_csr(acts_df_val, feature_index, n_examples=n_examples_val) if (acts_df_val is not None) else None
print(f"X_train shape: {X_train.shape}" + (f", X_val shape: {X_val.shape}" if X_val is not None else "")) # (examples, latents)


In [None]:
""""Decode using NLDisco latents."""

best = decode.decode_with_lag_sweep(
    X_tr=X_train,
    X_va=X_val,
    y_tr=y_train,
    y_va=y_val,
    lags=range(0, 6), # bins to sweep
    alpha=30.0 # ridge strength
)

print(f"Best lag (bins): {best['lag']}")
print(f"R² per dimension: {np.round(best['r2_per_dim'], 3)}")
print(f"Mean R²: {best['r2_mean']:.3f}")

### With CEBRA

In [None]:
"""Train CEBRA models."""

if val_mask.sum() == 0:
    raise ValueError("Validation split is required for this section.")

save_path = Path(r"../saved_CEBRA_models")
cebra_save_path = save_path / data_identifier / "CEBRA_models"

# Prepare data for CEBRA
trial_ids_train = metadata_binned['trial_idx'][train_mask].values
trial_ids_val = metadata_binned['trial_idx'][val_mask].values

# Only train if there are not already models saved
if any(cebra_save_path.glob("*.pt")):
    print(f"CEBRA models already found in {cebra_save_path}, skipping training.")
else:
    print(f"Saving CEBRA models to {cebra_save_path}")
    cebra_save_path.mkdir(parents=True, exist_ok=True)
    params_grid = dict(
        output_dimension=[48],
        time_offsets=[1, 2],  # in the paper for 20ms bins they use 1-2 so I think 1 here for 50ms is good? or [0,1]?
        model_architecture='offset10-model',
        temperature_mode='constant',
        temperature=np.linspace(0.0001, 0.004, 10).tolist(),
        max_iterations=[5000],
        batch_size=[512],
        device='cuda_if_available',
        num_hidden_units=[[128, 256, 512]],
        verbose=True)

    # Run the grid search
    grid_search = cebra.grid_search.GridSearch()
    datasets = {data_identifier: (spikes_train, trial_ids_train)}
    grid_search.fit_models(datasets, params=params_grid, models_dir=cebra_save_path)

In [None]:
"""Validate top CEBRA model and visualise embeddings."""

# Load top model
df_results = grid_search.get_df_results(models_dir=cebra_save_path)
dataset_name = df_results["dataset_name"].iloc[0]
best_model, best_model_name = grid_search.get_best_model(dataset_name=dataset_name, models_dir=cebra_save_path)
print("The best model is:", best_model_name)
model_path = cebra_save_path / f"{best_model_name}.pt"
top_model = cebra.CEBRA.load(model_path, weights_only=False)
print("Training InfoNCE loss curve:")
ax = cebra.plot_loss(top_model)
plt.show()
print(f"Final InfoNCE training loss: ", df_results["loss"].min())

# Transform
top_train_embedding = top_model.transform(spikes_train_arr)
top_val_embedding = top_model.transform(spikes_val_arr)

# InfoNCE loss
loss_train = cebra.sklearn.metrics.infonce_loss(top_model, spikes_train_arr, trial_ids_train, num_batches=200, correct_by_batchsize=True)
loss_val = cebra.sklearn.metrics.infonce_loss(top_model, spikes_val_arr, trial_ids_val, num_batches=200, correct_by_batchsize=True)
print("\nInfoNCE loss recalculated:")
print("Train:", loss_train)
print("Validation: ", loss_val)
print("\n")

# Plot embeddings
# Decide which metadata variable to use for coloring
feature_train = metadata_binned["vel_x"].to_numpy(dtype=np.float32)[train_mask]
feature_val = metadata_binned["vel_x"].to_numpy(dtype=np.float32)[val_mask]
# Create random samples for plotting
n_plot_train = min(10_000, top_train_embedding.shape[0])
train_sample = np.random.choice(top_train_embedding.shape[0], size=n_plot_train, replace=False)
n_plot_val = min(10_000, top_val_embedding.shape[0])
val_sample = np.random.choice(top_val_embedding.shape[0], size=n_plot_val, replace=False)
# Sample embeddings and features
top_train_embedding_sample = top_train_embedding[train_sample, :]
feature_train_sample = feature_train[train_sample]
top_val_embedding_sample = top_val_embedding[val_sample, :]
feature_val_sample = feature_val[val_sample]
# Plot
fig = cebra.integrations.plotly.plot_embedding_interactive(
    top_train_embedding_sample,
    embedding_labels=feature_train_sample,
    title="CEBRA-Time (train)",
    markersize=3,
    cmap="rainbow"
)
fig.show()
fig = cebra.integrations.plotly.plot_embedding_interactive(
    top_val_embedding_sample,
    embedding_labels=feature_val_sample,
    title="CEBRA-Time (validation)",
    markersize=3,
    cmap="rainbow"
)
fig.show()

In [None]:
"""Load CEBRA models and average embeddings."""

# Load all .pt files in folder
pt_files = sorted(cebra_save_path.glob("*.pt"))
if not pt_files:
    raise FileNotFoundError(f"No .pt files found in {cebra_save_path}")
print(f"Found {len(pt_files)} models.")

# Compute validation losses and store together with paths
model_losses = []
for p in pt_files:
    model = cebra.CEBRA.load(p, weights_only=False)
    loss_val = cebra.sklearn.metrics.infonce_loss(model, spikes_val_arr, trial_ids_val, num_batches=200, correct_by_batchsize=True)
    model_losses.append((p, loss_val))

# Sort by loss and use this to determine a loss threshold to use with the function below
model_losses.sort(key=lambda x: x[1])
print("Model losses (sorted):")
for p, l in model_losses:
    print(f"{p.name}: {l:.4f}")

# Average embeddings
X_train, X_val = cebra_utils.average_cebra_embeddings_procrustes(
    model_losses, spikes_train_arr, spikes_val_arr, loss_threshold=10
)

In [None]:
"""Decode using averaged CEBRA embeddings."""

best = decode.decode_with_lag_sweep(
    X_tr=X_train,
    X_va=X_val,
    y_tr=y_train,
    y_va=y_val,
    lags=range(0, 6),
    scaler=StandardScaler(with_mean=True, with_std=True),
    alpha=1.0,
)

print(f"Best lag (bins): {best['lag']}")
print(f"R² per dimension: {best['r2_per_dim']}")
print(f"Mean R²: {best['r2_mean']}")

# 5. Find features

## Pick dataset

In [None]:
"""Pick whether to find features in training or validation set."""

search_train = True

if search_train:
    acts_df_split = acts_df_train
    metadata_binned_split = metadata_binned[train_mask].copy()
    spikes_df_split = spikes_df[train_mask].copy()
else:
    acts_df_split = acts_df_val
    metadata_binned_split = metadata_binned[val_mask].copy()
    spikes_df_split = spikes_df[val_mask].copy()

## Automatically map latents to metadata

Now we see if latents represent properties of known continuous and discrete behavioral and environmental variables in the MC_Maze task.

**How it works:**

Latents are mapped to real-world variables through the calculation of a selectivity score. For a latent $l$ and condition $c$ (variable/value combination e.g., velocity is between 0 and 1, or maze condition = 3, etc.):

$$
\text{activation\_frac\_during} =
\frac{\#\{\text{activations of } l \text{ in examples with } c\}}
     {\#\{\text{examples with } c\}}
$$

$$
\text{activation\_frac\_outside} =
\frac{\#\{\text{activations of } l \text{ in examples without } c\}}
     {\#\{\text{examples without } c\}}
$$

$$
\text{selectivity\_score} =
\frac{\text{activation\_frac\_during}}
     {\text{activation\_frac\_during} + \text{activation\_frac\_outside}}
$$

- $\approx 1$: latent mainly active *during* the condition (highly selective)  
- $\approx 0.5$: latent active equally in/out (not selective)  
- $\approx 0$: latent mostly active *outside* the condition

The `map_latents_to_metadata` function:
1. For discrete variables: computes activation fractions + selectivity score per condition value.  
2. For continuous variables: bins, then reuses discrete analysis.  
3. Results are ranked by selectivity score and the `top_n_mappings` are returned.

**Key arguments to play with:**
- `discrete_vars` and `continuous_vars`: as default only one of each was included because the function takes time to run, but you may be interested in exploring different or additional variables
- `n_bins_continuous`: number of bins for continuous variables. This will affect whether you find more general features (small number of bins so you have less granularity e.g., you can only distinguish between fast vs slow hand velocity) or specific features (larger number of bins for more granularity e.g., you can now distinguish between very fast vs fast vs intermediate vs slow vs very slow hand velocity)
- `min_activation_frac`: minimum fraction of condition examples a latent must activate in
- `top_n_mappings`: number of highest-scoring mappings kept per variable/value/instance combination (default `3`)

In [None]:
"""Map latents to metadata variables."""

# discrete_vars = ['event', 'maze_condition', 'barriers', 'targets', 'hit_position_x', 'hit_position_y', 'hit_position_angle']
# continuous_vars = ['vel_magnitude', 'accel_magnitude', 'movement_angle']
discrete_vars = ['event']
continuous_vars = ['vel_magnitude']
latent_metadata_mapping = pipeline.map_latents_to_metadata(  # this can take a minute
    acts_df_split,
    metadata_binned_split,
    discrete_vars=discrete_vars,
    continuous_vars=continuous_vars,
    min_activation_frac=0.5,
    n_bins_continuous=[[0, 1, 2, 4, 6, 10, 15, 40, 100, 250, 1300]], # Or [12] for 12 equal-width bins
    top_n_mappings=5
)
latent_metadata_mapping

## Find meaningful features and their contributing units

The code below generates a dashboard like this one, providing an interactive way to explore a model's latents in search for meaningful features and their contributing units:

![](./figures/feature_finding_dashboard.png)

*(If the image above does not render, see ./figures/feature_finding_dashboard.png)*

The `latent_metadata_mapping` dataframe generated is used as a starting point to identify promising features ("preset" option on the dashboard). If you want more preset options to explore, feel free to return to the "Automatically map latents to metadata" section and play with the arguments that control the mapping process.

You may also choose to look through latents manually, though this tends to take more time ("manual selection" option on the dashboard).

Remember:
- Matryoshka SAEs split latents across multiple levels, with higher levels capturing broad patterns and lower levels capturing more specific ones. Check the `dsae` settings you used to train your SAEs to see what latents were allocated to each level.
- The neural recordings come from two brain regions (PMd and M1). Use the `units_df` dataframe generated below to look up the mapping of unit IDs to brain regions - this can help you see whether particular features are driven more strongly by units in one region or the other.
- By default, this tutorial set `search_train = True` at the beginning of this section, so the feature search runs on the training set. You can switch to the validation set instead, which lets you check that the SAEs generalise to unseen data (and to unseen sessions if you chose to split by session in section "1. Load and prepare data").

*Note*: Here we present a prepared dashboard for viewing the activity of latents in the context of the MC_Maze task. In general when working with new datasets, we recommend spending some time designing similar simple, bespoke dashboards.

In [None]:
"""Print unit to brain region mapping for reference."""

unit_stats_ids = [u.decode() for u in session.units.id]
units_df = pd.DataFrame({
    "unit_id": [int(u.split("elec")[1]) for u in unit_stats_ids],
    "region": ["PMd" if "group_1" in u else "M1" for u in unit_stats_ids]
})

units_df

In [None]:
"""Feature finding dashboard."""

save_path = Path(r"../saved_plot_data_paper")
plot_data_save_path = save_path / data_identifier / "plot_data_all_data"

pipeline.build_feature_finding_dashboard(
    latent_metadata_mapping=latent_metadata_mapping,
    acts_df=acts_df_split,
    spikes_df=spikes_df_split,
    metadata_binned=metadata_binned_split,
    save_path=plot_data_save_path
)

# 6. Make paper plots

In [None]:
plot_data_save_path_nitschke_all_data = save_path / f"nitschke_20090812_20090819_20090910" / "plot_data_all_data"
plot_data_save_path_nitschke_sessions_1_2 = save_path / f"nitschke_20090812_20090819_20090910" / "plot_data_sessions_1_2"
plot_data_save_path_jenkins_all_data = save_path / f"jenkins_20090912_20090916_20090918_20090923" / "plot_data_all_data"

# Extreme velocity features
fig = pipeline.plot_selectivity_score_from_saved(
    plot_data_save_path_nitschke_all_data / "plotdata_inst0_latent24_vel_magnitude_binned.csv",
    plot_data_save_path_nitschke_sessions_1_2 / "plotdata_inst1_latent98_vel_magnitude_binned.csv",
    plot_data_save_path_jenkins_all_data / "plotdata_inst1_latent78_vel_magnitude_binned.csv",
    labels=["N", "N (unseen session)", "J"],
    y_max=0.8,
)
fig.show()

# High velocity features
fig = pipeline.plot_selectivity_score_from_saved(
    plot_data_save_path_nitschke_all_data / "plotdata_inst0_latent2_vel_magnitude_binned.csv",
    plot_data_save_path_nitschke_sessions_1_2 / "plotdata_inst1_latent62_vel_magnitude_binned.csv",
    plot_data_save_path_jenkins_all_data / "plotdata_inst1_latent137_vel_magnitude_binned.csv",
    labels=["N", "N (unseen session)", "J"],
    y_max=0.8,
)
fig.show()

# Low velocity features
fig = pipeline.plot_selectivity_score_from_saved(
    plot_data_save_path_nitschke_all_data / "plotdata_inst0_latent127_vel_magnitude_binned.csv",
    plot_data_save_path_nitschke_sessions_1_2 / "plotdata_inst1_latent505_vel_magnitude_binned.csv",
    plot_data_save_path_jenkins_all_data / "plotdata_inst1_latent43_vel_magnitude_binned.csv",
    labels=["N", "N (unseen session)", "J"],
    y_max=0.8,
)
fig.show()

# Intermediate velocity features
fig = pipeline.plot_selectivity_score_from_saved(
    plot_data_save_path_nitschke_all_data / "plotdata_inst0_latent208_vel_magnitude_binned.csv",
    plot_data_save_path_nitschke_sessions_1_2 / "plotdata_inst1_latent191_vel_magnitude_binned.csv",
    labels=["N", "N (unseen session)"],
    y_max=0.8,
)
fig.show()

# Hit position on right feature
fig = pipeline.plot_selectivity_score_from_saved(
    plot_data_save_path_nitschke_all_data / "plotdata_inst0_latent169_hit_position_x.csv",
    labels=["N"],
    y_max=0.6,
)
fig.show()