In [None]:
# from your repo root
python graph/compare.py \
  --root data/results/Qubits8 \
  --out  graph/Qubits8 \
  --ema  0.9 \
  --metrics loss mmd kl grad_norm step_time
0.7
Step 999: loss=0.075449 mmd=0.002077 kl=0.246650 ||g||=0.001167 dt_ms=0.029ms
0.1
Step 999: loss=0.224564 mmd=0.002195 kl=0.249271 ||g||=0.003662 dt_ms=0.127ms
0.2
Step 999: loss=0.211111 mmd=0.002905 kl=0.263162 ||g||=0.005027 dt_ms=0.029ms
0.3
Step 999: loss=0.172192 mmd=0.002165 kl=0.245061 ||g||=0.003197 dt_ms=0.157ms
0.4
Step 999: loss=0.148880 mmd=0.002176 kl=0.246683 ||g||=0.002398 dt_ms=0.156ms
0.5
Step 999: loss=0.124427 mmd=0.002158 kl=0.246695 ||g||=0.002026 dt_ms=0.128ms

usage: ipykernel_launcher.py [-h] --root ROOT --out OUT [--runs [RUNS ...]]
                             [--labels [LABELS ...]] [--ema EMA]
                             [--metrics [METRICS ...]]
ipykernel_launcher.py: error: the following arguments are required: --root, --out


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [17]:
from pathlib import Path
import numpy as np
import jax, jax.numpy as jnp
import matplotlib.pyplot as plt
import os, sys

from pathlib import Path
import sys, os

# --- Find repo root (works in script or notebook) ---
try:
    REPO_ROOT = Path(__file__).resolve().parents[2]  # when running as a script
except NameError:
    # notebook/REPL: walk up until we find a 'src' folder
    REPO_ROOT = Path.cwd()
    for p in [REPO_ROOT, *REPO_ROOT.parents]:
        if (p / "src").exists():
            REPO_ROOT = p
            break

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Define this since you use it below
report1_dir = str(REPO_ROOT)

# Data path (no ~, fully resolved)
data_path = (REPO_ROOT / "data_2d" / "Qubits8" / "train.csv")


from src.train.mmdagg_probs import mmdagg_prob
from src.train.qcbm import QCBM
from itertools import product

import pandas as pd
data_path = os.path.join(report1_dir, "data_2d/Qubits8/train.csv")
df = pd.read_csv(data_path)
n_bits = 8
L1 =4
L_M =3
bit_cols = [f"q{i}" for i in range(n_bits)]
bitstrings = (
    df[bit_cols]
    .astype(str)
    .agg("".join, axis=1)
)
counts = bitstrings.value_counts().sort_index()
all_bits = ["".join(seq) for seq in product("01", repeat=n_bits)]
probs_full = pd.Series(0.0, index=all_bits, dtype=float)   # float64
probs_full.update(counts / counts.sum())                   # 归一化

gpu = jax.devices("gpu")[0]
target_probs = jax.device_put(jnp.asarray(probs_full.values, dtype=jnp.float64), gpu)
from src.circuits.ansatz1 import hardware_efficient_ansatz
from src.circuits.ansatz2 import ising_structured_ansatz
from src.circuits.ansatz3 import eh2d_ansatz
from src.circuits.ansatz4 import mi_ansatz
## Control # of Params around 100


ansatz = hardware_efficient_ansatz
n_qubits= 8
mmd_fn = mmdagg_prob
R = 2
C = 4
keep_edges = 16
model = QCBM(ansatz=ansatz, n_qubits=n_bits, L=R*C, mmd_fn=mmd_fn, target_probs = target_probs)
model.build_circuits()


# ---------- pick your run ----------
RESULTS_ROOT = Path("/home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8")
RUN_NAME = "result1"   # <-- choose the run (e.g., "result1")
run_dir = RESULTS_ROOT / RUN_NAME

# ---------- center params p0 ----------
if (run_dir/"params.npy").exists():
    P = np.load(run_dir/"params.npy", mmap_mode="r")  # (n_steps, ...[batch], n_params)
    p_arr = P[int(0.6*len(P))-1]                      # mid/late snapshot; change if you like
elif (run_dir/"final_params.npy").exists():
    p_arr = np.load(run_dir/"final_params.npy")
