# Quantum kicked rotor (Floquet) via split-operator + FFT: PLOTS

## Tests

In [4]:
#!/usr/bin/env python3
from __future__ import annotations

import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
OUTDIR = "../src/quantum/out_quantum"
FIGDIR = os.path.join("../figs/quantum")
os.makedirs(FIGDIR, exist_ok=True)
plt.style.use("science.mplstyle")

: 

In [35]:
def load_m2_csv(outdir: str) -> pd.DataFrame:
    path = os.path.join(outdir, "kr_m2_vs_n.csv")
    df = pd.read_csv(path)
    for c in ["n", "m2", "p2", "norm"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=["n", "p2"]).sort_values("n")
    return df


def load_dist_csv(outdir: str, fname: str) -> pd.DataFrame:
    path = os.path.join(outdir, fname)
    df = pd.read_csv(path)
    return df


def plot_p2_vs_n(outdir: str, figdir: str, title: str | None = None) -> str:
    df = load_m2_csv(outdir)

    fig, ax = plt.subplots(figsize=(4.3, 3.2), dpi=600)

    ax.plot(df["n"], df["p2"], lw=1.6, color="#3b528b")
    ax.set_xlabel(" \# kicks")
    ax.set_ylabel(r"$\langle p^2\rangle$")
    ax.grid(True, alpha=0.25)

    # The derivative of the norm should be very small
    norm_dev = float(np.max(np.abs(df["norm"].to_numpy() - 1.0)))
    ax.text(
        0.02, 0.98,
        f"max |norm-1| = {norm_dev:.2f}",
        transform=ax.transAxes,
        va="top", ha="left",
        fontsize=8,
        color="0.25",
    )

    if title:
        ax.set_title(title, fontsize=10)

    fig.tight_layout()
    outpath = os.path.join(figdir, "kr_p2_vs_n.pdf")
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)
    return outpath


def plot_pm_final(outdir: str, figdir: str, title: str | None = None) -> str:
    df = load_dist_csv(outdir, "kr_final_pm.csv")
    m = pd.to_numeric(df["m"], errors="coerce").to_numpy()
    prob = pd.to_numeric(df["prob"], errors="coerce").to_numpy()
    mask = np.isfinite(m) & np.isfinite(prob)
    m, prob = m[mask], prob[mask]

    fig, ax = plt.subplots(figsize=(4.3, 3.2), dpi=600)

    ax.plot(m, prob, lw=1.2, color="#3b528b")
    ax.set_xlabel("Discrete angular momentum m")
    ax.set_ylabel(r"$|\psi_m|^2$")
    ax.set_yscale("log")  # useful visualization
    ax.grid(True, alpha=0.25, which="both")

    if title:
        ax.set_title(title, fontsize=10)

    fig.tight_layout()
    outpath = os.path.join(figdir, "kr_pm_final.pdf")
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)
    return outpath


def plot_ptheta_final(outdir: str, figdir: str, title: str | None = None) -> str:
    df = load_dist_csv(outdir, "kr_final_ptheta.csv")
    th = pd.to_numeric(df["theta"], errors="coerce").to_numpy()
    prob = pd.to_numeric(df["prob"], errors="coerce").to_numpy()
    mask = np.isfinite(th) & np.isfinite(prob)
    th, prob = th[mask], prob[mask]

    fig, ax = plt.subplots(figsize=(4.3, 3.2), dpi=600)

    ax.plot(th, prob, lw=1.2, color="#3b528b")
    ax.set_xlabel(r"$\theta$")
    ax.set_ylabel(r"$|\psi(\theta)|^2$")
    ax.set_xlim(0, 2 * np.pi)
    ax.grid(True, alpha=0.25)

    if title:
        ax.set_title(title, fontsize=10)

    fig.tight_layout()
    outpath = os.path.join(figdir, "kr_ptheta_final.pdf")
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)
    return outpath


