<a href="https://colab.research.google.com/github/mortonsguide/axis-model-suite/blob/main/Einstein_Cross_5_11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# === Einstein Cross (Q2237+0305)  ===
# Model: SIS + external quadrupole shear  (paper Eq. (154), §5.11)
# Outputs:
#   - einstein_cross_fit.png  (figure)
#   - einstein_cross_params.json  (parameters, RMS, environment, mapping)
# Notes:
#   - Fully deterministic given fixed RNG seed
#   - Uses analytic deflection (no finite-difference noise)
#   - Uses Hungarian assignment to match images to observations robustly

import json, sys, platform, math, os
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize, linear_sum_assignment

# ----------------------------
# Config
# ----------------------------
MODE = "fit"     # "fit" to derive (b, gamma, theta); or "eval" to use FIXED_PARAMS
FIG_NAME = "einstein_cross.png"
JSON_NAME = "einstein_cross_params.json"
RNG_SEED = 42    # for deterministic seed generation / jitter

# Fixed (paper) parameters for evaluation mode (γ=0.010, θγ=2.329 rad, b=0.920)
FIXED_PARAMS = dict(b=0.920, gamma=0.010, theta=2.329)

# Observed Q2237+0305 image positions (arcsec) — as in your interactive notebook & Fig. 15
OBSERVED = np.array([
    [ 0.68,  0.64],
    [-0.62,  0.67],
    [-0.66, -0.63],
    [ 0.70, -0.65]
], dtype=float)

# Source position (β) fixed at lens center for the Einstein Cross test in the paper
BETA = np.array([0.0, 0.0], dtype=float)

# ----------------------------
# Lensing model (Eq. 154) and analytic deflection
# ----------------------------
def deflection_xy(x: float, y: float, b: float, gamma: float, theta: float) -> np.ndarray:
    """
    Analytic deflection α = ∇ψ for ψ_SIS + ψ_shear:
      ψ_SIS = b * sqrt(x^2 + y^2)
      ψ_shear = 0.5*γ*(x^2 - y^2) cos(2θ) + γ x y sin(2θ)

    => α_x = b x/r + γ ( x cos2θ + y sin2θ )
       α_y = b y/r + γ ( -y cos2θ + x sin2θ )
    """
    r = math.hypot(x, y)
    # Avoid r=0 singularity (unresolved core) by taking the limit along gradient direction
    if r == 0.0:
        # In SIS, gradient is undefined exactly at r=0; for our optimization this never matters,
        # but we guard with an infinitesimal displacement.
        r = 1e-12
    c2, s2 = math.cos(2.0*theta), math.sin(2.0*theta)
    ax = b * x / r + gamma * (x * c2 + y * s2)
    ay = b * y / r + gamma * (-y * c2 + x * s2)
    return np.array([ax, ay], dtype=float)

def lens_eq_residual(pt: np.ndarray, beta: np.ndarray, b: float, gamma: float, theta: float) -> float:
    """Residual || θ - α(θ) - β ||^2 to be minimized for an image starting from a seed θ."""
    x, y = float(pt[0]), float(pt[1])
    ax, ay = deflection_xy(x, y, b, gamma, theta)
    rx = (x - ax - beta[0])
    ry = (y - ay - beta[1])
    return rx*rx + ry*ry

# ----------------------------
# Image solver: robust multi-start + optional deduplication
# ----------------------------
def solve_images(b: float, gamma: float, theta: float,
                 beta: np.ndarray = BETA,
                 seeds: np.ndarray | None = None,
                 tol: float = 1e-12,
                 maxiter: int = 20000) -> np.ndarray:
    """
    Solve for 4 image positions using multi-start Nelder–Mead on the lens residual.

    Returns:
      (4,2) array of image positions (unordered)
    Raises:
      RuntimeError if <4 distinct minima are found.
    """
    if seeds is None:
        # Deterministic seed ring: 12 angles around unit circle + cardinals (scaled)
        rng = np.random.default_rng(RNG_SEED)
        thetas = np.linspace(0, 2*np.pi, 12, endpoint=False)
        ring = np.c_[np.cos(thetas), np.sin(thetas)]
        # include cardinals for extra robustness; add slight jitter to reduce symmetry duplicates
        card = np.array([[1,0],[-1,0],[0,1],[0,-1]], dtype=float)
        seeds = np.vstack([ring, 1.3*card])
        seeds += 0.01 * rng.standard_normal(seeds.shape)

    minima = []
    for s in seeds:
        res = minimize(lens_eq_residual, x0=np.array(s, dtype=float),
                       args=(beta, b, gamma, theta),
                       method="Nelder-Mead",
                       options={"xatol": tol, "fatol": tol, "maxiter": maxiter, "disp": False})
        if res.success:
            minima.append(res.x)

    if len(minima) == 0:
        raise RuntimeError("No minima found from seeds — check parameters.")

    # Deduplicate minima by clustering near-identical points
    pts = np.array(minima, dtype=float)
    uniq = []
    for p in pts:
        if not any(np.linalg.norm(p - q) < 1e-3 for q in uniq):
            uniq.append(p)
    uniq = np.array(uniq, dtype=float)

    # Expect exactly 4 images for SIS + shear with β=0 in this configuration
    if len(uniq) < 4:
        raise RuntimeError(f"Only {len(uniq)} distinct images found; need 4.")
    # If more than 4 (rare due to jitter), keep the 4 with smallest residuals
    if len(uniq) > 4:
        vals = np.array([lens_eq_residual(p, beta, b, gamma, theta) for p in uniq])
        idx = np.argsort(vals)[:4]
        uniq = uniq[idx]

    assert uniq.shape == (4,2)
    return uniq