else:
    raise FileNotFoundError("Need params.npy or final_params.npy")

p0 = jnp.asarray(p_arr[0]) if p_arr.ndim == 2 else jnp.asarray(p_arr)

# ---------- scalar loss ----------
def loss_scalar(p):
    val, _ = model.loss(p)    # your model.loss returns (loss, metrics)
    return val
loss_scalar = jax.jit(loss_scalar)

# ---------- helpers for directions ----------
def hvp_at(p, v):                   # H @ v via JVP of grad
    return jax.jvp(jax.grad(loss_scalar), (p,), (v,))[1]

def pick_hessian_top(p, key, iters=25):
    # power iteration for top curvature dir, then next orthogonal
    def power(p, key, iters, ortho=None):
        v = jax.random.normal(key, p.shape)
        if ortho is not None: v = v - jnp.dot(v, ortho) * ortho
        v = v / (jnp.linalg.norm(v) + 1e-12)
        for _ in range(iters):
            w = hvp_at(p, v)
            if ortho is not None: w = w - jnp.dot(w, ortho) * ortho
            v = w / (jnp.linalg.norm(w) + 1e-12)
        return v
    k1, k2 = jax.random.split(key)
    U = power(p, k1, iters)
    V = power(p, k2, iters, U)
    return U, V

def pick_low_curv(p, key, candidates=128):
    # sample many directions, pick two with smallest |v^T H v|
    vs = jax.random.normal(key, (candidates,) + p.shape)
    vs = vs / (jnp.linalg.norm(vs, axis=tuple(range(1, vs.ndim)), keepdims=True) + 1e-12)
    rq = jax.vmap(lambda v: jnp.dot(v, hvp_at(p, v)))(vs)
    idx = jnp.argsort(jnp.abs(rq))
    U = vs[idx[0]]
    V = vs[idx[1]] - jnp.dot(vs[idx[1]], U) * U
    U = U / (jnp.linalg.norm(U) + 1e-12)
    V = V / (jnp.linalg.norm(V) + 1e-12)
    return U, V

def make_slice(p0, mode="hessian-top", N=151, radius=None, key=0):
    key = jax.random.PRNGKey(key)
    if mode == "hessian-top":
        U, V = pick_hessian_top(p0, key, iters=30)
        # adapt radius to curvature so the surface has structure
        cu = float(jnp.dot(U, hvp_at(p0, U))); cv = float(jnp.dot(V, hvp_at(p0, V)))
        if radius is None:
            rad_u = min(6.0, 2.0 / (abs(cu)**0.5 + 1e-6))
            rad_v = min(6.0, 2.0 / (abs(cv)**0.5 + 1e-6))
        else:
            rad_u = rad_v = radius
    elif mode == "low-curv":
        U, V = pick_low_curv(p0, key)
        # wider window to expose flatness
        s = float(jnp.std(p0)) or 1.0
        rad_u = rad_v = 4.0 * s if radius is None else radius
    else:
        raise ValueError("mode must be 'hessian-top' or 'low-curv'")

    a = jnp.linspace(-rad_u, rad_u, N); b = jnp.linspace(-rad_v, rad_v, N)
    A, B = jnp.meshgrid(a, b, indexing="ij")

    def to_params(ab):
        a, b = ab
        return p0 + a * U + b * V

    @jax.jit
    def eval_batch(ab_batch):
        Ps = jax.vmap(to_params)(ab_batch)           # (M, D)
        return jax.vmap(loss_scalar)(Ps)              # (M,)

    AB = jnp.stack([A.ravel(), B.ravel()], axis=1)
    Z = eval_batch(AB).reshape(N, N)
    return np.asarray(A), np.asarray(B), np.asarray(Z)

def plot_contour(A, B, Z, title, out_png):
    ij_min = np.unravel_index(np.argmin(Z), Z.shape)
    ij_max = np.unravel_index(np.argmax(Z), Z.shape)
    plt.figure(figsize=(8,7))
    cf = plt.contourf(A, B, Z, levels=50)
    plt.colorbar(cf, label="Loss")
    plt.scatter(A[ij_min], B[ij_min], marker="v", s=80)  # valley
    plt.scatter(A[ij_max], B[ij_max], marker="^", s=80)  # peak
    plt.xlabel("α (dir 1)"); plt.ylabel("β (dir 2)")
    plt.title(title)
    plt.tight_layout(); plt.savefig(out_png, dpi=170); plt.close()
    print("[saved]", out_png)