def compare_runs_p2(runs: list[dict], figdir_global: str, fname: str = "kr_compare_p2_vs_n.pdf") -> str:
    fig, ax = plt.subplots(figsize=(4.6, 3.2), dpi=600)

    for r in runs:
        df = load_m2_csv(r["outdir"])
        ax.plot(df["n"], df["p2"], lw=1.6, label=r["label"])

    ax.set_xlabel("Nombre de kicks n")
    ax.set_ylabel(r"$\langle p^2\rangle$")
    ax.grid(True, alpha=0.25)
    ax.legend(frameon=False, fontsize=8)

    fig.tight_layout()
    outpath = os.path.join(figdir_global, fname)
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)
    return outpath

In [36]:
if __name__ == "__main__":

    # Small title (with K and hbar_eff )
    title = "Quantum kicked rotor (FFT) — exemple"

    print("Saving plots to:", FIGDIR)
    print(" -", plot_p2_vs_n(OUTDIR, FIGDIR, title=title))
    print(" -", plot_pm_final(OUTDIR, FIGDIR, title="Final m distribution"))
    print(" -", plot_ptheta_final(OUTDIR, FIGDIR, title=r"Final $\theta$ distribution"))



Saving plots to: ../figs/quantum
 - ../figs/quantum/kr_p2_vs_n.pdf
 - ../figs/quantum/kr_pm_final.pdf
 - ../figs/quantum/kr_ptheta_final.pdf


## Poster-style multiplot

In [24]:
def moving_average(y: np.ndarray, window: int) -> np.ndarray:
    """
    Compute a centered moving average (with edge handling).

    Parameters
    ----------
    y : ndarray
        Input signal.
    window : int
        Window size in samples (must be >= 1). If window==1, returns y.

    Returns
    -------
    y_smooth : ndarray
        Smoothed signal with the same length as y.
    """
    if window <= 1:
        return y.copy()
    window = int(window)
    w = np.ones(window, dtype=float) / float(window)
    # pad edges to keep same length
    pad = window // 2
    ypad = np.pad(y, (pad, pad), mode="edge")
    return np.convolve(ypad, w, mode="valid")


def plateau_stats(n: np.ndarray, y: np.ndarray, n_start: int) -> tuple[float, float]:
    """
    Compute mean and std of y over a late-time plateau region.

    Parameters
    ----------
    n : ndarray
        Time index array (kicks).
    y : ndarray
        Observable values (e.g., <p^2>).
    n_start : int
        First kick index used to define the plateau (use e.g. 2000 of 4000).

    Returns
    -------
    mean : float
        Plateau mean.
    std : float
        Plateau standard deviation.
    """
    mask = n >= n_start
    if not np.any(mask):
        raise ValueError("plateau_stats: n_start is beyond the available range.")
    yy = y[mask]
    return float(np.mean(yy)), float(np.std(yy))


