# Imports

In [None]:
import os
import glob
import re
from pathlib import Path

import numpy as np
import scipy.io as sio
from scipy.io import loadmat

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.colors import LinearSegmentedColormap

from sklearn.manifold import Isomap
from ripser import ripser as tda

# Preprocessing data

Run this section if you want to generate the data again. You don't need to do this as the data is already in the data folder. If you do want to regenerate all of the data then you need to set up your own directories and paths to provide a destination for the saved files.

In [None]:
save_path = r"YOUR PATH"
dir_path = r"C:\home\...\HD_attractor_model\data"
file_name_list = [r"\Inh_gain6Exc_gain9bimodal_vis_inp0gaus_ratio1visual_std0.262_rand_shuffle",
                  r"\Inh_gain6Exc_gain9bimodal_vis_inp0gaus_ratio1visual_std0.262_rand_shuffle_1",
                  r"\Inh_gain6Exc_gain9bimodal_vis_inp0gaus_ratio1visual_std0.262_rand_shuffle_2",
                  r"\Inh_gain6Exc_gain9bimodal_vis_inp0gaus_ratio1visual_std0.262_rand_shuffle_3",
                  r"\Inh_gain6Exc_gain9bimodal_vis_inp0gaus_ratio1visual_std0.262_rand_shuffle_4"]

In [None]:
for file_name in file_name_list:
    mat_file = loadmat(dir_path + file_name + ".mat")
    X = np.sqrt(mat_file['rHD'].T)
    embedding = Isomap(n_neighbors=200, n_components=10)
    tmp_emb = embedding.fit_transform(X)
    np.save(save_path+file_name+r'_tmp_emb.npy', tmp_emb) #np.load(save_path+r'\tmp_emb.npy')

In [None]:
def infer_suffix_from_stem(stem: str) -> str:
    m = re.search(r"_(\d+)$", stem)
    return f"_{m.group(1)}" if m else ""


def build_segments_to_exclude_list_notebook(mat_file) -> np.ndarray:

    starts = mat_file["stat_startIndices"][0]
    ends   = mat_file["stat_endIndices"][0]

    segments_to_exclude_list = []
    for start, end in zip(starts, ends):
        start = int(start)
        end   = int(end)
        segments_to_exclude_list.extend(list(range(start, end)))  # end excluded (NOTEBOOK)

    # de-duplicate while preserving order (NOTEBOOK)
    unique_list = []
    seen = set()
    for num in segments_to_exclude_list:
        if num not in seen:
            unique_list.append(num)
            seen.add(num)

    return np.asarray(unique_list, dtype=int)


def angular_differences_notebook(real_angles, decoded_angles):
    differences = []
    for real_angle, decoded_angle in zip(real_angles, decoded_angles):
        diff = abs(real_angle - decoded_angle)
        circular_diff = np.minimum(diff, 360 - diff)
        differences.append(circular_diff)
    return differences


def decoding_error_notebook_exact(mat_file) -> dict:
    Az = mat_file["Az"]
    real_Az = (Az + 180) % 360

    segments_to_exclude_list = build_segments_to_exclude_list_notebook(mat_file)

    loco_real_Az = np.delete(real_Az[0], segments_to_exclude_list, axis=0)
    stat_real_Az = real_Az[0][segments_to_exclude_list]

    loco_decodedHD = np.delete(mat_file["decodedHD"], segments_to_exclude_list, axis=0) + 180
    stat_decodedHD = mat_file["decodedHD"][segments_to_exclude_list] + 180

    loco_errors = angular_differences_notebook(loco_real_Az, loco_decodedHD)
    stat_errors = angular_differences_notebook(stat_real_Az, stat_decodedHD)

    loco_errors_list = [loco_errors[i][0] for i in range(len(loco_real_Az))]
    stat_errors_list = [stat_errors[i][0] for i in range(len(stat_real_Az))]

    mean_loco_errors = float(np.mean(loco_errors_list))
    sem_loco_errors  = float(np.std(loco_errors_list) / np.sqrt(len(loco_errors_list)))

    mean_stat_errors = float(np.mean(stat_errors_list)) if len(stat_errors_list) else np.nan
    sem_stat_errors  = float(np.std(stat_errors_list) / np.sqrt(len(stat_errors_list))) if len(stat_errors_list) else np.nan

    return dict(
        mean_loco_errors=mean_loco_errors,
        sem_loco_errors=sem_loco_errors,
        mean_stat_errors=mean_stat_errors,
        sem_stat_errors=sem_stat_errors,
        segments_to_exclude_list=segments_to_exclude_list,
    )