# ---------- make TWO landscapes ----------
A, B, Z = make_slice(p0, mode="hessian-top", N=151)
plot_contour(A, B, Z, "Loss landscape (high-curvature slice)", run_dir/"figures"/"landscape_hessian_top.png")

A2, B2, Z2 = make_slice(p0, mode="low-curv", N=151)   # <- shows barren slice
plot_contour(A2, B2, Z2, "Loss landscape (low-curvature / plateau slice)", run_dir/"figures"/"landscape_low_curv.png")



[saved] /home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8/result1/figures/landscape_hessian_top.png
[saved] /home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8/result1/figures/landscape_low_curv.png


In [22]:
from pathlib import Path
import sys, os
import numpy as np
import jax, jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
import optax
import pandas as pd
from itertools import product

# -------------------- locate repo root (script or notebook) --------------------
try:
    REPO_ROOT = Path(__file__).resolve().parents[2]
except NameError:
    REPO_ROOT = Path.cwd()
    for p in [REPO_ROOT, *REPO_ROOT.parents]:
        if (p / "src").exists():
            REPO_ROOT = p
            break
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# -------------------- data -> target_probs --------------------
data_path = REPO_ROOT / "data_2d" / "Qubits8" / "train.csv"
df = pd.read_csv(data_path)
n_bits = 8
bit_cols = [f"q{i}" for i in range(n_bits)]
bitstrings = df[bit_cols].astype(str).agg("".join, axis=1)
counts = bitstrings.value_counts().sort_index()
all_bits = ["".join(seq) for seq in product("01", repeat=n_bits)]
probs_full = pd.Series(0.0, index=all_bits, dtype=float)
probs_full.update(counts / counts.sum())

devices = jax.devices("gpu")
device = devices[0] if devices else jax.devices("cpu")[0]
target_probs = jax.device_put(jnp.asarray(probs_full.values, dtype=jnp.float64), device)

# -------------------- build model --------------------
from src.train.mmdagg_probs import mmdagg_prob
from src.train.qcbm import QCBM
from src.circuits.ansatz1 import hardware_efficient_ansatz

R, C = 2, 4
model = QCBM(
    ansatz=hardware_efficient_ansatz,
    n_qubits=n_bits,
    L=R * C,
    mmd_fn=mmdagg_prob,
    target_probs=target_probs,
)
model.build_circuits()

def loss_scalar(p):
    val, _ = model.loss(p)  # your loss returns (loss, metrics)
    return val
loss_scalar = jax.jit(loss_scalar)

# -------------------- run & checkpoint selection --------------------
RESULTS_ROOT = Path("/home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8")
RUN_NAME = "result1"          # <-- pick the run explicitly
run_dir = RESULTS_ROOT / RUN_NAME
out_dir = run_dir / "figures"
out_dir.mkdir(parents=True, exist_ok=True)

P = None
if (run_dir / "params.npy").exists():
    P = np.load(run_dir / "params.npy", mmap_mode="r")  # (steps, [batch], D)
elif (run_dir / "final_params.npy").exists():
    P = np.load(run_dir / "final_params.npy")[None, ...]  # fake steps dim

# pick checkpoint: use early step for plateau; mid/late for mountains
idx_plateau = 0                      # init (best for barren plateau)
idx_mountain = max(0, int(0.6 * len(P)) - 1)

def get_p0(arr):
    return jnp.asarray(arr[0]) if arr.ndim == 2 else jnp.asarray(arr)

p0_plateau = get_p0(P[idx_plateau])
p0_mountain = get_p0(P[idx_mountain])

# -------------------- HVP helpers and direction pickers --------------------
def hvp_at(p, v):
    # Hessian-vector product H(p) @ v via JVP of grad
    return jax.jvp(jax.grad(loss_scalar), (p,), (v,))[1]