def fit_exponential_tail(m: np.ndarray, prob: np.ndarray, m_min: int, m_max: int, prob_floor: float = 1e-300):
    """
    Fit an exponential tail in momentum space by linear regression in semi-log scale.

    Assumption:
        prob(|m|) ~ exp(-|m|/xi)
    Then:
        log(prob) = a + b*|m|, with b = -1/xi.

    Parameters
    ----------
    m : ndarray (int)
        Momentum indices (can be negative).
    prob : ndarray (float)
        Probability distribution |psi_m|^2 (should sum to 1).
    m_min, m_max : int
        Fit window in |m|: use points with m_min <= |m| <= m_max.
        Choose m_min away from the central peak, and m_max before numerical floor.
    prob_floor : float
        Minimum probability to avoid log(0).

    Returns
    -------
    result : dict
        Keys: slope, intercept, xi, r2, npoints
    """
    abs_m = np.abs(m.astype(int))
    p = np.maximum(prob.astype(float), prob_floor)

    mask = (abs_m >= m_min) & (abs_m <= m_max) & np.isfinite(p) & (p > 0)
    x = abs_m[mask].astype(float)
    y = np.log(p[mask])

    if x.size < 10:
        raise ValueError("fit_exponential_tail: not enough points in the selected fit window.")

    # linear least squares: y = a + b x
    b, a = np.polyfit(x, y, 1)

    yhat = a + b * x
    ss_res = float(np.sum((y - yhat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2))
    r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else np.nan

    xi = float(-1.0 / b) if b < 0 else np.inf
    return {"slope": float(b), "intercept": float(a), "xi": xi, "r2": r2, "npoints": int(x.size)}


def load_outputs(outdir: str):
    """
    Load the CSV outputs produced by part2.py from a given outdir.

    Returns
    -------
    df_m2 : DataFrame with columns [n, m2, p2, norm]
    df_pm : DataFrame with columns [m, prob]
    df_pt : DataFrame with columns [theta, prob]
    """
    df_m2 = pd.read_csv(os.path.join(outdir, "kr_m2_vs_n.csv"))
    for c in ["n", "m2", "p2", "norm"]:
        df_m2[c] = pd.to_numeric(df_m2[c], errors="coerce")
    df_m2 = df_m2.dropna(subset=["n", "p2"]).sort_values("n")

    df_pm = pd.read_csv(os.path.join(outdir, "kr_final_pm.csv"))
    df_pt = pd.read_csv(os.path.join(outdir, "kr_final_ptheta.csv"))

    return df_m2, df_pm, df_pt


def poster_triptych(
    outdir: OUTDIR,
    figdir: FIGDIR,
    K: float | None = None,
    hbar_eff: float | None = None,
    ma_window: int = 101,
    plateau_start: int | None = None,
    tail_m_min: int = 50,
    tail_m_max: int = 400,
):
    """
    Create a single 3-panel poster-ready figure:
      (A) <p^2>(n) + moving average + plateau band,
      (B) log prob vs |m| with exponential tail fit (xi),
      (C) |psi(theta)|^2 vs theta.

    It also writes a small text file with the fitted xi and R^2.
    """

    # Robust plotting: avoid external LaTeX surprises
    mpl.rcParams.update({"text.usetex": False})

    df_m2, df_pm, df_pt = load_outputs(outdir)

    n = df_m2["n"].to_numpy(dtype=int)
    p2 = df_m2["p2"].to_numpy(dtype=float)
    norm = df_m2["norm"].to_numpy(dtype=float)

    p2_smooth = moving_average(p2, window=ma_window)

    if plateau_start is None:
        plateau_start = int(0.6 * n.max())  # default: last 40%
    p2_mean, p2_std = plateau_stats(n, p2, n_start=plateau_start)

    # momentum distribution
    m = pd.to_numeric(df_pm["m"], errors="coerce").to_numpy()
    pm = pd.to_numeric(df_pm["prob"], errors="coerce").to_numpy()
    mask = np.isfinite(m) & np.isfinite(pm)
    m = m[mask].astype(int)
    pm = pm[mask].astype(float)
    pm = pm / pm.sum()

    fit = fit_exponential_tail(m, pm, m_min=tail_m_min, m_max=tail_m_max)

    # theta distribution
    theta = pd.to_numeric(df_pt["theta"], errors="coerce").to_numpy()
    ptheta = pd.to_numeric(df_pt["prob"], errors="coerce").to_numpy()
    mask = np.isfinite(theta) & np.isfinite(ptheta)
    theta = theta[mask].astype(float)
    ptheta = ptheta[mask].astype(float)
    ptheta = ptheta / ptheta.sum()

    # ---------- Figure ----------
    fig, axes = plt.subplots(1, 3, figsize=(12.0, 3.4), dpi=600)
    ax0, ax1, ax2 = axes

    # (A) <p^2> vs n
    ax0.plot(n, p2, lw=1.0, alpha=0.35, color="#3b528b", label=r"$\langle p^2\rangle$")
    ax0.plot(n, p2_smooth, lw=1.6, color="#3b528b", label=f"MA({ma_window})")
    ax0.axhline(p2_mean, color="#5ec962", lw=1.2, ls="--", label="Plateau mean")
    ax0.fill_between(n, p2_mean - p2_std, p2_mean + p2_std, color="#5ec962", alpha=0.12, linewidth=0)

    ax0.set_xlabel("Number of kicks  n")
    ax0.set_ylabel(r"$\langle p^2\rangle$")
    ax0.grid(True, alpha=0.25)
    ax0.legend(frameon=False, fontsize=8, loc="best")

    norm_dev = float(np.max(np.abs(norm - 1.0)))
    ax0.text(
        0.02, 0.98,
        f"max|norm-1|={norm_dev:.1e}",
        transform=ax0.transAxes,
        va="top", ha="left",
        fontsize=8, color="0.25",
    )

    # (B) log prob vs |m| + fit
    abs_m = np.abs(m)
    order = np.argsort(abs_m)
    abs_m_sorted = abs_m[order]
    pm_sorted = pm[order]

    ax1.plot(abs_m_sorted, np.log(pm_sorted + 1e-300), lw=1.0, color="#3b528b")
    # fit line over [tail_m_min, tail_m_max]
    xfit = np.linspace(tail_m_min, tail_m_max, 200)
    yfit = fit["intercept"] + fit["slope"] * xfit
    ax1.plot(xfit, yfit, lw=2.0, color="#5ec962")

    ax1.set_xlabel(r"$|m|$")
    ax1.set_ylabel(r"$\log |\psi_m|^2$")
    ax1.grid(True, alpha=0.25)
    ax1.set_title(rf"$\xi \approx {fit['xi']:.1f}$,  $R^2={fit['r2']:.3f}$", fontsize=9)

    # (C) theta distribution
    ax2.plot(theta, ptheta, lw=1.2, color="#3b528b")
    ax2.set_xlim(0, 2 * np.pi)
    ax2.set_xlabel(r"$\theta$")
    ax2.set_ylabel(r"$|\psi(\theta)|^2$")
    ax2.grid(True, alpha=0.25)

    # Global title (ASCII + math)
    #title_parts = ["Quantum kicked rotor (FFT)"]
    #if K is not None:
    #    title_parts.append(f"K={K:g}")
    #if hbar_eff is not None:
    #    title_parts.append(rf"$\hbar_{{eff}}={hbar_eff:g}$")
    #fig.suptitle(" — ".join(title_parts), y=1.05, fontsize=11)

    fig.tight_layout()
    outfig = os.path.join(figdir, "kr_poster_triptych.png")
    fig.savefig(outfig, bbox_inches="tight")
    plt.close(fig)

    # write fit summary
    outf = os.path.join(figdir, "kr_localization_fit.txt")
    with open(outf, "w", encoding="utf-8") as f:
        f.write(f"Fit window: |m| in [{tail_m_min}, {tail_m_max}]\n")
        f.write(f"slope b = {fit['slope']:.6e}\n")
        f.write(f"xi = {-1.0/fit['slope'] if fit['slope'] < 0 else np.inf:.6f}\n")
        f.write(f"R^2 = {fit['r2']:.6f}\n")
        f.write(f"npoints = {fit['npoints']}\n")

    print("Saved:", outfig)
    print("Saved:", outf)

In [None]:
if __name__ == "__main__":
    poster_triptych(
        outdir=OUTDIR,
        figdir=FIGDIR,
        K=5.5,
        hbar_eff=1.0,
        ma_window=101,
        plateau_start=2000,   # o None (auto)
        tail_m_min=50,
        tail_m_max=400,
    )

Saved: ../figs/quantum/kr_poster_triptych.png
Saved: ../figs/quantum/kr_localization_fit.txt


: 

## 2nd version

In [25]:
def load_pm_snapshots_npz(outdir: str):
    path = os.path.join(outdir, "kr_pm_snapshots.npz")
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"Missing {path}. Re-run the simulation with save_snapshots=True "
            "so kr_pm_snapshots.npz is created."
        )
    d = np.load(path)
    n_snap = d["n_snap"]   # shape (T,)
    m = d["m"]             # shape (M,)
    pm = d["pm"]           # shape (T, M)
    return n_snap, m, pm