def compute_betti(isomap_emb: np.ndarray):
    X = np.asarray(isomap_emb)
    barcodes_01 = tda(X, maxdim=1, coeff=2)["dgms"]
    h0 = barcodes_01[0]
    h1 = barcodes_01[1]

    if len(X) > 1500:
        idx = np.random.choice(np.arange(len(X)), 1500, replace=False)
        X2 = X[idx]
    else:
        X2 = X
    barcodes_012 = tda(X2, maxdim=2, coeff=2)["dgms"]
    h2 = barcodes_012[2]
    return h0, h1, h2


def circular_diff_deg(a_deg, b_deg):
    a = np.asarray(a_deg, dtype=float)
    b = np.asarray(b_deg, dtype=float)
    return ((a - b + 180.0) % 360.0) - 180.0


def compute_pref_nonpref_mean_activity(rHD: np.ndarray, real_Az_deg_0_360: np.ndarray, azimuth_range_deg=30.0):
    N, T = rHD.shape
    HD_PD_deg = np.linspace(-180.0, 180.0, N, endpoint=False)

    pref_mean = np.empty(T, dtype=float)
    nonpref_mean = np.empty(T, dtype=float)

    for t in range(T):
        d = np.abs(circular_diff_deg((HD_PD_deg + 180.0) % 360.0, real_Az_deg_0_360[t]))
        pref_idx = np.where(d <= azimuth_range_deg)[0]
        nonpref_idx = np.where(d > azimuth_range_deg)[0]

        pref_mean[t] = float(np.mean(rHD[pref_idx, t])) if pref_idx.size else np.nan
        nonpref_mean[t] = float(np.mean(rHD[nonpref_idx, t])) if nonpref_idx.size else np.nan

    return pref_mean, nonpref_mean


def split_by_exclude_list_timeaxis1(rHD_NT: np.ndarray, exclude_list: np.ndarray):
    locomotion = np.delete(rHD_NT, exclude_list, axis=1)
    stationary = rHD_NT[:, exclude_list]
    return locomotion, stationary


def split_1d_by_exclude_list(arr_1d: np.ndarray, exclude_list: np.ndarray):
    loco = np.delete(arr_1d, exclude_list, axis=0)
    stat = arr_1d[exclude_list]
    return loco, stat