def pick_hessian_top(p, key, iters=25):
    def power(key, ortho=None):
        v = jax.random.normal(key, p.shape)
        if ortho is not None: v = v - jnp.dot(v, ortho) * ortho
        v = v / (jnp.linalg.norm(v) + 1e-12)
        for _ in range(iters):
            w = hvp_at(p, v)
            if ortho is not None: w = w - jnp.dot(w, ortho) * ortho
            v = w / (jnp.linalg.norm(w) + 1e-12)
        return v
    k1, k2 = jax.random.split(key)
    U = power(k1)
    V = power(k2, U)
    return U, V  # principal-curvature directions (ridges/valleys)

def pick_low_curv(p, key, candidates=256):
    vs = jax.random.normal(key, (candidates,) + p.shape)
    vs = vs / (jnp.linalg.norm(vs, axis=tuple(range(1, vs.ndim)), keepdims=True) + 1e-12)
    rq = jax.vmap(lambda v: jnp.dot(v, hvp_at(p, v)))(vs)  # Rayleigh quotients
    idx = jnp.argsort(jnp.abs(rq))
    U = vs[idx[0]]
    V = vs[idx[1]] - jnp.dot(vs[idx[1]], U) * U
    U = U / (jnp.linalg.norm(U) + 1e-12)
    V = V / (jnp.linalg.norm(V) + 1e-12)
    return U, V  # near-zero curvature directions (plateau slice)

# -------------------- slicing & evaluation --------------------
def make_slice(p0, mode="hessian-top", N=151, radius=None, key=0, return_dirs=False):
    key = jax.random.PRNGKey(key)
    if mode == "hessian-top":
        U, V = pick_hessian_top(p0, key, iters=30)
        cu = float(jnp.dot(U, hvp_at(p0, U))); cv = float(jnp.dot(V, hvp_at(p0, V)))
        if radius is None:
            rad_u = min(6.0, 2.0 / (abs(cu)**0.5 + 1e-6))
            rad_v = min(6.0, 2.0 / (abs(cv)**0.5 + 1e-6))
        else:
            rad_u = rad_v = radius
    elif mode == "low-curv":
        U, V = pick_low_curv(p0, key)
        s = float(jnp.std(p0)) or 1.0
        rad_u = rad_v = 4.0 * s if radius is None else radius
    else:
        raise ValueError("mode must be 'hessian-top' or 'low-curv'")

    a = jnp.linspace(-rad_u, rad_u, N); b = jnp.linspace(-rad_v, rad_v, N)
    A, B = jnp.meshgrid(a, b, indexing="ij")

    def to_params(ab):
        a, b = ab
        return p0 + a * U + b * V

    @jax.jit
    def eval_batch(ab_batch):
        Ps = jax.vmap(to_params)(ab_batch)
        return jax.vmap(loss_scalar)(Ps)

    AB = jnp.stack([A.ravel(), B.ravel()], axis=1)
    Z = eval_batch(AB).reshape(N, N)

    if return_dirs:
        return np.asarray(A), np.asarray(B), np.asarray(Z), U, V
    return np.asarray(A), np.asarray(B), np.asarray(Z)

def plot_contour(A, B, Z, title, out_png, vmin=None, vmax=None):
    ij_min = np.unravel_index(np.argmin(Z), Z.shape)
    ij_max = np.unravel_index(np.argmax(Z), Z.shape)
    plt.figure(figsize=(8,7))
    cf = plt.contourf(A, B, Z, levels=50, vmin=vmin, vmax=vmax)
    plt.colorbar(cf, label="Loss")
    plt.scatter(A[ij_min], B[ij_min], marker="v", s=80)
    plt.scatter(A[ij_max], B[ij_max], marker="^", s=80)
    plt.scatter([0],[0], s=40, facecolors="none", edgecolors="w")  # mark center
    plt.xlabel("α (dir 1)"); plt.ylabel("β (dir 2)")
    plt.title(title)
    plt.tight_layout(); plt.savefig(out_png, dpi=170); plt.close()
    print("[saved]", out_png)

from matplotlib import cm
from matplotlib.colors import Normalize