### with Viridis heatmap

In [39]:
def poster_triptych_heatmap(
    outdir: str,
    figdir: str,
    K: float | None = None,
    hbar_eff: float | None = None,
    ma_window: int = 101,
    plateau_start: int | None = None,
    tail_m_min: int = 50,
    tail_m_max: int = 400,
    mmax_vis: int = 300,
    vmin: float = -16,
    vmax: float = 0,
):
    """
    3-panel figure:
      (A) <p^2>(n) + moving average + plateau band,
      (B) log prob vs |m| with exponential tail fit (xi),
      (C) Heatmap of log10 P(m,n) with x = kicks and y = momentum (viridis).
    """
    mpl.rcParams.update({"text.usetex": False})

    df_m2, df_pm, df_pt = load_outputs(outdir)

    n = df_m2["n"].to_numpy(dtype=int)
    p2 = df_m2["p2"].to_numpy(dtype=float)
    norm = df_m2["norm"].to_numpy(dtype=float)

    p2_smooth = moving_average(p2, window=ma_window)

    if plateau_start is None:
        plateau_start = int(0.6 * n.max())
    p2_mean, p2_std = plateau_stats(n, p2, n_start=plateau_start)

    # momentum distribution (final)
    m = pd.to_numeric(df_pm["m"], errors="coerce").to_numpy()
    pm = pd.to_numeric(df_pm["prob"], errors="coerce").to_numpy()
    mask = np.isfinite(m) & np.isfinite(pm)
    m = m[mask].astype(int)
    pm = pm[mask].astype(float)
    pm = pm / pm.sum()

    fit = fit_exponential_tail(m, pm, m_min=tail_m_min, m_max=tail_m_max)

    # snapshots for panel C
    n_snap, m_snap, pm_snap = load_pm_snapshots_npz(outdir)

    # sort axes robustly
    idxn = np.argsort(n_snap)
    n_snap = n_snap[idxn]
    pm_snap = pm_snap[idxn, :]

    idxm = np.argsort(m_snap)
    m_snap = m_snap[idxm]
    pm_snap = pm_snap[:, idxm]

    # normalize each snapshot
    row_sums = pm_snap.sum(axis=1, keepdims=True)
    row_sums = np.where(row_sums > 0, row_sums, 1.0)
    pm_snap = pm_snap / row_sums

    # ---------- Figure ----------
    fig, axes = plt.subplots(1, 3, figsize=(12.0, 3.4), dpi=600)
    ax0, ax1, ax2 = axes

    # A bit more space between panels (default is ~0.2)
    fig.subplots_adjust(wspace=0.30)

    # (A)
    ax0.plot(n, p2, lw=1.0, alpha=0.35, color="#3b528b", label=r"$\langle p^2\rangle$")
    ax0.plot(n, p2_smooth, lw=1.6, color="#3b528b", label=f"MA({ma_window})")
    ax0.axhline(p2_mean, color="#5ec962", lw=1.2, ls="--", label="Plateau mean")
    ax0.fill_between(
        n, p2_mean - p2_std, p2_mean + p2_std,
        color="#5ec962", alpha=0.12, linewidth=0
    )
    ax0.set_xlabel("Number of kicks  n")
    ax0.set_ylabel(r"$\langle p^2\rangle$")
    ax0.grid(True, alpha=0.25)
    ax0.legend(frameon=False, fontsize=8, loc="best")

    norm_dev = float(np.max(np.abs(norm - 1.0)))
    ax0.text(
        0.02, 0.98,
        f"max|norm-1|={norm_dev:.2f}",
        transform=ax0.transAxes,
        va="top", ha="left",
        fontsize=8, color="0.25",
    )

    # (B)
    abs_m = np.abs(m)
    order = np.argsort(abs_m)
    abs_m_sorted = abs_m[order]
    pm_sorted = pm[order]

    ax1.plot(abs_m_sorted, np.log(pm_sorted + 1e-300), lw=1.0, color="#3b528b")
    xfit = np.linspace(tail_m_min, tail_m_max, 200)
    yfit = fit["intercept"] + fit["slope"] * xfit
    ax1.plot(xfit, yfit, lw=2.0, color="#5ec962")

    ax1.set_xlabel(r"$|m|$")
    ax1.set_ylabel(r"$\log(|\psi_m|^2)$")
    ax1.grid(True, alpha=0.25)
    ax1.set_title(rf"$\xi \approx {fit['xi']:.1f}$,  $R^2={fit['r2']:.3f}$", fontsize=9)

    # (C) heatmap with true x = n_snap
    eps = 1e-300
    maskm = (m_snap >= -mmax_vis) & (m_snap <= mmax_vis)
    m_crop = m_snap[maskm].astype(float)
    z = np.log10(pm_snap[:, maskm].astype(float) + eps)  # (T, Mcrop)

    # x edges from n_snap
    x = n_snap.astype(float)
    if x.size == 1:
        x_edges = np.array([x[0] - 0.5, x[0] + 0.5])
    else:
        mid = 0.5 * (x[1:] + x[:-1])
        x_edges = np.concatenate(
            [[x[0] - (mid[0] - x[0])], mid, [x[-1] + (x[-1] - mid[-1])]]
        )

    # m edges
    dm = float(m_crop[1] - m_crop[0]) if m_crop.size > 1 else 1.0
    m_edges = np.concatenate([[m_crop[0] - 0.5 * dm], m_crop + 0.5 * dm])

    im = ax2.pcolormesh(
        x_edges, m_edges, z.T,
        cmap="viridis",
        vmin=vmin, vmax=vmax,
        shading="auto",
    )

    ax2.set_xlabel("Number of kicks  n")
    ax2.set_ylabel(r"Momentum index  $m$")
    ax2.set_title(r"$\log_{10} P(m,n)$", fontsize=9)

    cbar = fig.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)

    outfig = os.path.join(figdir, "kr_poster_triptych.pdf")
    fig.savefig(outfig, bbox_inches="tight")
    plt.close(fig)

    outf = os.path.join(figdir, "kr_localization_fit.txt")
    with open(outf, "w", encoding="utf-8") as f:
        f.write(f"Fit window: |m| in [{tail_m_min}, {tail_m_max}]\n")
        f.write(f"slope b = {fit['slope']:.6e}\n")
        f.write(f"xi = {-1.0/fit['slope'] if fit['slope'] < 0 else np.inf:.6f}\n")
        f.write(f"R^2 = {fit['r2']:.6f}\n")
        f.write(f"npoints = {fit['npoints']}\n")

    print("Saved:", outfig)
    print("Saved:", outf)