# ----------------------------
# Matching & RMS with assignment (no angle sorting dependence)
# ----------------------------
def match_and_rms(sim: np.ndarray, obs: np.ndarray) -> tuple[np.ndarray, np.ndarray, float]:
    """
    Match simulated ↔ observed by minimizing total squared distance (Hungarian algorithm).
    Returns (sim_ordered, obs_ordered, RMS)
    """
    # cost matrix: squared distances
    D2 = np.sum((sim[:,None,:] - obs[None,:,:])**2, axis=2)
    row_ind, col_ind = linear_sum_assignment(D2)
    sim_ord = sim[row_ind]
    obs_ord = obs[col_ind]
    rms = np.sqrt(np.mean(np.sum((sim_ord - obs_ord)**2, axis=1)))
    return sim_ord, obs_ord, rms

# ----------------------------
# Objective for parameter fitting (outer loop): b, gamma, theta
# ----------------------------
def objective_params(pvec: np.ndarray, obs: np.ndarray, seeds: np.ndarray | None = None) -> float:
    """
    Outer objective over parameters = RMS(sim(obs; p)) with image solve + assignment.
    Uses soft bounds via penalties to keep optimizer in physical region.
    """
    b, g, t = float(pvec[0]), float(pvec[1]), float(pvec[2])

    # Soft bounds / penalties
    pen = 0.0
    if not (0.6 <= b <= 1.4):
        pen += 1e3 * (abs(b - 1.0) + 1.0)
    if not (0.0 <= g <= 0.3):
        pen += 1e3 * (abs(g - 0.15) + 1.0)
    # theta is periodic; wrap to [0, 2π) to improve conditioning
    t = (t % (2*np.pi))

    try:
        sim = solve_images(b, g, t, beta=BETA, seeds=seeds)
        _, _, rms = match_and_rms(sim, obs)
    except Exception:
        return 1e6 + pen

    return float(rms + pen)

def fit_params(obs: np.ndarray,
               p0: tuple[float,float,float] = (0.95, 0.02, 2.2),
               restarts: int = 8,
               seeds: np.ndarray | None = None) -> tuple[dict, float, np.ndarray]:
    """
    Multi-start Nelder–Mead on (b, gamma, theta). Returns (best_params, rms, sim_images).
    """
    # Coarse grid around p0 + a few canonical points near the paper's solution
    b_grid   = np.array([0.85, 0.92, 1.00, p0[0]], dtype=float)
    g_grid   = np.array([0.005, 0.010, 0.020, p0[1]], dtype=float)
    t_center = p0[2]
    t_grid   = np.array([t_center-0.4, 2.329, t_center+0.4], dtype=float)

    start_list = []
    for b in b_grid:
        for g in g_grid:
            for t in t_grid:
                start_list.append((b, g, t))
    # Add a few randomized starts near the expected basin (deterministic RNG)
    rng = np.random.default_rng(RNG_SEED)
    for _ in range(restarts):
        start_list.append((
            0.92  + 0.05*rng.standard_normal(),
            0.010 + 0.01*rng.standard_normal(),
            2.329 + 0.3*rng.standard_normal(),
        ))

    best = (None, np.inf)
    best_sim = None

    for s in start_list:
        res = minimize(objective_params, x0=np.array(s, dtype=float),
                       args=(obs, seeds), method="Nelder-Mead",
                       options={"xatol": 1e-6, "fatol": 1e-8, "maxiter": 2000, "disp": False})
        if res.fun < best[1]:
            b, g, t = float(res.x[0]), float(res.x[1]), float(res.x[2] % (2*np.pi))
            try:
                sim = solve_images(b, g, t, beta=BETA, seeds=seeds)
                _, _, rms = match_and_rms(sim, obs)
            except Exception:
                continue
            best = (dict(b=b, gamma=g, theta=t), rms)
            best_sim = sim

    if best[0] is None:
        raise RuntimeError("Parameter fit failed to converge to a valid solution.")

    return best[0], float(best[1]), best_sim

