
# Synthetic dataset + inference (pcalib): full walkthrough

**What this notebook does (big picture):**
- Builds (or loads) a synthetic *low-rank*, *time-varying* dataset using the `pcalib` library's `Potential` class.
- Runs **two parameter sweeps** to test inference quality:
  1. **Animals sweep:** vary the number of animals \(D\) while keeping trials fixed.
  2. **Trials sweep:** vary the number of trials while keeping animals fixed.
- For each sweep, repeatedly **simulate data** and **run inference** to estimate per-mode statistics, saving summaries to disk so the process is resumable and results are easy to aggregate later.

**Key ideas:**
- The latent signal has **K = 2** components that trace a "corridor" trajectory over time (a normalized path with controlled offsets).
- We use `fit_statistics_from_dataset_diagonal(...)` to infer mode-wise statistics from synthetic recordings.
- We extract accuracy measures via `make_predictions(...)` (specifically, `"epsilon"` and `"rho"`), and cache arrays in `cached_results/` to resume across runs.
- Each sweep is repeated up to a cap (default: 50 attempts), enabling mean/SEM plots later.

> This notebook is designed for **clarity**: each code cell has a preceding Markdown block that explains what it does (both conceptually and in detail).



## Imports & paths

**What this cell does:**
- Imports standard libraries and the required `pcalib` functions/classes.
- Creates (if necessary) a `cached_results/` directory to store all outputs.
- Optionally sets a NumPy random seed so data generation is reproducible across runs.


In [1]:

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

from pcalib.functions import fit_statistics_from_dataset_diagonal, make_predictions
from pcalib.classes import Potential
from pcalib.utils import PCA_matlab_like, generate_gaussian_correlation_matrix

# Results directory (consistent spelling)
OUT = Path.cwd() / "cached_results"
OUT.mkdir(parents=True, exist_ok=True)

# Optional reproducibility
np.random.seed(0)



## Helper functions: resume/cap logic and scalar coercion

**What this cell does:**
- Defines small utilities used throughout the notebook:
  - `attempts_done(path)`: returns how many attempts (rows) are already saved in a given `.npy` file.
  - `append_rows_capped(path, new_block, cap)`: appends attempts along axis 0 but **never exceeds** `cap` total attempts in the saved file.
  - `_to_scalar(x)`: converts array-like values (incl. 0D/1D/2D NumPy/JAX) to plain Python floats for clean storage/logging.

**Details & rationale:**
- The sweeps are designed to be **resumable**. If you've already run 17 attempts out of 50, re-running the "attempts" cells will only append up to the cap.
- `_to_scalar` avoids shape/dtype surprises when moving values into NumPy arrays and then onto disk.


In [2]:

def attempts_done(path: Path):
    return np.load(path).shape[0] if path.exists() else 0

def append_rows_capped(path: Path, new_block, cap: int):
    """
    Append new rows on axis 0, but ensure the saved file has <= cap rows total.
    new_block must have shape [num_new, ...].
    """
    if new_block is None or len(new_block) == 0:
        return
    if path.exists():
        old = np.load(path)
        need = max(0, cap - old.shape[0])
        if need == 0:
            return  # already at cap
        out = np.concatenate([old, new_block[:need]], axis=0)
    else:
        out = new_block[:cap]

    np.save(path, out)  # overwrite existing file directly

def _to_scalar(x):
    a = np.asarray(x)
    if a.ndim == 0:
        return float(a)
    if a.ndim == 1:
        return float(a[0])
    if a.ndim == 2:
        return float(a[0, 0])
    return float(a.reshape(-1)[0])



## Latent signal generator: the 2D "corridor"

**What this cell does (conceptually):**
- Builds a **T×2** latent trajectory (`K=2`) that moves through a corridor-like path:
  - Linear drift in the first quarter (component 1).
  - A circular arc (via cos/sin) in the middle half (components 1 & 2).
  - Linear drift back in the last quarter.
  - Anti-symmetric vertical offsets (−ε for first half, +ε for second half) on component 2.

**What this cell does (details):**
- After constructing the piecewise path, we zero-mean each component, **variance-normalize**, then rescale to a target variance `var_array`.
- This becomes the **ground-truth** latent path `\bar{x}_t` used in the generative model.


In [3]:

def corridor_signal(T, var_array, epsilon_corridor):
    signal_array = np.zeros([T, 2])  # two components: K=2
    # Break everything into quarters
    signal_array[:T//4, 0] = np.linspace(-2, 0, T//4)

    n_points_middle = np.shape(np.arange(T//4, (3*T)//4))[0]
    signal_array[T//4:(3*T)//4, 0] = -np.cos(2*np.pi*np.arange(n_points_middle)/n_points_middle) + 1
    signal_array[T//4:(3*T)//4, 1] = -np.sin(2*np.pi*np.arange(n_points_middle)/n_points_middle)

    n_points_end = np.shape(signal_array[(3*T)//4:, 0])[0]
    signal_array[(3*T)//4:, 0] = np.linspace(0, -2, n_points_end)
    signal_array[:T//2, 1] -= epsilon_corridor
    signal_array[T//2:, 1] += epsilon_corridor

    signal_array -= np.mean(signal_array, 0, keepdims=True)
    signal_array /= np.sqrt(np.var(signal_array, 0, keepdims=True))
    signal_array[:, 0] *= np.sqrt(var_array[0])
    signal_array[:, 1] *= np.sqrt(var_array[1])

    return signal_array



## Global configuration & sweep grids

**What this cell does:**
- Centralizes all key parameters so you can change them in one place.
- Defines the **sweep grids**: `D_array` (number of animals) and `n_trials_array` (number of trials) used later.
- Sets the **cap** (`n_attempts`) on how many times each sweep is repeated (for bootstrapped stability).

**Notes:**
- `load_potential` controls whether we **generate** ground-truth and save it, or **load** previously saved objects. The default below is `False` so the notebook is self-contained on first run.
- If you've already generated and saved the Potentials once, set `load_potential=True` to skip regeneration and use the saved `.npz` bundles.


In [None]:

# Toggle: generate truth vs. load existing
load_potential = False   # set True if you've already generated/saved once and want to reload

# Baseline sizes (when generating)
T = 100
N_per_animal = 50
K = 2
D = 2
n_trials = 40

# Sweep grids
n_trials_array = np.arange(5, 50, 5)   # 5,10,...,45
D_array = np.arange(1, 6)              # 1..5

# Temporal correlation scales
tau_sigma = 2                          # within-trial smoothing kernel width
tau_xi   = 5                           # trial-to-trial temporal correlation scale

# Latent signal amplitude targets
var_array = [2, 1]
epsilon_corridor = 0.1

# Attempts cap (per sweep)
n_attempts = 50



## Build or load the generative `Potential`s

**What this cell does (conceptually):**
- Either **generates** and **saves** two `Potential` objects or **loads** them from disk:
  - **Many-animals Potential**: sized to the *largest* `D` in the sweep (so we can subselect animals from the largest `D`).
  - **Many-trials Potential**: sized to the baseline `D` (used for the trials sweep).

**What this cell does (details when generating):**
- Construct the ground-truth latent `bar_x = corridor_signal(...)`.
- Create a random orthonormal **mode basis** `\bar{e}` (QR-orthogonalized) and scale columns to have norm \(\sqrt{N}\).
- Define temporal correlation kernels:
  - `Z` (within-trial smoothing) via a Gaussian correlation matrix with scale `tau_sigma * sqrt(2)` - this is a correlation matrix of a pure white noise smoothed with a Gaussian kernel of with `tau_sigma`.
  - `Delta` (trial-to-trial temporal correlation) with scale `tau_xi`.
- Draw per-neuron noise scales `\bar{\sigma}`.
- Build per-animal selector blocks `G` so each animal's neurons form identity blocks along the diagonal.
- Set `Xi = 0` to disable neuron-dependent kernels (as per the paper's assumptions).
- Save:
  - The two `Potential` bundles (`many_animals_potential.npz`, `many_trials_potential.npz`),
  - Reference arrays (`D_array`, `n_trials_array`, and their baselines),
  - `tau_sigma` and some truth summaries for later comparison.

**Why this separation matters:**
- The animals sweep needs the *largest* version so we can subselect animals cleanly.
- The trials sweep needs a fixed-`D` version so we can subselect trials cleanly.


In [None]:

if not load_potential:
    # Derived sizes
    N = N_per_animal * D

    # Ground-truth latent
    bar_x = corridor_signal(T, var_array, epsilon_corridor)

    # Mode basis up to max D
    bar_e_largest = np.random.normal(0, 1, [N_per_animal * D_array[-1], K])
    bar_e_largest, _ = np.linalg.qr(bar_e_largest)  # orthogonalize
    bar_e_largest *= np.sqrt(N_per_animal * D_array[-1])  # scale columns
    bar_e = bar_e_largest[:N, :]

    # Temporal noise kernel (within-trial)
    Z = generate_gaussian_correlation_matrix(T, tau_sigma * np.sqrt(2))

    # Per-neuron noise scales up to max D
    bar_sigma_largest = np.abs(np.random.normal(1, 0.1, N_per_animal * D_array[-1]))
    bar_sigma = bar_sigma_largest[:N]

    # Trial-to-trial structure
    Delta = generate_gaussian_correlation_matrix(T, tau_xi)
    bar_xi_largest = np.zeros([D_array[-1], K])
    for k in range(K):
        bar_xi_largest[:, k] = np.sqrt(np.abs(np.random.normal(2/(k+1), 0.1/(k+1), D_array[-1])))

    G_largest = np.zeros([D_array[-1], N_per_animal * D_array[-1], N_per_animal * D_array[-1]])
    for d in range(D_array[-1]):
        sl = slice(d*N_per_animal, (d+1)*N_per_animal)
        G_largest[d, sl, sl] = np.eye(N_per_animal)

    bar_xi = bar_xi_largest[:D, :]
    G = G_largest[:D, :N, :N]

    # Neuron-dependent kernel disabled
    Xi = np.zeros([T, T])

    # Save truth summaries
    np.save(OUT / "true_mean_noise_variance.npy", np.sqrt(np.mean(bar_sigma**2 / n_trials)))
    np.save(OUT / "true_signal_variability.npy", np.var(bar_x, 0))

    # Save the two Potentials
    many_animals_potential = Potential(bar_sigma_largest, bar_e_largest, G_largest, bar_xi_largest, Z, Delta, bar_x, Xi)
    many_animals_potential.save_as_npz(str(OUT / "many_animals_potential.npz"))

    many_trials_potential = Potential(bar_sigma, bar_e, G, bar_xi, Z, Delta, bar_x, Xi)
    many_trials_potential.save_as_npz(str(OUT / "many_trials_potential.npz"))

    # Save references
    np.save(OUT / "D_array.npy", D_array)
    np.save(OUT / "D_reference.npy", D)
    np.save(OUT / "n_trials_array.npy", n_trials_array)
    np.save(OUT / "n_trials_reference.npy", n_trials)
    np.save(OUT / "tau_sigma.npy", tau_sigma)
else:
    # Load the saved bundles and references
    many_animals_potential = Potential.from_npz(str(OUT / "many_animals_potential.npz"))
    many_trials_potential = Potential.from_npz(str(OUT / "many_trials_potential.npz"))
    D_array = np.load(OUT / "D_array.npy")
    n_trials_array = np.load(OUT / "n_trials_array.npy")
    D = int(np.load(OUT / "D_reference.npy"))
    n_trials = int(np.load(OUT / "n_trials_reference.npy"))
    tau_sigma = np.load(OUT / "tau_sigma.npy").item()

    # Pull sizes/components from loaded objects
    bar_e_largest = many_animals_potential.bar_e
    bar_e = many_trials_potential.bar_e
    T, K = np.shape(many_animals_potential.bar_x)
    N = np.shape(many_trials_potential.bar_sigma)[0]
    G_largest = many_animals_potential.G
    bar_xi_largest = many_animals_potential.bar_xi
    N_per_animal = N // D



## Quick visualization: ground-truth latent trajectories (optional)

**What this cell does:**
- Plots the two latent components `\bar{x}_t` over time.
- Useful to verify the qualitative shape of the "corridor" signal (two components with different scalings and offsets).

**Details:**
- If you just generated the potentials, both `many_animals_potential` and `many_trials_potential` share the same `bar_x`, so plotting either is fine.


In [None]:

plt.figure()
plt.plot(many_animals_potential.bar_x)
plt.title("Ground-truth latent trajectories (K=2)")
plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.show()



## Animals sweep — setup (resume-aware)

**What this cell does:**
- Declares file paths for the results of the **animals sweep** (varying `D` with fixed `n_trials`).
- Computes how many **attempts** have already been performed (by reading the number of rows in each `.npy` file).
- Derives how many attempts remain up to the cap `n_attempts`.

**Files produced (shapes across attempts):**
- `epsilon_animals.npy`: shape `(attempts, |D_array|, K)`
- `rho_animals.npy`: shape `(attempts, |D_array|, K)`
- `signal_variability_animals.npy`: shape `(attempts, |D_array|, K)`


In [None]:

animals_files = {
    "epsilon": OUT / "epsilon_animals.npy",
    "rho": OUT / "rho_animals.npy",
    "sigvar": OUT / "signal_variability_animals.npy",
}
done_animals = max(
    attempts_done(animals_files["epsilon"]),
    attempts_done(animals_files["rho"]),
    attempts_done(animals_files["sigvar"]),
)
remaining_animals = max(0, n_attempts - done_animals)
remaining_animals



## Animals sweep — run attempts (outer loop over attempts, inner loop over `D_array`)

**What this cell does (conceptually):**
- For each *remaining* attempt up to the cap:
  1. **Simulate** a dataset with `n_trials` trials from the **many-animals** Potential.
  2. For each `D_current` in `D_array`:
     - **Subselect neurons** so we keep only the first `N_per_animal * D_current` neurons (i.e., that many animals).
     - **Compute PCA** on the trial-averaged data to get eigenvectors/scores, then **sign-align** the first `K` components with the true `\bar{e}` (so plots/predictions are consistent across attempts).
     - On the **very first** (global) attempt at the **baseline** `D`, save the PCA scores to `inferred_y_{n_trials}_trials.npy` (handy for visualization).
     - **Run inference** via `fit_statistics_from_dataset_diagonal(current_data, K, current_G, tau_sigma, gamma=0.1)` to estimate per-mode parameters.
     - **Make predictions** for each mode using `make_predictions(...)` and extract `"epsilon"` and `"rho"`, coercing to float with `_to_scalar(...)`.
     - **Record** the inferred signal variance per mode as a diagnostic (variance of inferred `bar_x` along time).
  3. **Append** this attempt's results to the `.npy` files using `append_rows_capped(...)`.

**Why PCA and sign-alignment?**
- PCA eigenvectors are only defined up to sign. To ensure consistent orientation with the ground truth (and across attempts), we align signs using the dot product with the true `\bar{e}` columns (normalized to \(\sqrt{N}\)).

**Loop structure:**
- **Outer**: attempts (resume-aware, up to `n_attempts` rows).
- **Inner**: `D_current` in `D_array` (e.g., 1..5 animals). For each, run PCA, inference, predictive summaries.


In [None]:

if remaining_animals > 0:
    for attempt in range(remaining_animals):
        print(f"[animals] attempt {done_animals + attempt + 1} of {n_attempts}")
        synth_data_large = many_animals_potential.generate_sample_data(n_samples=n_trials)

        epsilon_animals_new = np.zeros([np.shape(D_array)[0], K])
        rho_animals_new = np.zeros([np.shape(D_array)[0], K])
        signal_variability_new = np.zeros([np.shape(D_array)[0], K])

        for i, D_current in enumerate(D_array):
            # Subselect to current number of animals
            current_N = N_per_animal * D_current
            current_data = synth_data_large[:, :, :current_N]
            current_bar_e = np.array(bar_e_largest[:current_N, :])  # ensure mutability
            current_G = G_largest[:D_current, :current_N, :current_N]

            # PCA on trial-averaged data (shape T x current_N)
            coeff, score, _ = PCA_matlab_like(np.mean(current_data, 0))
            sign_array = np.zeros(K)
            for k in range(K):
                # Normalize true vector for fair dot product, target norm sqrt(N)
                current_bar_e[:, k] = current_bar_e[:, k] / np.linalg.norm(current_bar_e[:, k]) * np.sqrt(current_N)
                sign_array[k] = np.sign(np.dot(coeff[:, k], current_bar_e[:, k]))

            # Rescale first K components to match sqrt(N) convention and apply signs
            coeff = coeff[:, :K] * sign_array[np.newaxis, :] * np.sqrt(current_N)
            score = score[:, :K] * sign_array[np.newaxis, :] / np.sqrt(current_N)

            # Save PCA scores once at baseline D on the very first global attempt
            if done_animals + attempt == 0 and D_current == D:
                np.save(OUT / f"inferred_y_{n_trials}_trials.npy", score)

            # Inference per mode (diagonal approx)
            inferred_potentials, _ = fit_statistics_from_dataset_diagonal(current_data, K, current_G, tau_sigma, gamma=0.1)

            # Predictions + diagnostics
            for k in range(K):
                prediction_dict = make_predictions(inferred_potentials[k])
                epsilon_animals_new[i, k] = _to_scalar(prediction_dict["epsilon"])
                rho_animals_new[i, k]     = _to_scalar(prediction_dict["rho"])
                signal_variability_new[i, k] = np.var(inferred_potentials[k].bar_x, 0)[0]

        # Append one attempt (row) to each file, capped
        append_rows_capped(animals_files["epsilon"], epsilon_animals_new[np.newaxis, :, :], n_attempts)
        append_rows_capped(animals_files["rho"], rho_animals_new[np.newaxis, :, :], n_attempts)
        append_rows_capped(animals_files["sigvar"], signal_variability_new[np.newaxis, :, :], n_attempts)
else:
    print(f"[animals] already at cap ({n_attempts}) attempts; skipping.")



## Trials sweep — setup (resume-aware)

**What this cell does:**
- Declares file paths for the results of the **trials sweep** (varying the number of trials while keeping animals fixed).
- Computes existing attempts and remaining attempts up to the cap.

**Files produced (shapes across attempts):**
- `epsilon_trials.npy`: shape `(attempts, |n_trials_array|, K)`
- `rho_trials.npy`: shape `(attempts, |n_trials_array|, K)`
- `sqrt_mean_sigma_squared.npy`: shape `(attempts, |n_trials_array|)` — a global noise summary `\(\sqrt{\mathbb{E}_i[\bar{\sigma}_i^2]}\)`.


In [None]:

trials_files = {
    "epsilon": OUT / "epsilon_trials.npy",
    "rho": OUT / "rho_trials.npy",
    "sigma": OUT / "sqrt_mean_sigma_squared.npy",
}
done_trials = max(
    attempts_done(trials_files["epsilon"]),
    attempts_done(trials_files["rho"]),
    attempts_done(trials_files["sigma"]),
)
remaining_trials = max(0, n_attempts - done_trials)
remaining_trials



## Trials sweep — run attempts (outer loop over attempts, inner loop over `n_trials_array`)

**What this cell does (conceptually):**
- For each *remaining* attempt up to the cap:
  1. **Simulate** a dataset with the **maximum** number of trials from the **many-trials** Potential.
  2. For each `n_trials_current` in `n_trials_array`:
     - **Subselect trials**: keep the first `n_trials_current` trials.
     - **Compute PCA** on the trial-averaged data and **sign-align** to the baseline `\bar{e}`.
     - **Run inference** via `fit_statistics_from_dataset_diagonal(current_data, K, G, tau_sigma, gamma=0.1)`.
     - **Make predictions** for each mode (extract `"epsilon"` and `"rho"`).
     - **Compute a noise summary**: \(\sqrt{\mathbb{E}_i[\bar{\sigma}_i^2]}\) from the inferred Potential (mode 0 is used as representative).
  3. **Append** this attempt's results to disk, respecting the overall cap.

**Loop structure:**
- **Outer**: attempts (resume-aware).
- **Inner**: `n_trials_current` in `n_trials_array` (e.g., 5..45). For each, run PCA, inference, predictive summaries.


In [None]:

if remaining_trials > 0:
    G = many_trials_potential.G
    N_const = many_trials_potential.bar_sigma.shape[0]
    for attempt in range(remaining_trials):
        print(f"[trials] attempt {done_trials + attempt + 1} of {n_attempts}")
        synth_data_large = many_trials_potential.generate_sample_data(n_samples=n_trials_array[-1])

        epsilon_trials_new = np.zeros([np.shape(n_trials_array)[0], K])
        rho_trials_new = np.zeros([np.shape(n_trials_array)[0], K])
        sqrt_mean_sigma_squared_new = np.zeros([np.shape(n_trials_array)[0]])

        for i, n_trials_current in enumerate(n_trials_array):
            current_data = synth_data_large[:n_trials_current, :, :]

            # PCA on trial-averaged data (shape T x N)
            coeff, score, _ = PCA_matlab_like(np.mean(current_data, 0))
            sign_array = np.zeros(K)
            for k in range(K):
                sign_array[k] = np.sign(np.dot(coeff[:, k], many_trials_potential.bar_e[:, k]))

            # Rescale first K components to match sqrt(N) convention and apply signs
            coeff = coeff[:, :K] * sign_array[np.newaxis, :] * np.sqrt(N_const)
            score = score[:, :K] * sign_array[np.newaxis, :] / np.sqrt(N_const)

            # Inference
            inferred_potentials, _ = fit_statistics_from_dataset_diagonal(current_data, K, G, tau_sigma, gamma=0.1)

            # Aggregate noise summary (representative from mode 0)
            sqrt_mean_sigma_squared_new[i] = np.sqrt(np.mean(inferred_potentials[0].bar_sigma**2))

            # Predictions per mode
            for k in range(K):
                prediction_dict = make_predictions(inferred_potentials[k])
                epsilon_trials_new[i, k] = _to_scalar(prediction_dict["epsilon"])
                rho_trials_new[i, k]     = _to_scalar(prediction_dict["rho"])

        # Append one attempt (row) to each file, capped
        append_rows_capped(trials_files["epsilon"], epsilon_trials_new[np.newaxis, :, :], n_attempts)
        append_rows_capped(trials_files["rho"], rho_trials_new[np.newaxis, :, :], n_attempts)
        append_rows_capped(trials_files["sigma"], sqrt_mean_sigma_squared_new[np.newaxis, :], n_attempts)
else:
    print(f"[trials] already at cap ({n_attempts}) attempts; skipping.")



## Load cached arrays for analysis

**What this cell does:**
- Provides a small helper to load a `.npy` array if it exists, returning `None` otherwise.
- Loads all animals-sweep and trials-sweep arrays for downstream summarization and plotting.

**Why this matters:**
- After running attempts, you'll typically want to **aggregate** results across attempts (e.g., compute mean and SEM per grid point).


In [None]:

def load_attempts(path): 
    return np.load(path) if Path(path).exists() else None

eps_anim = load_attempts(OUT / "epsilon_animals.npy")
rho_anim = load_attempts(OUT / "rho_animals.npy")
sigvar_anim = load_attempts(OUT / "signal_variability_animals.npy")

eps_trials = load_attempts(OUT / "epsilon_trials.npy")
rho_trials = load_attempts(OUT / "rho_trials.npy")
sigma_trials = load_attempts(OUT / "sqrt_mean_sigma_squared.npy")



## Mean/SEM helper

**What this cell does:**
- Defines a helper function `mean_sem(x, axis=0)` that returns the mean and standard error of the mean (SEM) along the specified axis.
- We'll use this to aggregate across **attempts** (axis 0) to visualize how estimates evolve with `D` or with the number of trials.

**Details:**
- SEM = standard deviation / sqrt(n).
- We use `ddof=1` to get an unbiased standard deviation estimate.


In [None]:

def mean_sem(x, axis=0):
    m = x.mean(axis=axis)
    s = x.std(axis=axis, ddof=1) / np.sqrt(x.shape[axis])
    return m, s



## Example plots: inspecting sweep results (optional)

**What this cell does:**
- Produces a few example plots to visualize the dependence of predictions on the sweep variable:
  - `epsilon` (mode 1) vs **# animals**.
  - `rho` (mode 1) vs **# trials**.
- You can duplicate/adapt these blocks for mode 2 (index `1`) and for other metrics.

**Notes:**
- We guard plots with `if arr is not None` to avoid errors if you haven't run the sweeps yet.
- The arrays have the shape `(attempts, grid, K)`, so we take `mean_sem(..., axis=0)` to aggregate across attempts.


In [None]:

# Epsilon (mode 1) vs number of animals
if eps_anim is not None:
    m, s = mean_sem(eps_anim, axis=0)  # shape: |D_array| x K
    plt.figure()
    plt.errorbar(D_array, m[:,0], yerr=s[:,0])
    plt.title("Epsilon (mode 1) vs # animals")
    plt.xlabel("# animals (D)")
    plt.ylabel("epsilon")
    plt.show()

# Rho (mode 1) vs number of trials
if rho_trials is not None:
    m, s = mean_sem(rho_trials, axis=0)  # shape: |n_trials_array| x K
    plt.figure()
    plt.errorbar(n_trials_array, m[:,0], yerr=s[:,0])
    plt.title("Rho (mode 1) vs # trials")
    plt.xlabel("# trials")
    plt.ylabel("rho")
    plt.show()



## Notes & next steps

- **Resetting the cache:** To start over, delete the contents of the `cached_results/` folder.
- **Changing the model:** Try different `tau_sigma`, `tau_xi`, or `var_array` to see how inference behaves under different temporal/noise conditions.
- **Adding widgets:** For interactive exploration, add sliders (e.g., `ipywidgets`) for parameters like `n_attempts`, `D_array`, `n_trials_array` and re-run the sweep cells.
- **Parallelizing attempts:** If attempts are slow for your dataset size, consider distributing attempts across processes or machines; the on-disk append makes combining results straightforward.
- **Reproducibility:** Keep the `np.random.seed(...)` in the imports cell to ensure repeated runs produce the same synthetic data and results (useful for debugging changes).