In [40]:
if __name__ == "__main__":
    poster_triptych_heatmap(
        outdir=OUTDIR,
        figdir=FIGDIR,
        K=5.5,
        hbar_eff=1.0,
        ma_window=101,
        plateau_start=2000,   # o None (auto)
        tail_m_min=50,
        tail_m_max=400,
    )

Saved: ../figs/quantum/kr_poster_triptych.pdf
Saved: ../figs/quantum/kr_localization_fit.txt


### with 3D plot

In [None]:
def poster_triptych_3d(
    outdir: str,
    figdir: str,
    K: float | None = None,
    hbar_eff: float | None = None,
    ma_window: int = 101,
    plateau_start: int | None = None,
    tail_m_min: int = 50,
    tail_m_max: int = 400,
    mmax_vis: int = 200,
    n_curves: int = 40,          # (kept, but not used now if we force every-100)
    z_mode: str = "log10",       # "log10" or "linear"
    elev: float = 22,
    azim: float = -60,
    gap_ab: float = 0.22,
    gap_bc: float = 0.10,
):
    """
    3-panel figure:
      (A) <p^2>(n) + moving average + plateau band,
      (B) log prob vs |m| with exponential tail fit (xi),
      (C) 3D snapshots of P(m,n) stacked by kick number n (not snapshot index).

    IMPORTANT:
      Ensures that n=0 is plotted. If the snapshots file doesn't include n=0,
      it synthesizes the initial distribution in momentum (Gaussian in m)
      and prepends it as the first snapshot. 
    """
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

    mpl.rcParams.update({"text.usetex": False})

    df_m2, df_pm, df_pt = load_outputs(outdir)

    n = df_m2["n"].to_numpy(dtype=int)
    p2 = df_m2["p2"].to_numpy(dtype=float)
    norm = df_m2["norm"].to_numpy(dtype=float)

    p2_smooth = moving_average(p2, window=ma_window)

    if plateau_start is None:
        plateau_start = int(0.6 * n.max())
    p2_mean, p2_std = plateau_stats(n, p2, n_start=plateau_start)

    # momentum distribution (final)
    m = pd.to_numeric(df_pm["m"], errors="coerce").to_numpy()
    pm = pd.to_numeric(df_pm["prob"], errors="coerce").to_numpy()
    mask = np.isfinite(m) & np.isfinite(pm)
    m = m[mask].astype(int)
    pm = pm[mask].astype(float)
    pm = pm / pm.sum()

    fit = fit_exponential_tail(m, pm, m_min=tail_m_min, m_max=tail_m_max)

    # snapshots for panel C (3D): n_snap, m_snap, pm_snap[T,M]
    n_snap, m_snap, pm_snap = load_pm_snapshots_npz(outdir) 

    # sort axes robustly
    idxn = np.argsort(n_snap)
    n_snap = n_snap[idxn]
    pm_snap = pm_snap[idxn, :]

    idxm = np.argsort(m_snap)
    m_snap = m_snap[idxm]
    pm_snap = pm_snap[:, idxm]

    # normalize each snapshot
    row_sums = pm_snap.sum(axis=1, keepdims=True)
    row_sums = np.where(row_sums > 0, row_sums, 1.0)
    pm_snap = pm_snap / row_sums

    # --- FORCE presence of n=0 snapshot ---
    # If not present, synthesize from the initial state used in runsimulation:
    # Gaussian in momentum m centered at m0=0 with width sigmam0=2 (defaults). 
    if not np.any(n_snap == 0):
        m0 = 0.0
        sigmam0 = 2.0

        pm0 = np.exp(-0.5 * ((m_snap.astype(float) - m0) / sigmam0) ** 2)
        pm0 = pm0.astype(float)
        pm0 = pm0 / pm0.sum()

        n_snap = np.concatenate(([0], n_snap))
        pm_snap = np.vstack([pm0[None, :], pm_snap])

    # ---------- Figure ----------
    fig = plt.figure(figsize=(12.0, 3.4), dpi=600)

    gs = fig.add_gridspec(
        1, 5,
        width_ratios=[1.0, float(gap_ab), 1.0, float(gap_bc), 1.0],
        wspace=0.0,
    )

    ax0 = fig.add_subplot(gs[0, 0])
    ax1 = fig.add_subplot(gs[0, 2])
    ax2 = fig.add_subplot(gs[0, 4], projection="3d")

    # (A)
    ax0.plot(n, p2, lw=1.0, alpha=0.35, color="#3b528b", label=r"$\langle p^2\rangle$")
    ax0.plot(n, p2_smooth, lw=1.6, color="#3b528b", label=f"MA({ma_window})")
    ax0.axhline(p2_mean, color="#5ec962", lw=1.2, ls="--", label="Plateau mean")
    ax0.fill_between(
        n, p2_mean - p2_std, p2_mean + p2_std,
        color="#5ec962", alpha=0.12, linewidth=0
    )
    ax0.set_xlabel("Number of kicks  n")
    ax0.set_ylabel(r"$\langle p^2\rangle$")
    ax0.grid(True, alpha=0.25)
    ax0.legend(frameon=False, fontsize=8, loc="best")

    norm_dev = float(np.max(np.abs(norm - 1.0)))
    ax0.text(
        0.02, 0.98,
        f"max|norm-1|={norm_dev:.1f}",
        transform=ax0.transAxes,
        va="top", ha="left",
        fontsize=8, color="0.25",
    )

    # (B)
    abs_m = np.abs(m)
    order = np.argsort(abs_m)
    abs_m_sorted = abs_m[order]
    pm_sorted = pm[order]

    ax1.plot(abs_m_sorted, np.log(pm_sorted + 1e-300), lw=1.0, color="#3b528b")
    xfit = np.linspace(tail_m_min, tail_m_max, 200)
    yfit = fit["intercept"] + fit["slope"] * xfit
    ax1.plot(xfit, yfit, lw=2.0, color="#5ec962")

    ax1.set_xlabel(r"$|m|$")
    ax1.set_ylabel(r"$\log(|\psi_m|^2)$")
    ax1.grid(True, alpha=0.25)
    ax1.set_title(rf"$\xi \approx {fit['xi']:.1f}$,  $R^2={fit['r2']:.3f}$", fontsize=9)

    # (C) 3D snapshots: y-axis is kick number n
    eps = 1e-300
    maskm = (m_snap >= -mmax_vis) & (m_snap <= mmax_vis)
    m_crop = m_snap[maskm].astype(float)

    stride_n = 100

    # Indices: always include n=0; then every 100; always include last
    idx0 = np.where(n_snap == 0)[0]
    idx_every = np.where((n_snap % stride_n) == 0)[0]
    idx_last = np.array([len(n_snap) - 1], dtype=int)
    idx = np.unique(np.concatenate([idx0, idx_every, idx_last]))

    for it in idx:
        x = m_crop
        y = np.full_like(x, float(n_snap[it]))  # real kick number
        if z_mode == "linear":
            z = pm_snap[it, maskm].astype(float)
        else:
            z = np.log10(pm_snap[it, maskm].astype(float) + eps)
        ax2.plot(x, y, z, color="#3b528b", lw=0.8, alpha=0.9)

    zlabel = r"$\log_{10} P(m,n)$" if z_mode != "linear" else r"$P(m,n)$"
    ax2.set_xlabel(r"$m$")
    ax2.set_ylabel("Number of kicks  n")
    ax2.set_zlabel(zlabel)
    ax2.set_title(zlabel, fontsize=9)
    ax2.view_init(elev=elev, azim=azim)

    fig.tight_layout(pad=0.25)
    outfig = os.path.join(figdir, "kr_poster_triptych_3D.pdf")
    fig.savefig(outfig, bbox_inches="tight")
    plt.close(fig)

    outf = os.path.join(figdir, "kr_localization_fit.txt")
    with open(outf, "w", encoding="utf-8") as f:
        f.write(f"Fit window: |m| in [{tail_m_min}, {tail_m_max}]\n")
        f.write(f"slope b = {fit['slope']:.6e}\n")
        f.write(f"xi = {-1.0/fit['slope'] if fit['slope'] < 0 else np.inf:.6f}\n")
        f.write(f"R^2 = {fit['r2']:.6f}\n")
        f.write(f"npoints = {fit['npoints']}\n")

    print("Saved:", outfig)
    print("Saved:", outf)


In [30]:
if __name__ == "__main__":
    poster_triptych_3d(
        outdir=OUTDIR,
        figdir=FIGDIR,
        K=5.5,
        hbar_eff=1.0,
        ma_window=101,
        plateau_start=2000,   # o None (auto)
        tail_m_min=50,
        tail_m_max=400,
    )

Saved: ../figs/quantum/kr_poster_triptych_3D.pdf
Saved: ../figs/quantum/kr_localization_fit.txt


: 