# ----------------------------
# Plot & save
# ----------------------------
def make_figure(sim: np.ndarray, obs: np.ndarray, params: dict, fname: str) -> None:
    sim_ord, obs_ord, rms = match_and_rms(sim, obs)

    fig, ax = plt.subplots(figsize=(7,7))
    th = np.linspace(0, 2*np.pi, 360)
    ax.plot(np.cos(th), np.sin(th), ls="--", color="k", alpha=0.3, label="Critical curve")

    ax.scatter(sim_ord[:,0], sim_ord[:,1], s=60, color="crimson", label="Simulated (SIS + shear)")
    ax.scatter(obs_ord[:,0], obs_ord[:,1], s=80, facecolors="none", edgecolors="royalblue",
               linewidths=2, label="Observed (Q2237+0305)")
    ax.scatter(0.0, 0.0, marker="*", s=90, color="k", label="Source (β=0)")

    # Connect matched pairs
    for s, o in zip(sim_ord, obs_ord):
        ax.plot([s[0], o[0]], [s[1], o[1]], ls="--", lw=1.0, color="gray", alpha=0.7)

    ax.set_aspect("equal")
    ax.set_xlim(-1.6, 1.6)
    ax.set_ylim(-1.6, 1.6)
    ax.set_xlabel("x [arcsec]")
    ax.set_ylabel("y [arcsec]")
    ax.legend(loc="upper right", fontsize=9)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    fig.savefig(fname, dpi=300)
    plt.close(fig)
    print(f"Saved figure -> {fname}")
    print(f"RMS residual = {rms:.4f} arcsec")

# ----------------------------
# Run
# ----------------------------
def env_meta():
    import numpy, scipy, matplotlib
    git_hash = None
    # best-effort: capture git commit if this is run in a git repo
    try:
        import subprocess
        git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
    except Exception:
        pass
    return {
        "python": sys.version.split()[0],
        "numpy": numpy.__version__,
        "scipy":  __import__("scipy").__version__,
        "matplotlib": matplotlib.__version__,
        "platform": platform.platform(),
        "git_commit": git_hash
    }

def main():
    # Common deterministic seeds for image solving
    rng = np.random.default_rng(RNG_SEED)
    th = np.linspace(0, 2*np.pi, 16, endpoint=False)
    base_ring = np.c_[np.cos(th), np.sin(th)]
    seeds = base_ring + 0.005 * rng.standard_normal(base_ring.shape)

    if MODE.lower() == "eval":
        params = dict(FIXED_PARAMS)  # copy
        sim = solve_images(params["b"], params["gamma"], params["theta"], beta=BETA, seeds=seeds)
        _, _, rms = match_and_rms(sim, OBSERVED)
        print(f"[EVAL] b={params['b']:.6f}, gamma={params['gamma']:.6f}, theta={params['theta']:.6f} rad")
        print(f"[EVAL] RMS residual = {rms:.4f} arcsec")
    elif MODE.lower() == "fit":
        print("[FIT] Solving for (b, gamma, theta) from observed positions …")
        params, rms, sim = fit_params(OBSERVED, p0=(0.95, 0.02, 2.2), restarts=10, seeds=seeds)
        print(f"[FIT] b={params['b']:.6f}, gamma={params['gamma']:.6f}, theta={params['theta']:.6f} rad")
        print(f"[FIT] RMS residual = {rms:.4f} arcsec")
    else:
        raise ValueError("MODE must be 'fit' or 'eval'.")

    # Plot and save
    make_figure(sim, OBSERVED, params, FIG_NAME)

    # Save JSON snapshot for provenance
    sim_ord, obs_ord, rms2 = match_and_rms(sim, OBSERVED)
    snapshot = {
        "mode": MODE,
        "params": params,
        "rms_arcsec": float(rms2),
        "observed_xy_arcsec": obs_ord.tolist(),
        "simulated_xy_arcsec": sim_ord.tolist(),
        "beta_xy_arcsec": BETA.tolist(),
        "paper_reference": {"equation": 154, "section": "5.11"},
        "environment": env_meta(),
    }
    with open(JSON_NAME, "w", encoding="utf-8") as f:
        json.dump(snapshot, f, indent=2)
    print(f"Saved params -> {JSON_NAME}")

if __name__ == "__main__":
    main()


[FIT] Solving for (b, gamma, theta) from observed positions …
[FIT] b=0.928132, gamma=0.005536, theta=2.342936 rad
[FIT] RMS residual = 0.0327 arcsec
Saved figure -> einstein_cross.png
RMS residual = 0.0327 arcsec
Saved params -> einstein_cross_params.json