def process_one_run_notebook_exact(mat_path: Path, emb_path: Path, out_dir: Path, azimuth_range_deg=30.0):
    run_stem = mat_path.stem
    suffix = infer_suffix_from_stem(run_stem)

    mat = loadmat(mat_path, squeeze_me=False)  # keep MATLAB-like shapes

    rHD = np.asarray(mat["rHD"], dtype=float)         # (N,T)
    decodedHD = np.asarray(mat["decodedHD"], dtype=float)  # likely (T,1)
    Az = np.asarray(mat["Az"], dtype=float)           # likely (1,T)

    # real_Az exactly like notebook
    real_Az = (Az + 180) % 360  # (1,T)

    T = rHD.shape[1]
    if real_Az.shape[1] != T:
        raise ValueError(f"Az length mismatch in {run_stem}: rHD T={T}, Az T={real_Az.shape[1]}")

    # stationary list exactly like notebook
    segments_to_exclude_list = build_segments_to_exclude_list_notebook(mat)

    # Locomotion & stationary rHD matrices (NOTEBOOK)
    locomotion_rHD_matrix, stationary_rHD_matrix = split_by_exclude_list_timeaxis1(rHD, segments_to_exclude_list)

    # Angles (NOTEBOOK indexing: real_Az[0] is 1D)
    loco_real_Az, stat_real_Az = split_1d_by_exclude_list(real_Az[0], segments_to_exclude_list)

    # decodedHD splits (NOTEBOOK uses +180)
    loco_decodedHD = np.delete(decodedHD, segments_to_exclude_list, axis=0) + 180
    stat_decodedHD = decodedHD[segments_to_exclude_list] + 180

    # ---- decoding error EXACT notebook ----
    loco_errors = angular_differences_notebook(loco_real_Az, loco_decodedHD)
    stat_errors = angular_differences_notebook(stat_real_Az, stat_decodedHD)

    loco_errors_list = [loco_errors[i][0] for i in range(len(loco_real_Az))]
    stat_errors_list = [stat_errors[i][0] for i in range(len(stat_real_Az))]

    mean_loco_errors = float(np.mean(loco_errors_list))
    mean_stat_errors = float(np.mean(stat_errors_list)) if len(stat_errors_list) else np.nan

    # ---- embedding (already saved) ----
    tmp_emb = np.load(emb_path, allow_pickle=True)
    tmp_emb = np.asarray(tmp_emb)

    # NOTE: your earlier notebook used X = sqrt(rHD.T) for isomap,
    # but since you already saved tmp_emb, we just split it.
    loco_emb = np.delete(tmp_emb, segments_to_exclude_list, axis=0)
    stat_emb = tmp_emb[segments_to_exclude_list, :]

    # ---- pref / nonpref mean activity (same as before, but split by same exclude list) ----
    pref_mean_t, nonpref_mean_t = compute_pref_nonpref_mean_activity(
        rHD=rHD,
        real_Az_deg_0_360=real_Az[0],
        azimuth_range_deg=azimuth_range_deg,
    )

    loco_pref_mean_t, stat_pref_mean_t = split_1d_by_exclude_list(pref_mean_t, segments_to_exclude_list)
    loco_nonpref_mean_t, stat_nonpref_mean_t = split_1d_by_exclude_list(nonpref_mean_t, segments_to_exclude_list)

    mean_loco_pref = float(np.nanmean(loco_pref_mean_t))
    mean_loco_nonpref = float(np.nanmean(loco_nonpref_mean_t))
    mean_stat_pref = float(np.nanmean(stat_pref_mean_t))
    mean_stat_nonpref = float(np.nanmean(stat_nonpref_mean_t))

    # ---- betti ----
    h0_stat, h1_stat, h2_stat = compute_betti(stat_emb)
    h0_loco, h1_loco, h2_loco = compute_betti(loco_emb)

    # ---- save ----
    out_dir.mkdir(parents=True, exist_ok=True)

    np.save(out_dir / f"mean_loco_errors{suffix}.npy", mean_loco_errors)
    np.save(out_dir / f"mean_stat_errors{suffix}.npy", mean_stat_errors)

    np.save(out_dir / f"mean_loco_pref_HD_mean_activities{suffix}.npy", mean_loco_pref)
    np.save(out_dir / f"mean_loco_non_pref_HD_mean_activities{suffix}.npy", mean_loco_nonpref)
    np.save(out_dir / f"mean_stat_pref_HD_mean_activities{suffix}.npy", mean_stat_pref)
    np.save(out_dir / f"mean_stat_non_pref_HD_mean_activities{suffix}.npy", mean_stat_nonpref)

    np.save(out_dir / f"old_stat_betti_h0{suffix}.npy", h0_stat)
    np.save(out_dir / f"old_stat_betti_h1{suffix}.npy", h1_stat)
    np.save(out_dir / f"old_stat_betti_h2{suffix}.npy", h2_stat)

    np.save(out_dir / f"old_loco_betti_h0{suffix}.npy", h0_loco)
    np.save(out_dir / f"old_loco_betti_h1{suffix}.npy", h1_loco)
    np.save(out_dir / f"old_loco_betti_h2{suffix}.npy", h2_loco)

    return {
        "run": run_stem,
        "suffix": suffix,
        "mean_loco_errors": mean_loco_errors,
        "mean_stat_errors": mean_stat_errors,
        "T": int(T),
        "n_stat": int(len(segments_to_exclude_list)),
        "n_loco": int(T - len(segments_to_exclude_list)),
    }


