## GLM Analysis with NeMoS

This notebook demonstrates how to fit **Generalized Linear Models (GLMs)** to predict neural spiking activity from kinematic features using [NeMoS](https://nemos.readthedocs.io/) (Neural Models). We reproduce the methodology from the Turner lab Brain 2016 paper: *"Movement encoding deficits in the motor cortex of parkinsonian macaque monkeys"* (Pasquereau & Turner, Brain 2016).

**What you'll learn:**
- How to build design matrices for neural encoding models
- Fitting Poisson GLMs with NeMoS
- Using JAX's `vmap` for efficient vectorized fitting
- Statistical validation via shuffle-based significance testing

### Why GLMs for Neural Data?

PETHs (like in the [PETH tutorial](https://github.com/catalystneuro/turner-lab-to-nwb/blob/main/notebooks/turner_m1_peth.ipynb)) show *when* neurons fire relative to events, but don't tell us *why*. GLMs let us ask: **which kinematic features predict spiking?**

The model assumes spike counts follow a Poisson distribution:

$$\text{E}[\text{spike count}] = \exp(\beta_0 + \beta_1 x_1 + \beta_2 x_2 + ...)$$

Where $x_i$ are kinematic features (direction, position, velocity, etc.) and $\beta_i$ are learned coefficients. A positive $\beta$ means higher feature values increase firing rate; negative means they decrease it.

In [None]:
import jax
import matplotlib.pyplot as plt
import nemos as nmo
import numpy as np
import pynapple as nap
from pynwb import NWBHDF5IO
from dandi.dandiapi import DandiAPIClient
import remfile
import h5py

# Stream NWB file from DANDI
dandiset_id = "001636"
session_id = "V++v2703++PreMPTP++Depth19880um++19990607"

client = DandiAPIClient()
dandiset = client.get_dandiset(dandiset_id)
assets = list(dandiset.get_assets())
asset = next(a for a in assets if session_id in a.path)
print(f"Streaming: {asset.path}")

# Open remote file
url = asset.get_content_url(follow_redirects=1, strip_query=True)
file = remfile.File(url)
h5_file = h5py.File(file, "r")
io = NWBHDF5IO(file=h5_file, load_namespaces=True)
nwbfile = io.read()
print(f"Session: {nwbfile.session_id}")


### Preparing the Data with Pynapple

[Pynapple](https://pynapple.org) and [NeMoS](https://nemos.readthedocs.io/) are designed to work together seamlessly:

- **Pynapple** handles data loading, time alignment, and preprocessing (trial extraction, perievent alignment, binning)
- **NeMoS** handles the statistical modeling (GLM fitting, regularization, scoring)

Both libraries are built on NumPy arrays, so data flows naturally between them. Pynapple's `Tsd` and `TsGroup` objects have `.values` attributes that NeMoS accepts directly. This division of labor keeps each library focused: pynapple doesn't need to implement GLMs, and NeMoS doesn't need to handle NWB files or trial alignment.

We'll use pynapple to:
1. Load NWB data with `nap.NWBFile()`
2. Align kinematics and spikes to movement onset with `compute_perievent_continuous()` and `compute_perievent()`
3. Count spikes in time windows with `.restrict()` and `.count()`

Then pass the resulting NumPy arrays to NeMoS for GLM fitting.

In [None]:
data = nap.NWBFile(nwbfile)
print(f"Trials: {len(data['trials'])}, Units: {len(data['units'])}")

In [None]:
# Extract data
trials = data["trials"]
units = data["units"]
spikes = units[0]  # First unit

# Kinematics (handle different naming conventions)
elbow_position = data["ElbowAngle"] if "ElbowAngle" in data.keys() else data["SpatialSeriesElbowAngle"]
elbow_velocity = data["ElbowVelocity"] if "ElbowVelocity" in data.keys() else data["TimeSeriesElbowVelocity"]

# Convert to 1D Tsd - pynapple may load these as TsdFrame (2D) depending on the session.
# We need 1D Tsd so that .get(t) returns a scalar that broadcasts correctly across trials.
elbow_position = nap.Tsd(t=elbow_position.t, d=np.asarray(elbow_position.values).flatten())
elbow_velocity = nap.Tsd(t=elbow_velocity.t, d=np.asarray(elbow_velocity.values).flatten())

# Compute acceleration using pynapple's derivative method (wraps numpy.gradient)
elbow_acceleration = elbow_velocity.derivative()

dt = np.median(np.diff(elbow_velocity.t))
print(f"Kinematics sampling rate: {1/dt:.1f} Hz")

### Z-scoring Features

GLM coefficients are more interpretable when features are standardized. A coefficient of 0.5 then means: "a 1 standard deviation increase in this feature multiplies the firing rate by $e^{0.5} \approx 1.65$."

In [None]:
# Z-score kinematics
position_z = nap.Tsd(t=elbow_position.t, d=(elbow_position.values - elbow_position.mean()) / elbow_position.std())
velocity_z = nap.Tsd(t=elbow_velocity.t, d=(elbow_velocity.values - elbow_velocity.mean()) / elbow_velocity.std())
acceleration_z = nap.Tsd(t=elbow_acceleration.t, d=(elbow_acceleration.values - elbow_acceleration.mean()) / elbow_acceleration.std())

# Movement onset times and direction
movement_onset_times = nap.Ts(trials["derived_movement_onset_time"].values)
direction_values = np.array([1.0 if mt == "flexion" else -1.0 for mt in trials["movement_type"]])

# Reaction time (z-scored)
lateral_target_times = trials["lateral_target_appearance_time"].values
movement_onset_values = trials["derived_movement_onset_time"].values
reaction_times = movement_onset_values - lateral_target_times
reaction_times_z = (reaction_times - reaction_times.mean()) / reaction_times.std()

print(f"Trials: {len(trials)} ({(direction_values == 1).sum()} flexion, {(direction_values == -1).sum()} extension)")

### Building the Design Matrix

**The paper's approach:** Pasquereau & Turner (2016) asked whether M1 neurons encode kinematic features differently before vs. after MPTP-induced parkinsonism. To answer this, they fit a **separate GLM for each time point** relative to movement onset, using trials as observations.

**Why time-resolved fitting?** Neural encoding is not static. A neuron might encode velocity *before* movement starts (planning) but switch to encoding position *during* movement (feedback). By fitting independent models at each time point, we can track how encoding evolves.

**The sliding window approach:**
- **Window size**: 200ms (long enough to get reliable spike counts)
- **Step size**: 25ms (fine temporal resolution)
- **Time range**: -500ms to +800ms relative to movement onset

This creates 45 overlapping windows. For each window, we count spikes and sample kinematic features, then fit a GLM predicting spike count from features. The result is a time series of coefficients showing when each feature matters.

**Design matrix structure:**
For each time window, we need one row per trial with columns for each feature:

| Feature | Description | Values |
|---------|-------------|--------|
| Direction | Movement type (categorical) | +1 (flexion), -1 (extension) |
| Position | Elbow angle at window center | z-scored degrees |
| Velocity | Angular velocity at window center | z-scored deg/s |
| Acceleration | Angular acceleration at window center | z-scored deg/s² |
| RT | Reaction time (constant per trial) | z-scored seconds |

The first code cell below aligns all signals to movement onset using pynapple's `compute_perievent` functions. The second cell loops through time windows, extracting spike counts (y) and kinematic features (X) to build the full 3D tensor: `(n_time_bins, n_trials, n_features)`.

In [None]:
# Parameters matching the paper
WINDOW_SIZE_MS = 200.0
STEP_SIZE_MS = 25.0
TIME_RANGE_MS = (-500.0, 800.0)

trial_start_s = TIME_RANGE_MS[0] / 1000.0
trial_end_s = TIME_RANGE_MS[1] / 1000.0

# Align kinematics and spikes to movement onset using pynapple
position_aligned = nap.compute_perievent_continuous(position_z, tref=movement_onset_times, minmax=(trial_start_s, trial_end_s))
velocity_aligned = nap.compute_perievent_continuous(velocity_z, tref=movement_onset_times, minmax=(trial_start_s, trial_end_s))
acceleration_aligned = nap.compute_perievent_continuous(acceleration_z, tref=movement_onset_times, minmax=(trial_start_s, trial_end_s))
peth = nap.compute_perievent(timestamps=spikes, tref=movement_onset_times, minmax=(trial_start_s, trial_end_s))

print(f"Aligned position shape: {position_aligned.shape} (time x trials)")

In [None]:
# Create time windows
starts_ms = np.arange(TIME_RANGE_MS[0], TIME_RANGE_MS[1] - WINDOW_SIZE_MS + STEP_SIZE_MS, STEP_SIZE_MS)
ends_ms = starts_ms + WINDOW_SIZE_MS
n_time_bins = len(starts_ms)
n_trials = len(trials)

# Feature names
feature_names = ["Direction", "Position", "Velocity", "Acceleration", "RT"]
n_features = len(feature_names)

# Build tensors: X is (time_bins, trials, features), y is (time_bins, trials)
X = np.zeros((n_time_bins, n_trials, n_features))
y = np.zeros((n_time_bins, n_trials))  # spike counts

for win_index, (start_ms, end_ms) in enumerate(zip(starts_ms, ends_ms)):
    interval = nap.IntervalSet(start=start_ms / 1000.0, end=end_ms / 1000.0)

    # Spike counts in this window
    spikes_in_window = peth.restrict(interval)
    spike_counts = spikes_in_window.count()
    y[win_index, :] = spike_counts.values.flatten()

    # Kinematic features at window center
    t_center = spike_counts.t.item() if hasattr(spike_counts.t, "item") else spike_counts.t[0]
    X[win_index, :, 0] = direction_values
    X[win_index, :, 1] = position_aligned.restrict(interval).get(t_center)
    X[win_index, :, 2] = velocity_aligned.restrict(interval).get(t_center)
    X[win_index, :, 3] = acceleration_aligned.restrict(interval).get(t_center)
    X[win_index, :, 4] = reaction_times_z  # RT is constant across time bins

print(f"Design matrix X: {X.shape} (time_bins, trials, features)")
print(f"Spike counts y: {y.shape}")
print(f"Total spikes: {y.sum():.0f}")

### Fitting GLMs with NeMoS

Here's where NeMoS shines. Instead of writing loops to fit 45 separate models (one per time bin), we use JAX's `vmap` to vectorize the fitting. This is both faster and cleaner.

**Key NeMoS concepts:**

- **`nmo.glm.GLM()`** - The model object, which by default uses a Poisson observation model with exponential link function

- **`regularizer="Ridge"`** - L2 regularization adds a penalty on large coefficients. This is important when features are correlated (like velocity and acceleration) because without regularization, the optimizer can find many equivalent solutions, leading to unstable or extreme coefficients.

- **`regularizer_strength=0.1`** - Controls how strongly we penalize large coefficients. Higher values = more shrinkage toward zero. We use 0.1 as a reasonable default; you could tune this via cross-validation.

- **`solver_kwargs`** - Fine-tune the underlying optimizer (JAXopt's gradient descent). The defaults work for most cases, but time-resolved fitting across many windows can hit edge cases:
  - **`stepsize=0.001`** - Learning rate for gradient descent. Smaller values converge more slowly but more reliably. The default (0.01) can overshoot and diverge on some time bins.
  - **`acceleration=False`** - Disables Nesterov momentum. Acceleration speeds up convergence but can cause oscillations when the loss surface is tricky. Some time bins have very few spikes, creating ill-conditioned problems where acceleration hurts.
  - **`maxiter=5000`** - Maximum optimization steps. More iterations ensure convergence even with the smaller stepsize.

These conservative settings ensure we get valid coefficients for *all* 45 time bins, not just most of them. Without them, you may see NaN values in certain windows where the optimizer diverged.

In [None]:
# Configure the GLM
# Ridge regularization stabilizes fitting with correlated features
# Smaller stepsize and no acceleration help convergence across all time bins
model = nmo.glm.GLM(
    regularizer="Ridge",
    regularizer_strength=0.1,
    solver_kwargs={"stepsize": 0.001, "acceleration": False, "maxiter": 5000},
)

# Initialize parameters
weights_init = np.zeros(n_features)
intercept_init = np.array([-1.0])

# This is the key step: use JAX vmap to fit all time bins in parallel
model.instantiate_solver()
vmap_fit = jax.vmap(model.solver_run, in_axes=(None, 0, 0))
(coefficients, intercepts), _ = vmap_fit((weights_init, intercept_init), X, y)

coefficients = np.array(coefficients)
intercepts = np.array(intercepts).flatten()

print(f"Coefficients shape: {coefficients.shape} (time_bins x features)")
print(f"Fitting complete!")

### Understanding `vmap`

Recall that we need to fit 45 separate GLMs, one for each time window. The naive approach would be a Python loop:

```python
# Slow loop version
coefficients = []
for t in range(n_time_bins):
    model.fit(X[t], y[t])
    coefficients.append(model.coef_)
```

This works, but it is slow. Each iteration has Python overhead, and the fits run sequentially. With 45 time bins and 100 shuffle iterations for significance testing, that is 4,500 model fits.

JAX's `vmap` (vectorized map) solves this by transforming a function that operates on single arrays into one that operates on batches. The key line is:

```python
vmap_fit = jax.vmap(model.solver_run, in_axes=(None, 0, 0))
```

This says: "Take `solver_run`, which fits one GLM, and create a new function that applies it across axis 0 of X and y simultaneously." The `in_axes=(None, 0, 0)` specifies:
- `None`: Keep the initial parameters the same for all fits
- `0`: Iterate over the first axis of X (the time bins)
- `0`: Iterate over the first axis of y (the time bins)

Now instead of 45 sequential Python calls, we have one vectorized call that JAX compiles into efficient parallel code. On a GPU, all 45 fits can literally run in parallel. Even on CPU, the compiled code avoids Python overhead and enables SIMD optimizations.

This pattern is essential for time-resolved neural analysis where you fit many models across time, neurons, or cross-validation folds.

### Visualizing the Coefficients

Now we can see how each kinematic feature's influence on firing rate evolves through the movement.

In [None]:
# Plot all coefficients
fig, ax = plt.subplots(figsize=(14, 6))
colors = ["tab:green", "tab:blue", "tab:orange", "tab:red", "tab:purple"]

for i, name in enumerate(feature_names):
    ax.plot(starts_ms, coefficients[:, i], color=colors[i], linewidth=3, label=name)

ax.axhline(0, color="gray", linestyle="--", alpha=0.5, linewidth=1.5)
ax.axvline(0, color="black", linestyle="--", linewidth=2, label="Movement onset")
ax.set_xlabel("Time from movement onset (ms)", fontsize=18, fontweight="bold")
ax.set_ylabel("GLM Coefficient", fontsize=18, fontweight="bold")
ax.legend(loc="upper right", fontsize=14, frameon=False)
ax.set_title(f"GLM Coefficients Across Time\n{asset.path}", fontsize=20, fontweight="bold")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
ax.tick_params(axis='both', which='major', labelsize=14, width=1.5, length=6)
ax.grid(True, alpha=0.3, linewidth=0.8)
plt.tight_layout()
plt.show()

### Statistical Validation: Shuffle Test

A coefficient being non-zero does not mean it is *significant*. The GLM will always find some coefficients that minimize the loss, even if the features have no real relationship to spiking. We need to ask: could we have obtained this coefficient by chance?

**The idea behind permutation testing:**

The null hypothesis is that there is no relationship between kinematic features and spike counts. If this were true, which trial had which spike count would be arbitrary. Shuffling the spike counts across trials simulates this null world: the features (X) stay the same, but the spike counts (y) are randomly reassigned to different trials.

**What "shuffling trials" means concretely:**

For each time bin, we have spike counts for 20 trials: `y = [3, 5, 2, 8, 1, ...]`. Shuffling means randomly reordering these values: `y_shuffled = [8, 1, 5, 2, 3, ...]`. Now trial 1's features (direction, position, velocity, etc.) are paired with trial 4's spike count.

**What information is preserved vs. destroyed:**

Shuffling preserves the *marginal* statistics:
- Total spike count stays the same
- Mean firing rate stays the same  
- Variance of spike counts stays the same
- Distribution of each feature stays the same

What is destroyed is the **joint distribution**, the covariance between features and spikes. Consider this example:

| Trial | Direction | Spikes |
|-------|-----------|--------|
| 1 | Flexion (+1) | 8 |
| 2 | Extension (-1) | 2 |
| 3 | Flexion (+1) | 7 |
| 4 | Extension (-1) | 3 |

The GLM learns: "Flexion trials have ~7.5 spikes, extension trials have ~2.5 spikes, so the direction coefficient is positive."

After shuffling the spike counts:

| Trial | Direction | Spikes (shuffled) |
|-------|-----------|--------|
| 1 | Flexion (+1) | 3 |
| 2 | Extension (-1) | 8 |
| 3 | Flexion (+1) | 2 |
| 4 | Extension (-1) | 7 |

Now the GLM sees no consistent relationship between direction and spike count. The coefficient will be near zero or even reversed.

**The procedure (following Pasquereau & Turner, 2016):**

The paper states: *"To test whether individual coefficients were significant, we shuffled spike counts 1000 times across trials and compared actual coefficients to the confidence intervals yielded by shuffling [P = 0.05/(52 independent time bins) to compensate for multiple comparisons]."*

1. Fit the GLM on real data to get actual coefficients
2. Shuffle spike counts across trials (independently for each time bin)
3. Refit the GLM on shuffled data to get null coefficients
4. Repeat steps 2-3 many times (1000 in the paper, 100 here for speed)
5. Compute p-value: what fraction of null coefficients are as extreme as the actual coefficient?
6. Apply Bonferroni correction for multiple time bins

**Why this works:**

This is a non-parametric test. We do not assume coefficients follow a normal distribution or any other parametric form. Instead, we empirically construct the null distribution from the data itself. This is particularly valuable for GLMs where the sampling distribution of coefficients can be complex, especially with small sample sizes.

In [None]:
N_SHUFFLES = 100  # Use 1000 for publication-quality results

rng = np.random.default_rng(42)
null_coeffs = np.zeros((N_SHUFFLES, n_time_bins, n_features))

print(f"Running {N_SHUFFLES} shuffle iterations...")
for i in range(N_SHUFFLES):
    if (i + 1) % 50 == 0:
        print(f"  Shuffle {i + 1}/{N_SHUFFLES}")
    
    # Shuffle spike counts across trials (independently for each time bin)
    y_shuffled = np.array([rng.permutation(y[t]) for t in range(n_time_bins)])
    (null_coeffs[i], _), _ = vmap_fit((weights_init, intercept_init), X, y_shuffled)

print("Done!")

In [None]:
# Compute two-tailed p-values
p_values = np.zeros((n_time_bins, n_features))
for t in range(n_time_bins):
    for f in range(n_features):
        p_values[t, f] = (np.abs(null_coeffs[:, t, f]) >= np.abs(coefficients[t, f])).mean()

# Bonferroni correction for multiple time bins
bonferroni_threshold = 0.05 / n_time_bins

print(f"Bonferroni-corrected threshold: p < {bonferroni_threshold:.4f}")
print(f"\nSignificant time bins per feature:")
for f, name in enumerate(feature_names):
    n_sig = (p_values[:, f] < bonferroni_threshold).sum()
    print(f"  {name}: {n_sig}/{n_time_bins} ({100*n_sig/n_time_bins:.0f}%)")

### Coefficients with Significance Markers

Now we can add stars to show which time bins have statistically significant encoding.

In [None]:
# Plot coefficients with significance markers
fig, ax = plt.subplots(figsize=(14, 6))

for i, name in enumerate(feature_names):
    ax.plot(starts_ms, coefficients[:, i], color=colors[i], linewidth=3, label=name)
    
    # Mark significant time bins with stars
    sig_mask = p_values[:, i] < bonferroni_threshold
    ax.scatter(starts_ms[sig_mask], coefficients[sig_mask, i], 
               color=colors[i], s=150, marker="*", zorder=5)

ax.axhline(0, color="gray", linestyle="--", alpha=0.5, linewidth=1.5)
ax.axvline(0, color="black", linestyle="--", linewidth=2)
ax.set_xlabel("Time from movement onset (ms)", fontsize=18, fontweight="bold")
ax.set_ylabel("GLM Coefficient", fontsize=18, fontweight="bold")
ax.legend(loc="upper right", fontsize=14, frameon=False)
ax.set_title(f"GLM Coefficients with Significance (Bonferroni p<{bonferroni_threshold:.4f})\n{nwbfile.session_id}", fontsize=20, fontweight="bold")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
ax.tick_params(axis='both', which='major', labelsize=14, width=1.5, length=6)
ax.grid(True, alpha=0.3, linewidth=0.8)
plt.tight_layout()
plt.show()

### Examining the Null Distribution

Let's visualize what the shuffle test actually does. For one time bin, we can compare the actual coefficient to its null distribution.

In [None]:
# Pick a time bin around movement onset
sample_time_bin = n_time_bins // 2

fig, axes = plt.subplots(1, n_features, figsize=(18, 5))

for i, (ax, name) in enumerate(zip(axes, feature_names)):
    null_dist = null_coeffs[:, sample_time_bin, i]
    actual_coef = coefficients[sample_time_bin, i]
    
    ax.hist(null_dist, bins=30, alpha=0.7, color=colors[i], edgecolor="black", linewidth=1.2)
    ax.axvline(actual_coef, color="red", linewidth=3, linestyle="--", label=f"Actual: {actual_coef:.2f}")
    ax.axvline(-actual_coef, color="red", linewidth=3, linestyle="--", alpha=0.5)
    
    ax.set_xlabel("Coefficient", fontsize=14, fontweight="bold")
    ax.set_ylabel("Count", fontsize=14, fontweight="bold")
    ax.set_title(f"{name}\np = {p_values[sample_time_bin, i]:.3f}", fontsize=16, fontweight="bold")
    ax.legend(fontsize=11, frameon=False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.tick_params(axis='both', which='major', labelsize=12, width=1.5, length=5)

fig.suptitle(f"Null Distributions at t = {starts_ms[sample_time_bin]:.0f} ms", fontsize=20, fontweight="bold")
plt.tight_layout()
plt.show()

### Model Validation: Pseudo-R²

NeMoS also supports cross-validation via pseudo-R². This tells us how well the model generalizes to held-out data, not just how well it fits the training data.

In [None]:
# Split trials 80/20
rng_cv = np.random.default_rng(123)
n_train = int(n_trials * 0.8)
train_indices = rng_cv.choice(n_trials, n_train, replace=False)
test_indices = np.setdiff1d(np.arange(n_trials), train_indices)

X_train, X_test = X[:, train_indices, :], X[:, test_indices, :]
y_train, y_test = y[:, train_indices], y[:, test_indices]

# Fit on training data
(coeffs_train, intercepts_train), _ = vmap_fit((weights_init, intercept_init), X_train, y_train)

# Compute pseudo-R² on test data for each time bin
pseudo_r2 = np.zeros(n_time_bins)
for t in range(n_time_bins):
    model_eval = nmo.glm.GLM()
    model_eval.coef_ = coeffs_train[t]
    model_eval.intercept_ = intercepts_train[t]
    try:
        pseudo_r2[t] = model_eval.score(X_test[t], y_test[t], score_type="pseudo-r2-Cohen")
    except Exception:
        pseudo_r2[t] = np.nan

print(f"Train trials: {n_train}, Test trials: {len(test_indices)}")
print(f"Mean pseudo-R²: {np.nanmean(pseudo_r2):.4f}")
print(f"Max pseudo-R²: {np.nanmax(pseudo_r2):.4f}")

In [None]:
# Plot pseudo-R² across time
fig, ax = plt.subplots(figsize=(14, 5))

ax.plot(starts_ms, pseudo_r2, color="tab:purple", linewidth=3)
ax.fill_between(starts_ms, 0, pseudo_r2, alpha=0.3, color="tab:purple", where=(pseudo_r2 > 0))
ax.fill_between(starts_ms, 0, pseudo_r2, alpha=0.3, color="tab:red", where=(pseudo_r2 <= 0))
ax.axhline(0, color="gray", linestyle="--", alpha=0.5, linewidth=1.5)
ax.axvline(0, color="black", linestyle="--", linewidth=2)

ax.set_xlabel("Time from movement onset (ms)", fontsize=18, fontweight="bold")
ax.set_ylabel("Pseudo-R² (Cohen)", fontsize=18, fontweight="bold")
ax.set_title(f"Model Generalization: Train/Test Split ({n_train}/{len(test_indices)} trials)\n{nwbfile.session_id}", fontsize=20, fontweight="bold")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
ax.tick_params(axis='both', which='major', labelsize=14, width=1.5, length=6)
ax.grid(True, alpha=0.3, linewidth=0.8)
plt.tight_layout()
plt.show()

### Interpreting Pseudo-R²

**Positive R²**: Model predicts held-out data better than a mean-only (null) model.

**Negative R²**: Model is *worse* than just predicting the mean firing rate. This often indicates overfitting due to small sample sizes.

With only ~20 trials split 80/20, we have just 4 test trials, which is not enough for reliable cross-validation. The paper addressed this by using shuffle-based significance rather than cross-validation.

### Significance Heatmap

A compact way to visualize which features are significant at which times.

In [None]:
# Significance heatmap
fig, ax = plt.subplots(figsize=(14, 5))

# Log-transform p-values for better visualization
p_log = -np.log10(p_values + 1e-10)
threshold_log = -np.log10(bonferroni_threshold)

im = ax.imshow(p_log.T, aspect="auto", cmap="viridis",
               extent=[starts_ms[0], starts_ms[-1], -0.5, n_features - 0.5])
for i in range(1, n_features):
    ax.axhline(i - 0.5, color="white", linewidth=1)
ax.axvline(0, color="white", linestyle="--", linewidth=2, alpha=0.8)

ax.set_yticks(range(n_features))
ax.set_yticklabels(feature_names, fontsize=14, fontweight="bold")
ax.set_xlabel("Time from movement onset (ms)", fontsize=18, fontweight="bold")

cbar = plt.colorbar(im, ax=ax)
cbar.set_label("-log10(p-value)", fontsize=14, fontweight="bold")
cbar.ax.tick_params(labelsize=12)
cbar.ax.axhline(threshold_log, color="red", linewidth=3)

ax.set_title(f"Significance Across Time (brighter = more significant)\n{nwbfile.session_id}", fontsize=20, fontweight="bold")
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
ax.tick_params(axis='x', which='major', labelsize=14, width=1.5, length=6)
plt.tight_layout()
plt.show()

### Summary: NeMoS Workflow

| Step | What | Code |
|------|------|------|
| 1 | Create model | `model = nmo.glm.GLM(regularizer="Ridge")` |
| 2 | Prepare solver | `model.instantiate_solver()` |
| 3 | Vectorize | `vmap_fit = jax.vmap(model.solver_run, in_axes=(None, 0, 0))` |
| 4 | Fit all bins | `(coefs, intercepts), _ = vmap_fit(init, X, y)` |
| 5 | Evaluate | `model.score(X_test, y_test, score_type="pseudo-r2-Cohen")` |

**Key advantages of NeMoS:**
- JAX backend enables GPU acceleration and vectorization
- Scikit-learn-like API (`.fit()`, `.score()`, `.predict()`)
- Built-in regularization options
- Compatible with pynapple for neural data handling

### What This Neuron Encodes

Based on the analysis above, this corticostriatal neuron shows:

| Feature | Pattern | Interpretation |
|---------|---------|----------------|
| **Direction** | Positive, sustained | Fires more for flexion movements |
| **Position** | Negative trend | Slight preference for flexed positions |
| **Velocity** | Weak/transient | Not strongly velocity-tuned |
| **Acceleration** | Near zero | Does not encode acceleration |
| **RT** | Near zero | Does not encode reaction time |

This is a **direction-selective neuron**. It primarily encodes *what* movement is being made (flexion vs extension), not the detailed kinematics of *how* it is being made.