def surface3d(A, B, Z, title, out_png, zlabel="Value", elev=32, azim=45):
    # ensure numpy arrays & matching shapes
    A = np.asarray(A); B = np.asarray(B); Z = np.asarray(Z)
    assert A.shape == B.shape == Z.shape, "A, B, Z must have same shape"

    fig = plt.figure(figsize=(9, 7))
    ax = fig.add_subplot(111, projection="3d")

    # color == height
    norm = Normalize(vmin=float(Z.min()), vmax=float(Z.max()))
    colors = cm.viridis(norm(Z))
    ax.plot_surface(A, B, Z,
                    facecolors=colors,
                    rstride=1, cstride=1,
                    linewidth=0, antialiased=True,
                    shade=False)

    # colorbar anchored to THIS axes  ✅
    m = cm.ScalarMappable(norm=norm, cmap="viridis")
    m.set_array(Z)  # required for some Matplotlib versions
    cbar = fig.colorbar(m, ax=ax, shrink=0.65, aspect=16, pad=0.1)
    cbar.set_label(zlabel)

    ax.set_xlabel("α (dir 1)")
    ax.set_ylabel("β (dir 2)")
    ax.set_zlabel(zlabel)
    ax.set_title(title)
    ax.view_init(elev=elev, azim=azim)

    plt.tight_layout()
    plt.savefig(out_png, dpi=170)
    plt.close(fig)
    print("[saved]", out_png)


# -------------------- 1) Mountains: high-curvature loss --------------------
A_m, B_m, Z_m = make_slice(p0_mountain, mode="hessian-top", N=151)
plot_contour(A_m, B_m, Z_m, "Loss landscape — high-curvature slice",
             out_dir/"landscape_hcurv_contour.png")
surface3d(A_m, B_m, Z_m, "Loss surface — high-curvature slice",
          out_dir/"surface_loss_hcurv.png", zlabel="Loss")

# -------------------- 2) Plateau: low-curvature gradient norm --------------
# pick early checkpoint (you already did earlier)
idx_plateau = 0
p0_plateau = get_p0(P[idx_plateau])

# low-curvature plane at p0_plateau
key = jax.random.PRNGKey(0)
U_p, V_p = pick_low_curv(p0_plateau, key, candidates=512)

# === SMALL RADIUS near p0 to reveal flatness ===
# radius in "parameter RMS" units
param_rms = float(jnp.sqrt(jnp.mean(p0_plateau**2))) or 1.0
rad = 0.05 * param_rms      # try 0.02–0.10; this focuses on the local neighborhood
N = 151

N = 151
a = jnp.linspace(-rad, rad, N); b = jnp.linspace(-rad, rad, N)
A_p, B_p = jnp.meshgrid(a, b, indexing="ij")

def to_params_plateau(ab):
    a, b = ab
    return p0_plateau + a * U_p + b * V_p

loss_jit = jax.jit(loss_scalar)
grad_norm_fn = jax.jit(lambda p: optax.global_norm(jax.grad(loss_jit)(p)))

@jax.jit
def eval_grad(ab_batch):
    Ps = jax.vmap(to_params_plateau)(ab_batch)
    return jax.vmap(grad_norm_fn)(Ps)

AB_p = jnp.stack([A_p.ravel(), B_p.ravel()], axis=1)

# (optional) chunking
def run_chunks(fn, X, chunk=8000):
    outs = []
    for s0 in range(0, X.shape[0], chunk):
        outs.append(fn(X[s0:s0+chunk]))
    return jnp.concatenate(outs)