def process_directory_notebook_exact(data_dir: str, azimuth_range_deg=30.0):
    data_dir = Path(data_dir)
    out_dir = data_dir

    emb_files = sorted(data_dir.glob("*_tmp_emb.npy"))
    if not emb_files:
        raise FileNotFoundError(f"No *_tmp_emb.npy files found in: {data_dir}")

    results = []
    for emb_path in emb_files:
        stem = emb_path.name.replace("_tmp_emb.npy", "")
        mat_path = data_dir / f"{stem}.mat"
        if not mat_path.exists():
            print(f"[SKIP] No matching .mat for: {emb_path.name}")
            continue

        print(f"[RUN] {stem}")
        res = process_one_run_notebook_exact(mat_path, emb_path, out_dir, azimuth_range_deg=azimuth_range_deg)
        results.append(res)

    return results


In [None]:
data_dir = r"C:\Users\hz3791\Documents\Git_repo_for_Alex_paper\HD_attractor_model\data"

results = process_directory_notebook_exact(data_dir, azimuth_range_deg=30.0)
print("Done:", len(results))
print(results[0])


# Plotting figures

In [None]:
def mean_and_sem(x):
    x = np.array(x, dtype=float)
    m = float(np.mean(x))
    sem = float(np.std(x, ddof=1) / np.sqrt(len(x))) if len(x) > 1 else 0.0
    return m, sem


def max_h1_length(betti):
    if len(betti) == 0:
        return 0.0
    betti = np.asarray(betti)
    betti = betti[np.isfinite(betti).all(axis=1)] if betti.size else betti
    if betti.size == 0:
        return 0.0
    return float(np.max(betti[:, 1] - betti[:, 0]))


def load_results_only(data_dir):
    results = []

    # Find all mean_loco_errors*.npy
    files = sorted(glob.glob(os.path.join(data_dir, "mean_loco_errors*.npy")))
    if len(files) == 0:
        raise RuntimeError("No mean_loco_errors*.npy files found – run pipeline first.")

    for f in files:
        base = os.path.basename(f)                 # e.g. "mean_loco_errors_2.npy"
        suffix = base[len("mean_loco_errors"):-4]  # → "_2"  or ""

        def load(name):
            p = os.path.join(data_dir, f"{name}{suffix}.npy")
            return np.load(p, allow_pickle=True)

        r = {
            "mean_loco_errors": float(load("mean_loco_errors")),
            "mean_stat_errors": float(load("mean_stat_errors")),

            "mean_loco_pref": float(load("mean_loco_pref_HD_mean_activities")),
            "mean_loco_nonpref": float(load("mean_loco_non_pref_HD_mean_activities")),

            "mean_stat_pref": float(load("mean_stat_pref_HD_mean_activities")),
            "mean_stat_nonpref": float(load("mean_stat_non_pref_HD_mean_activities")),

            "h1_len_loco": max_h1_length(load("old_loco_betti_h1")),
            "h1_len_stat": max_h1_length(load("old_stat_betti_h1")),
        }

        results.append(r)

    print(f"Loaded {len(results)} simulations from cached results.")
    return results


