In [None]:
import os

# Force CPU backend on Apple Silicon to avoid Metal issues
os.environ["JAX_PLATFORMS"] = "cpu"

import matplotlib.pyplot as plt

# Disable LaTeX rendering in matplotlib
plt.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "sans-serif"

import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

In [None]:
# %%
# Configuration
BASE_PATH = Path("/Users/laura.driscoll/Documents/data/VR foraging/vr_foraging_data")
MODEL_NAME = "snle_2M_lr0.0005_ts2000_h128_l8_b256_37feat"
ODOR_TYPE = "Methyl_Butyrate"


# %%
def plot_session_posteriors(posterior_data_path, save_fig=False):
    """Load and plot posterior samples from a saved file."""

    with open(posterior_data_path, "rb") as f:
        data = pickle.load(f)

    posterior_samples = data["posterior_samples"]
    n_windows = data["n_windows"]
    param_names = data["param_names"]
    prior_low = data["prior_low"]
    prior_high = data["prior_high"]

    # Setup colors
    cmap = plt.colormaps["rainbow"]
    gradient = np.linspace(0, 1, n_windows)
    colors_rgba = cmap(gradient)

    # Create figure
    fig, axes = plt.subplots(1, 4, figsize=(10, 2))
    axes = axes.flatten()

    for session_i, samples in enumerate(posterior_samples):
        for i in range(4):
            # Compute histogram
            counts, bins, _ = axes[i].hist(
                samples[:, i], bins=30, color=colors_rgba[session_i], edgecolor=None, alpha=0.3
            )

            # Posterior mode (bin center with max count)
            mode_index = np.argmax(counts)
            posterior_mode = (bins[mode_index] + bins[mode_index + 1]) / 2

            axes[i].axvline(posterior_mode, color=colors_rgba[session_i], linestyle=":", alpha=0.5)

            axes[i].set_xlabel(param_names[i])
            axes[i].set_ylabel("Frequency")
            axes[i].set_xlim(prior_low[i], prior_high[i])

    axes[-1].legend(["MAP estimate"], loc="center left", bbox_to_anchor=(1, 0.5))

    # Add title with session info
    fig.suptitle(f"{data['session_name']} - {data['odor_type']}", fontsize=10)

    plt.tight_layout()

    if save_fig:
        output_path = posterior_data_path.parent / "posterior_plot.png"
        fig.savefig(output_path, dpi=150, bbox_inches="tight")
        print(f"Saved figure to {output_path}")

    return fig


# %%
# Find all sessions with saved posteriors
posterior_dir = BASE_PATH / "posterior_samples" / MODEL_NAME / ODOR_TYPE

if not posterior_dir.exists():
    print(f"No posteriors found at {posterior_dir}")
else:
    session_dirs = sorted([d for d in posterior_dir.iterdir() if d.is_dir()])
    print(f"Found {len(session_dirs)} sessions")

# %%
# Plot all sessions
for session_dir in session_dirs:
    posterior_file = session_dir / "posterior_data.pkl"

    if posterior_file.exists():
        print(f"\nPlotting {session_dir.name}")
        fig = plot_session_posteriors(posterior_file, save_fig=True)
        plt.show()
    else:
        print(f"No posterior_data.pkl found in {session_dir.name}")