G = run_chunks(eval_grad, AB_p, chunk=8000).reshape(N, N)   # ||∇L||
# G: (N,N) gradient norms on the low-curvature slice
G_center = float(G[G.shape[0]//2, G.shape[1]//2])
G_med    = float(np.median(G))
G_p10    = float(np.percentile(G, 10))
G_p90    = float(np.percentile(G, 90))
frac_lt = {thr: float((G < thr).mean()) for thr in (1e-3, 1e-5, 1e-7, 1e-9)}

print("[plateau slice] ||grad|| center =", G_center)
print("[plateau slice] median =", G_med, "p10 =", G_p10, "p90 =", G_p90)
print("[plateau slice] fraction below thresholds:", frac_lt)

Glog = np.log10(np.asarray(G) + 1e-12)

# --------- QUANTILE CLIPPING to suppress edge spikes ---------
qlo, qhi = np.percentile(Glog, [5, 95])   # robust range
Glog_clip = np.clip(Glog, qlo, qhi)

# quick stats to confirm "barren"
print(f"[plateau] grad-norm stats: min={G.min():.2e}, median={np.median(G):.2e}, max={G.max():.2e}")
# (optional) also look at loss span on this small window:
@jax.jit
def eval_loss(ab_batch):
    Ps = jax.vmap(to_params_plateau)(ab_batch)
    return jax.vmap(loss_jit)(Ps)
Z_p = run_chunks(eval_loss, AB_p, chunk=8000).reshape(N, N)
print(f"[plateau] Δloss on slice: {Z_p.max()-Z_p.min():.3e}")

# --------- 2D & 3D plots with the clipped range ----------
A_np, B_np = np.asarray(A_p), np.asarray(B_p)

# contour (log10 ||∇L||)
plt.figure(figsize=(8,7))
cf = plt.contourf(A_np, B_np, Glog_clip, levels=50, vmin=qlo, vmax=qhi)
plt.colorbar(cf, label="log10 ||∇Loss||")
plt.scatter([0],[0], s=40, facecolors="none", edgecolors="w")
plt.xlabel("α (low-curv dir U)"); plt.ylabel("β (low-curv dir V)")
plt.title("Gradient norm (log10) — plateau slice (clipped)")
plt.tight_layout(); plt.savefig(out_dir/"plateau_gradnorm_log10_contour.png", dpi=170); plt.close()

# 3D surface (height=color=log10 ||∇L||)
def surface3d_with_limits(A, B, Z, vmin, vmax, title, out_png, zlabel="log10 ||∇Loss||"):
    A = np.asarray(A); B = np.asarray(B); Z = np.asarray(Z)
    fig = plt.figure(figsize=(9,7)); ax = fig.add_subplot(111, projection="3d")
    from matplotlib import cm
    from matplotlib.colors import Normalize
    norm = Normalize(vmin=vmin, vmax=vmax)
    colors = cm.viridis(norm(Z))
    ax.plot_surface(A, B, Z, facecolors=colors, rstride=1, cstride=1,
                    linewidth=0, antialiased=True, shade=False)
    m = cm.ScalarMappable(norm=norm, cmap="viridis"); m.set_array(Z)
    cbar = fig.colorbar(m, ax=ax, shrink=0.65, aspect=16, pad=0.1); cbar.set_label(zlabel)
    ax.set_xlabel("α (U)"); ax.set_ylabel("β (V)"); ax.set_zlabel(zlabel)
    ax.set_zlim(vmin, vmax)   # fix z-range so spikes don't dominate
    ax.set_title(title)
    plt.tight_layout(); plt.savefig(out_png, dpi=170); plt.close(fig)
    print("[saved]", out_png)

surface3d_with_limits(A_np, B_np, Glog_clip, qlo, qhi,
                      "Gradient norm (log10) surface — plateau slice (clipped)",
                      out_dir/"surface_gradnorm_lowcurv_log10_clipped.png")

def radial_profile(Z, A, B, nbins=20):
    r = np.sqrt(A**2 + B**2).ravel()
    y = Z.ravel()
    bins = np.linspace(0, r.max(), nbins+1)
    r_mid, med = [], []
    for i in range(nbins):
        m = (r >= bins[i]) & (r < bins[i+1])
        if m.any():
            r_mid.append(0.5*(bins[i]+bins[i+1]))
            med.append(np.median(y[m]))
    return np.array(r_mid), np.array(med)

# after computing Glog = log10(G + 1e-12)
r_mid, med_log = radial_profile(Glog, A_p, B_p, nbins=24)
plt.figure(figsize=(7,4))
plt.plot(r_mid, med_log)
plt.xlabel("radius ρ (in α,β units)"); plt.ylabel("median log10 ||∇Loss||")
plt.title("Radial profile on low-curvature slice")
plt.tight_layout(); plt.savefig(out_dir/"plateau_radial_profile.png", dpi=170); plt.close()


PH = np.load(run_dir/"params.npy", mmap_mode="r")
if PH.ndim == 3: PH = PH[:,0,:]
u_np, v_np, p0_np = np.asarray(U_p), np.asarray(V_p), np.asarray(p0_plateau)
d = PH - p0_np
A_tr, B_tr = d @ u_np, d @ v_np

plt.figure(figsize=(8,7))
cf = plt.contourf(np.asarray(A_p), np.asarray(B_p), Glog, levels=50)
plt.colorbar(cf, label="log10 ||∇Loss||")
plt.plot(A_tr, B_tr, "w-", lw=1.5, alpha=0.9)
plt.scatter([0],[0], c="w", s=40)
plt.xlabel("α (U)"); plt.ylabel("β (V)")
plt.title("Training path over gradient-norm (low-curv slice)")
plt.tight_layout(); plt.savefig(out_dir/"plateau_path_overlay.png", dpi=170); plt.close()


[saved] /home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8/result1/figures/landscape_hcurv_contour.png
[saved] /home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8/result1/figures/surface_loss_hcurv.png
[plateau slice] ||grad|| center = 1.4984982751806841
[plateau slice] median = 1.4999619120426002 p10 = 1.4654924875183264 p90 = 1.541528384494073
[plateau slice] fraction below thresholds: {0.001: 0.0, 1e-05: 0.0, 1e-07: 0.0, 1e-09: 0.0}
[plateau] grad-norm stats: min=1.45e+00, median=1.50e+00, max=1.57e+00
[plateau] Δloss on slice: 3.977e-02
[saved] /home/cx/Documents/qcbm-ansatz-benchmark/data/results/Qubits8/result1/figures/surface_gradnorm_lowcurv_log10_clipped.png


In [3]:
# ===== 2D global sweep: loss & gradient-norm surfaces  =====
import numpy as np
import jax, jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
import optax
from pathlib import Path

from pathlib import Path
import sys, os
import numpy as np
import jax, jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
import optax
import pandas as pd
from itertools import product

# -------------------- locate repo root (script or notebook) --------------------
try:
    REPO_ROOT = Path(__file__).resolve().parents[2]
except NameError:
    REPO_ROOT = Path.cwd()
    for p in [REPO_ROOT, *REPO_ROOT.parents]:
        if (p / "src").exists():
            REPO_ROOT = p
            break
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# -------------------- data -> target_probs --------------------
data_path = REPO_ROOT / "data_2d" / "Qubits8" / "train.csv"
df = pd.read_csv(data_path)
n_bits = 8
bit_cols = [f"q{i}" for i in range(n_bits)]
bitstrings = df[bit_cols].astype(str).agg("".join, axis=1)
counts = bitstrings.value_counts().sort_index()
all_bits = ["".join(seq) for seq in product("01", repeat=n_bits)]
probs_full = pd.Series(0.0, index=all_bits, dtype=float)
probs_full.update(counts / counts.sum())

devices = jax.devices("gpu")
device = devices[0] if devices else jax.devices("cpu")[0]
target_probs = jax.device_put(jnp.asarray(probs_full.values, dtype=jnp.float64), device)

# -------------------- build model --------------------
from src.train.mmdagg_probs import mmdagg_prob
from src.train.qcbm import QCBM
from src.circuits.ansatz1 import hardware_efficient_ansatz

R, C = 2, 4
model = QCBM(
    ansatz=hardware_efficient_ansatz,
    n_qubits=n_bits,
    L=R * C,
    mmd_fn=mmdagg_prob,
    target_probs=target_probs,
)
model.build_circuits()

# --- choose base point near init (plateaus live there) ---
if (run_dir/"params.npy").exists():
    P = np.load(run_dir/"params.npy", mmap_mode="r")
    p_arr = P[0]                         # init (or a very early step)
elif (run_dir/"final_params.npy").exists():
    p_arr = np.load(run_dir/"final_params.npy")
else:
    raise FileNotFoundError("Need params.npy or final_params.npy")

p_base = jnp.asarray(p_arr[0]) if p_arr.ndim == 2 else jnp.asarray(p_arr)
D = p_base.size

# --- two global directions U,V (orthonormal) ---
U = jnp.ones(D); U = U / (jnp.linalg.norm(U) + 1e-12)
V = jnp.sign(jnp.arange(D) % 2 - 0.5)         # [+1,-1,+1,-1,...]
V = V - jnp.dot(V, U) * U;  V = V / (jnp.linalg.norm(V) + 1e-12)

# If you know masks per gate type, substitute them:
# U = mask_Ry / jnp.linalg.norm(mask_Ry);  V = mask_Rz / jnp.linalg.norm(mask_Rz)

# --- scalar loss & RMS gradient ---
def loss_scalar(p):
    val, _ = model.loss(p)
    return val
loss_scalar = jax.jit(loss_scalar)

# RMS grad (per-parameter scale): sqrt(mean(g^2)) = ||g|| / sqrt(D)
grad_rms_fn = jax.jit(lambda p: optax.global_norm(jax.grad(loss_scalar)(p)) / jnp.sqrt(D))

# --- tighter window around (0,0) (don’t go to ±π) ---
# try ±0.3π first; shrink to ±0.1π if edges still dominate
alpha = jnp.linspace(-0.3*jnp.pi, 0.3*jnp.pi, 201)
beta  = jnp.linspace(-0.3*jnp.pi, 0.3*jnp.pi, 201)
A, B = jnp.meshgrid(alpha, beta, indexing="ij")

def to_params(ab):
    a, b = ab
    return p_base + a * U + b * V

@jax.jit
def eval_surfaces(ab_batch):
    Ps = jax.vmap(to_params)(ab_batch)
    Ls = jax.vmap(loss_scalar)(Ps)
    G  = jax.vmap(grad_rms_fn)(Ps)            # RMS gradient
    return Ls, G

AB = jnp.stack([A.ravel(), B.ravel()], axis=1)

# chunk to avoid OOM if needed
def run_chunks(fn, X, chunk=10000):
    outs = []
    for s0 in range(0, X.shape[0], chunk):
        outs.append(fn(X[s0:s0+chunk]))
    return tuple(jnp.concatenate(parts) for parts in zip(*outs))

Z_loss, G_rms = run_chunks(eval_surfaces, AB)
Z_loss = np.asarray(Z_loss.reshape(A.shape))
G_rms  = np.asarray(G_rms.reshape(A.shape))
Glog   = np.log10(G_rms + 1e-12)

# --- print numbers (so you can claim plateau quantitatively) ---
ctr = (A.shape[0]//2, A.shape[1]//2)
print("[global sweep / RMS] center log10||g||_rms =", float(Glog[ctr]))
print("[global sweep / RMS] median log10||g||_rms =", float(np.median(Glog)))
for thr in (1e-3, 1e-5, 1e-7):
    print(f"fraction(RMS < {thr:g}) =", float((G_rms < thr).mean()))

# --- plot 3D with height==color (RMS gradient, not total norm) ---
from matplotlib import cm
from matplotlib.colors import Normalize
def surface3d(A, B, Z, title, out_png, zlabel):
    A = np.asarray(A); B = np.asarray(B); Z = np.asarray(Z)
    fig = plt.figure(figsize=(9,7)); ax = fig.add_subplot(111, projection="3d")
    norm = Normalize(vmin=Z.min(), vmax=Z.max())
    colors = cm.viridis(norm(Z))
    ax.plot_surface(A, B, Z, facecolors=colors, rstride=1, cstride=1,
                    linewidth=0, antialiased=True, shade=False)
    m = cm.ScalarMappable(norm=norm, cmap="viridis"); m.set_array(Z)
    cbar = fig.colorbar(m, ax=ax, shrink=0.65, aspect=16, pad=0.1); cbar.set_label(zlabel)
    ax.set_xlabel("α (global U)"); ax.set_ylabel("β (global V)"); ax.set_zlabel(zlabel)
    ax.set_title(title); plt.tight_layout(); plt.savefig(out_png, dpi=170); plt.close(fig)

out_dir = run_dir / "figures_global"; out_dir.mkdir(parents=True, exist_ok=True)
surface3d(A, B, Z_loss, "Loss surface — global sweep (tight window)",
          out_dir/"surface_loss_global_tight.png", zlabel="Loss")
surface3d(A, B, Glog,  "RMS gradient (log10) — global sweep (tight window)",
          out_dir/"surface_grad_rms_global_log10_tight.png", zlabel="log10 ||∇Loss||_RMS")


[global sweep / RMS] center log10||g||_rms = -0.8243437526019529
[global sweep / RMS] median log10||g||_rms = -0.8140525553996238
fraction(RMS < 0.001) = 0.0
fraction(RMS < 1e-05) = 0.0
fraction(RMS < 1e-07) = 0.0