def stationary_index_list_from_ranges(starts, ends, T: int) -> np.ndarray:
    starts = np.asarray(starts).reshape(-1)
    ends   = np.asarray(ends).reshape(-1)

    # autodetect 0-based vs 1-based (same as your pipeline)
    if np.any(starts == 0) or np.any(ends == 0):
        starts0 = starts.astype(int)
        ends0   = ends.astype(int)
    else:
        starts0 = (starts.astype(int) - 1)
        ends0   = (ends.astype(int) - 1)

    idx = []
    for s, e in zip(starts0, ends0):
        s = max(0, min(T - 1, s))
        e = max(0, min(T - 1, e))
        if e >= s:
            idx.append(np.arange(s, e + 1, dtype=int))
    if not idx:
        return np.array([], dtype=int)
    return np.unique(np.concatenate(idx))


def load_representation(data_dir, rep_base):
    mat_path = os.path.join(data_dir, rep_base + ".mat")
    if not os.path.exists(mat_path):
        raise FileNotFoundError(f"Missing: {mat_path}")

    m = sio.loadmat(mat_path, squeeze_me=True)

    rHD = np.asarray(m["rHD"], dtype=float)              # (N, T)
    Az = np.asarray(m["Az"], dtype=float).reshape(-1)    # (T,)
    decodedHD = np.asarray(m.get("decodedHD", []), dtype=float).reshape(-1) if "decodedHD" in m else None

    T = rHD.shape[1]
    starts = m["stat_startIndices"]
    ends   = m["stat_endIndices"]
    segments_to_exclude_list = stationary_index_list_from_ranges(starts, ends, T)  # 0-based

    emb_path = os.path.join(data_dir, rep_base + "_tmp_emb.npy")
    if not os.path.exists(emb_path):
        raise FileNotFoundError(f"Missing embedding: {emb_path}")
    tmp_emb = np.load(emb_path, allow_pickle=True)
    tmp_emb = np.asarray(tmp_emb)
    if tmp_emb.ndim != 2 or tmp_emb.shape[1] < 2:
        raise ValueError(f"Embedding has shape {tmp_emb.shape}, expected (N,>=2)")

    # suffix "" or "_k" inferred from rep_base tail
    suffix = ""
    tail = rep_base.split("_")[-1]
    if tail.isdigit():
        suffix = "_" + tail

    def load_npy(name):
        return np.load(os.path.join(data_dir, f"{name}{suffix}.npy"), allow_pickle=True)

    rep = {
        "rep_base": rep_base,
        "suffix": suffix,
        "rHD": rHD,
        "Az": Az,
        "decodedHD": decodedHD,
        "segments_to_exclude_list": segments_to_exclude_list,
        "tmp_emb": tmp_emb,  # full embedding (run+stat)
        "h0_loco": load_npy("old_loco_betti_h0"),
        "h1_loco": load_npy("old_loco_betti_h1"),
        "h2_loco": load_npy("old_loco_betti_h2"),
        "h0_stat": load_npy("old_stat_betti_h0"),
        "h1_stat": load_npy("old_stat_betti_h1"),
        "h2_stat": load_npy("old_stat_betti_h2"),
    }
    return rep


def get_long_bettibars(h0, h1, h2):
    out = []
    lengths = []

    for H in [h0, h1, h2]:
        H = np.asarray(H)

        # Remove rows with NaN or Inf
        if H.size > 0:
            finite_mask = np.isfinite(H).all(axis=1)
            H = H[finite_mask]

        if H.size == 0:
            out.append([])
            lengths.append([])
            continue

        lens = (H[:, 1] - H[:, 0]).astype(float)

        out.append(H)
        lengths.append(lens)

    return out, lengths


def compute_corr_strip_like_notebook(rHD_loco_TN, rHD_stat_TN, seed=0):
    T_loco = rHD_loco_TN.shape[0]
    T_stat = rHD_stat_TN.shape[0]
    n = min(T_loco, T_stat)

    rng = np.random.default_rng(seed)

    # subsample timepoints (random, without replacement)
    loco_idx = rng.choice(T_loco, size=n, replace=False)
    stat_idx = rng.choice(T_stat, size=n, replace=False)

    loco_sub = rHD_loco_TN[loco_idx]
    stat_sub = rHD_stat_TN[stat_idx]

    loco_rHD_corr = np.corrcoef(loco_sub)  # (n, n)
    stat_rHD_corr = np.corrcoef(stat_sub)  # (n, n)

    upper = np.triu_indices(n, k=1)
    loco_vals = loco_rHD_corr[upper]
    stat_vals = stat_rHD_corr[upper]

    sorted_idx = np.argsort(loco_vals)
    sorted_loco = loco_vals[sorted_idx]
    sorted_stat = stat_vals[sorted_idx]
    diff = sorted_loco - sorted_stat

    corr_for_plot_list = np.array([sorted_loco, sorted_stat, diff]).T
    return corr_for_plot_list


def plot_panels_B_to_G(results_all, rep, dt=0.1, save_path=None, title=None):

    # ---- colors for bars ----
    stat_color = "#8fd2d9"
    run_color  = "#3b9099"

    # ---- Aggregate scalars across simulations (for error bars) ----
    loco_errs = [r["mean_loco_errors"] for r in results_all]
    stat_errs = [r["mean_stat_errors"] for r in results_all]

    loco_pref = [r["mean_loco_pref"] for r in results_all]
    loco_nonpref = [r["mean_loco_nonpref"] for r in results_all]
    stat_pref = [r["mean_stat_pref"] for r in results_all]
    stat_nonpref = [r["mean_stat_nonpref"] for r in results_all]

    h1_loco = [r["h1_len_loco"] for r in results_all]
    h1_stat = [r["h1_len_stat"] for r in results_all]

    loco_err_m, loco_err_sem = mean_and_sem(loco_errs)
    stat_err_m, stat_err_sem = mean_and_sem(stat_errs)

    loco_pref_m, loco_pref_sem = mean_and_sem(loco_pref)
    loco_nonpref_m, loco_nonpref_sem = mean_and_sem(loco_nonpref)
    stat_pref_m, stat_pref_sem = mean_and_sem(stat_pref)
    stat_nonpref_m, stat_nonpref_sem = mean_and_sem(stat_nonpref)

    h1_loco_m, h1_loco_sem = mean_and_sem(h1_loco)
    h1_stat_m, h1_stat_sem = mean_and_sem(h1_stat)

    # ---- Prep representative split (MATCH YOUR PIPELINE) ----
    rHD = rep["rHD"]  # (N, T)
    Az = rep["Az"].reshape(-1)  # (T,)
    real_Az = (Az + 180.0) % 360.0  # notebook

    stat_idx = rep["segments_to_exclude_list"].astype(int)
    T = rHD.shape[1]

    # build loco indices as complement (like np.delete in notebook)
    mask = np.ones(T, dtype=bool)
    mask[stat_idx] = False
    loco_idx = np.where(mask)[0]

    # embeddings split to match pipeline: tmp_emb has time on axis 0
    tmp_emb = rep["tmp_emb"]
    loco_emb = tmp_emb[loco_idx]
    stat_emb = tmp_emb[stat_idx]

    loco_real_Az = real_Az[loco_idx]
    stat_real_Az = real_Az[stat_idx]

    # Also rHD for corr strip wants time x neuron (T,N)
    loco_rHD_TN = rHD[:, loco_idx].T
    stat_rHD_TN = rHD[:, stat_idx].T

    # ---- Figure layout: mimic your notebook D structure inside panel area ----
    fig = plt.figure(figsize=(12, 8))
    outer = gridspec.GridSpec(
        2, 3,
        width_ratios=[1.6, 1.0, 1.0],
        height_ratios=[1.0, 1.0],
        wspace=0.35, hspace=0.45
    )

    # ----------------
    # Panel B: heatmap
    # ----------------
    axB = fig.add_subplot(outer[0, 0])
    imB = axB.imshow(rHD, aspect="auto", origin="lower")
    axB.set_title("B: rHD heatmap (rep sim)")
    axB.set_xlabel("Time (samples)")
    axB.set_ylabel("Neuron ID")
    if len(stat_idx) > 0:
        axB.axvline(int(stat_idx[0]), linestyle="--", linewidth=1)
    cbB = fig.colorbar(imB, ax=axB, fraction=0.046, pad=0.04)
    cbB.set_label("rHD")
    axB.set_xlim([260, 360])

    # --------------------------
    # Panel C: Kendall tau
    # --------------------------
    axC = fig.add_subplot(outer[0, 1])

    corr_for_plot_list = compute_corr_strip_like_notebook(loco_rHD_TN, stat_rHD_TN, seed=0)
    print("T_loco:", loco_rHD_TN.shape[0], "T_stat:", stat_rHD_TN.shape[0])

    colors = [(0, '#3853A4'), (0.5, 'white'), (1, '#ED1F24')]
    custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', colors)

    imC = axC.imshow(
        corr_for_plot_list[:, :3],
        cmap=custom_cmap,
        aspect='auto',
        vmin=-1, vmax=1,
        interpolation='none',
        origin='lower'
    )
    axC.set_title("C: Kendall tau")
    axC.set_xticks([0, 1, 2])
    axC.set_xticklabels(["Run", "Stat", "Diff"])
    axC.set_yticks([])

    axC.spines['right'].set_visible(False)
    axC.spines['top'].set_visible(False)
    plt.setp(axC.spines.values(), linewidth=1.5)
    axC.xaxis.set_tick_params(width=1.5)
    axC.yaxis.set_tick_params(width=1.5)

    # --------------------------
    # Panel D: embeddings + barcodes
    # --------------------------
    inner = gridspec.GridSpecFromSubplotSpec(
        6, 6, subplot_spec=outer[0, 2], wspace=0.0, hspace=0.25
    )

    # Top row: loco scatter (cols 0:3) and stat scatter (cols 3:6)
    axD_loco = fig.add_subplot(inner[0:3, 0:3])
    axD_stat = fig.add_subplot(inner[0:3, 3:6])

    axD_loco.scatter(loco_emb[:, 0], loco_emb[:, 1], s=1,
                     c=loco_real_Az, vmin=0, vmax=360, cmap='hsv')
    axD_loco.axis('off')

    axD_stat.scatter(stat_emb[:, 0], stat_emb[:, 1], s=1,
                     c=stat_real_Az, vmin=0, vmax=360, cmap='hsv')
    axD_stat.axis('off')

    # Barcodes underneath: 3 rows (H0,H1,H2) for loco and stat
    col_list = ['red', 'green', 'purple']

    to_plot_loco, _ = get_long_bettibars(rep["h0_loco"], rep["h1_loco"], rep["h2_loco"])
    to_plot_stat, _ = get_long_bettibars(rep["h0_stat"], rep["h1_stat"], rep["h2_stat"])

    # compute xmax = largest finite value present in any bar interval (both columns, loco+stat, all bettis)
    def _max_finite_endpoint(barsets):
        mx = 0.0
        for bars in barsets:
            B = np.asarray(bars)
            if B.size == 0:
                continue
            B = B[np.isfinite(B).all(axis=1)]
            if B.size == 0:
                continue
            mx = max(mx, float(np.max(B)))  
        return mx

    xmax = max(_max_finite_endpoint(to_plot_loco), _max_finite_endpoint(to_plot_stat))
    if not np.isfinite(xmax) or xmax <= 0:
        xmax = 1.0

    axL_base = None
    axS_base = None

    for curr_betti in range(3):
        axL = fig.add_subplot(
            inner[curr_betti + 3: curr_betti + 4, 0:3],
            sharex=axL_base if axL_base is not None else None
        )
        if axL_base is None:
            axL_base = axL

        bars = np.asarray(to_plot_loco[curr_betti])
        if bars.size != 0:
            bars = bars[np.isfinite(bars).all(axis=1)]
            for i, interval in enumerate(bars):
                axL.plot([interval[0], interval[1]], [i, i],
                         color=col_list[curr_betti], lw=3)

        axL.set_xlim(0, xmax)
        axL.set_yticks([])
        axL.spines['right'].set_visible(False)
        axL.spines['top'].set_visible(False)
        plt.setp(axL.spines.values(), linewidth=1.5)
        axL.xaxis.set_tick_params(width=1.5)
        axL.yaxis.set_tick_params(width=1.5)

        if curr_betti != 2:
            axL.set_xticks([])
            axL.tick_params(labelbottom=False)
        else:
            axL.set_xticks([])  # keep notebook style (no ticks)

        # stat barcode axis (sharex)
        axS = fig.add_subplot(
            inner[curr_betti + 3: curr_betti + 4, 3:6],
            sharex=axS_base if axS_base is not None else None
        )
        if axS_base is None:
            axS_base = axS

        bars = np.asarray(to_plot_stat[curr_betti])
        if bars.size != 0:
            bars = bars[np.isfinite(bars).all(axis=1)]
            for i, interval in enumerate(bars):
                axS.plot([interval[0], interval[1]], [i, i],
                         color=col_list[curr_betti], lw=1.5)

        axS.set_xlim(0, xmax)
        axS.set_yticks([])
        axS.spines['right'].set_visible(False)
        axS.spines['top'].set_visible(False)
        plt.setp(axS.spines.values(), linewidth=1.5)
        axS.xaxis.set_tick_params(width=1.5)
        axS.yaxis.set_tick_params(width=1.5)

        if curr_betti != 2:
            axS.set_xticks([])
            axS.tick_params(labelbottom=False)
        else:
            axS.set_xticks([])  

    # --------------------------
    # Panel E: mean firing rates (across sims)
    # --------------------------
    axE = fig.add_subplot(outer[1, 0])
    x = np.arange(2)
    width = 0.35

    axE.bar(x - width/2, [stat_pref_m, stat_nonpref_m], width,
            yerr=[stat_pref_sem, stat_nonpref_sem], capsize=4, label="Stat",
            color=stat_color)
    axE.bar(x + width/2, [loco_pref_m, loco_nonpref_m], width,
            yerr=[loco_pref_sem, loco_nonpref_sem], capsize=4, label="Run",
            color=run_color)
    axE.set_xticks(x)
    axE.set_xticklabels(["Az pref.", "Az non-pr."])
    axE.set_ylabel("Mean firing (a.u.)")
    axE.set_title("E: Mean firing rate (across sims)")
    axE.legend()

    # --------------------------
    # Panel F: decoding error (across sims)
    # --------------------------
    axF = fig.add_subplot(outer[1, 1])
    axF.bar([0, 1], [stat_err_m, loco_err_m],
            yerr=[stat_err_sem, loco_err_sem], capsize=4,
            color=[stat_color, run_color])
    axF.set_xticks([0, 1])
    axF.set_xticklabels(["Stat", "Run"])
    axF.set_ylabel("Decoding err. (deg)")
    axF.set_title("F: Decoding error (across sims)")

    # --------------------------
    # Panel G: H1 length (across sims)
    # --------------------------
    axG = fig.add_subplot(outer[1, 2])
    axG.bar([0, 1], [h1_stat_m, h1_loco_m],
            yerr=[h1_stat_sem, h1_loco_sem], capsize=4,
            color=[stat_color, run_color])
    axG.set_xticks([0, 1])
    axG.set_xticklabels(["Stat", "Run"])
    axG.set_ylabel("H1 length")
    axG.set_title("G: H1 length (across sims)")

    if title is not None:
        fig.suptitle(title, y=1.02)

    plt.tight_layout()

    if save_path is not None:
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.show()


In [None]:
data_dir = r"C:\home\...\HD_attractor_model\data"

results = load_results_only(data_dir)

rep_base = "Inh_gain6Exc_gain9bimodal_vis_inp0gaus_ratio1visual_std0.262_rand_shuffle_1"
rep = load_representation(data_dir, rep_base)

plot_panels_B_to_G(results_all=results, rep=rep, dt=0.1, save_path=None)
