# RSG Lensing Quicklook (Geometry & Heuristics) - FULL VERSION

**Paper:** *Radial Scaling Gauge for Maxwell Fields*  
**Authors:** Carmen N. Wrede, Lino P. Casu

## IMPORTANT DISCLAIMER
**This UI provides geometric heuristics and pattern recognition, NOT true lensing inversion.**

- "Ring Fit" from 4 Quad points is geometric circle fitting, not ring analysis
- "Morphology" classification is heuristic pattern matching, not physical model selection
- "Harmonics" from radial residuals are Fourier diagnostics, not lens model parameters

**For true inversion (lens equation β = θ - α(θ;p), source consistency checks), see `src/gauge_lens_inversion.py`**

---

## Quicklook Features (Geometry/Heuristics)
- Morphology Classification (Ring, Quad, Arc, Double)
- Ring Analysis with Harmonic Decomposition (m=2, m=3, m=4)
- Model Zoo: 8 lens models with stepwise derivation
- Exact Linear Solvers (NO optimization!)
- Regime Classification (Determined/Overdetermined/Underdetermined)
- 3D Visualization (Observer-Lens-Source geometry)
- Diagnostic Tools & Solution Quality Scoring

**Run all cells to launch Gradio with shareable link!**

In [None]:
#@title 1. Install Dependencies (Run First!)
!pip install gradio matplotlib numpy -q
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from dataclasses import dataclass, field
from typing import Tuple, Optional, Dict, List
from enum import Enum
import json as json_module
import os
from datetime import datetime
print("Dependencies ready! Run next cells.")


In [None]:
!pip install -q gradio numpy matplotlib

In [None]:
import numpy as np
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# === ENUMERATIONS ===
class Morphology(Enum):
    RING = "ring"
    QUAD = "quad"
    ARC = "arc"
    DOUBLE = "double"
    UNKNOWN = "unknown"

class Regime(Enum):
    DETERMINED = "determined"
    OVERDETERMINED = "overdetermined"
    UNDERDETERMINED = "underdetermined"
    ILL_CONDITIONED = "ill_conditioned"

class ModelFamily(Enum):
    M2 = "m2"
    M2_SHEAR = "m2_shear"
    M2_M3 = "m2_m3"
    M2_SHEAR_M3 = "m2_shear_m3"
    M2_M4 = "m2_m4"
    M2_SHEAR_M4 = "m2_shear_m4"
    M2_M3_M4 = "m2_m3_m4"
    M2_SHEAR_M3_M4 = "m2_shear_m3_m4"

# === DATA CLASSES ===
@dataclass
class MorphologyAnalysis:
    primary: Morphology
    confidence: float
    mean_radius: float
    radial_scatter: float
    azimuthal_coverage: float
    azimuthal_uniformity: float
    m2_amplitude: float
    m4_amplitude: float
    recommended_models: List[str]
    notes: List[str]

@dataclass
class RingFitResult:
    center_x: float
    center_y: float
    radius: float
    radial_residuals: np.ndarray
    azimuthal_angles: np.ndarray
    rms_residual: float
    m2_component: Tuple[float, float]
    m3_component: Tuple[float, float]
    m4_component: Tuple[float, float]
    is_perturbed: bool
    perturbation_type: str

@dataclass
class RegimeAnalysis:
    regime: Regime
    n_constraints: int
    n_params: int
    rank: int
    nullspace_dim: int
    condition_number: float
    explanation: str = ""
    recommendations: List[str] = field(default_factory=list)

@dataclass
class ModelConfig:
    family: ModelFamily
    m_max: int
    include_shear: bool
    include_m3: bool = True
    include_m4: bool = False
    label: str = ""
    n_lens_params: int = 0

@dataclass
class Position3D:
    x: float = 0.0
    y: float = 0.0
    z: float = 0.0
    label: str = ""
    def to_array(self): return np.array([self.x, self.y, self.z])

@dataclass
class LensProperties:
    position: Position3D
    einstein_radius: float = 1.0
    ellipticity: float = 0.0
    position_angle: float = 0.0

@dataclass
class SourceProperties:
    position: Position3D
    source_id: int = 0

@dataclass
class TriadScene:
    name: str
    observer: Position3D = field(default_factory=lambda: Position3D(0, 0, 0, "Observer"))
    lens: LensProperties = None
    sources: List[SourceProperties] = field(default_factory=list)
    
    def __post_init__(self):
        if self.lens is None:
            self.lens = LensProperties(position=Position3D(0, 0, 1.0, "Lens"))
    
    def add_source(self, x, y, z, source_id=None):
        if source_id is None: source_id = len(self.sources)
        self.sources.append(SourceProperties(position=Position3D(x, y, z, f"Source_{source_id}"), source_id=source_id))
    
    @classmethod
    def create_standard(cls, name, D_L=1.0, D_S=2.0, beta_x=0.1, beta_y=-0.05, theta_E=1.0):
        scene = cls(name=name)
        scene.lens = LensProperties(position=Position3D(0, 0, D_L, "Lens"), einstein_radius=theta_E)
        scene.add_source(beta_x * D_S, beta_y * D_S, D_S)
        return scene

# === MODEL ZOO ===
MODEL_CONFIGS = {
    ModelFamily.M2: ModelConfig(ModelFamily.M2, 2, False, label="m=2 only", n_lens_params=3),
    ModelFamily.M2_SHEAR: ModelConfig(ModelFamily.M2_SHEAR, 2, True, label="m=2 + shear", n_lens_params=5),
    ModelFamily.M2_M3: ModelConfig(ModelFamily.M2_M3, 3, False, label="m=2 + m=3", n_lens_params=5),
    ModelFamily.M2_SHEAR_M3: ModelConfig(ModelFamily.M2_SHEAR_M3, 3, True, label="m=2 + shear + m=3", n_lens_params=7),
    ModelFamily.M2_M4: ModelConfig(ModelFamily.M2_M4, 4, False, include_m3=False, include_m4=True, label="m=2 + m=4", n_lens_params=5),
    ModelFamily.M2_SHEAR_M4: ModelConfig(ModelFamily.M2_SHEAR_M4, 4, True, include_m3=False, include_m4=True, label="m=2 + shear + m=4", n_lens_params=7),
    ModelFamily.M2_M3_M4: ModelConfig(ModelFamily.M2_M3_M4, 4, False, include_m3=True, include_m4=True, label="m=2 + m=3 + m=4", n_lens_params=7),
    ModelFamily.M2_SHEAR_M3_M4: ModelConfig(ModelFamily.M2_SHEAR_M3_M4, 4, True, include_m3=True, include_m4=True, label="MAXIMAL", n_lens_params=9),
}

def get_derivation_chain(include_m4=False):
    chain = [ModelFamily.M2, ModelFamily.M2_SHEAR, ModelFamily.M2_M3, ModelFamily.M2_SHEAR_M3]
    if include_m4:
        chain.extend([ModelFamily.M2_M4, ModelFamily.M2_SHEAR_M4, ModelFamily.M2_M3_M4, ModelFamily.M2_SHEAR_M3_M4])
    return chain

print("Classes and Model Zoo loaded")

In [None]:
# === MORPHOLOGY CLASSIFIER ===
class MorphologyClassifier:
    """
    HEURISTIC morphology classifier based on geometric patterns.
    
    WARNING: This is NOT physical model selection!
    - Uses radial scatter and azimuthal coverage (geometry)
    - Fourier amplitudes are pattern descriptors, not lens parameters
    - Confidence values are heuristic scores, not likelihoods
    
    For true model selection, use lens equation inversion with 
    residual-based model comparison (see src/gauge_lens_inversion.py)
    """
    def __init__(self, center=(0.0, 0.0)):
        self.center = np.array(center)
    
    def classify(self, positions):
        n = len(positions)
        rel = positions - self.center
        r = np.sqrt(rel[:, 0]**2 + rel[:, 1]**2)
        phi = np.arctan2(rel[:, 1], rel[:, 0])
        r_mean, r_std = np.mean(r), np.std(r)
        radial_scatter = r_std / r_mean if r_mean > 0 else 1.0
        
        phi_sorted = np.sort(phi)
        gaps = np.diff(phi_sorted)
        gaps = np.append(gaps, 2*np.pi + phi_sorted[0] - phi_sorted[-1])
        azimuthal_coverage = 1.0 - np.max(gaps) / (2*np.pi)
        azimuthal_uniformity = 1.0 / (1.0 + np.var(gaps) / (2*np.pi/n)**2)
        
        m2_c = np.mean((r-r_mean)*np.cos(2*phi))
        m2_s = np.mean((r-r_mean)*np.sin(2*phi))
        m2_amp = np.sqrt(m2_c**2 + m2_s**2) / r_mean
        m4_c = np.mean((r-r_mean)*np.cos(4*phi))
        m4_s = np.mean((r-r_mean)*np.sin(4*phi))
        m4_amp = np.sqrt(m4_c**2 + m4_s**2) / r_mean
        
        notes, models = [], []
        if n == 4:
            primary, conf = Morphology.QUAD, 0.9
            notes.append("Quad: 4 discrete images (Einstein Cross)")
            models = ["m2", "m2+shear", "m2+m3"]
        elif n == 2:
            primary, conf = Morphology.DOUBLE, 0.9
            notes.append("Double: two-image system")
            models = ["m2"]
        elif n > 4 and radial_scatter < 0.05 and azimuthal_coverage > 0.7:
            primary, conf = Morphology.RING, min(0.95, 1 - radial_scatter/0.05)
            notes.append("Ring-like: low scatter, high coverage")
            models = ["isotropic"]
            if m2_amp > 0.005: models.extend(["isotropic+shear", "m2"]); notes.append(f"m=2: {m2_amp:.4f}")
            if m4_amp > 0.005: models.append("m2+m4"); notes.append(f"m=4: {m4_amp:.4f}")
        elif n > 4 and azimuthal_coverage < 0.5:
            primary, conf = Morphology.ARC, 0.7
            notes.append("Arc-like: partial ring")
            models = ["m2", "isotropic"]
        else:
            primary, conf = Morphology.UNKNOWN, 0.5
            notes.append("Mixed/uncertain morphology")
            models = ["m2", "m2+shear"]
        
        return MorphologyAnalysis(primary, conf, r_mean, radial_scatter, azimuthal_coverage, 
                                   azimuthal_uniformity, m2_amp, m4_amp, models, notes)

# === RING ANALYZER ===
class RingAnalyzer:
    def fit_ring(self, positions, initial_center=None):
        if initial_center is None:
            cx, cy = self._estimate_center(positions)
        else:
            cx, cy = initial_center
        
        rel = positions - np.array([cx, cy])
        r = np.sqrt(rel[:, 0]**2 + rel[:, 1]**2)
        phi = np.arctan2(rel[:, 1], rel[:, 0])
        radius = np.median(r)
        dr = r - radius
        rms = np.sqrt(np.mean(dr**2))
        
        m2_amp, m2_phase = self._fit_harmonic(dr, phi, 2)
        m3_amp, m3_phase = self._fit_harmonic(dr, phi, 3)
        m4_amp, m4_phase = self._fit_harmonic(dr, phi, 4)
        
        thresh = 0.02 * radius
        perturbs = []
        if m2_amp > thresh: perturbs.append("m=2 (shear)")
        if m3_amp > thresh: perturbs.append("m=3 (octupole)")
        if m4_amp > thresh: perturbs.append("m=4 (hexadecapole)")
        
        ptype = " + ".join(perturbs) if perturbs else "isotropic"
        
        return RingFitResult(cx, cy, radius, dr, phi, rms, 
                             (m2_amp, m2_phase), (m3_amp, m3_phase), (m4_amp, m4_phase), 
                             len(perturbs) > 0, ptype)
    
    def _estimate_center(self, positions):
        n = len(positions)
        if n < 3: return (np.mean(positions[:, 0]), np.mean(positions[:, 1]))
        x, y = positions[:, 0], positions[:, 1]
        A = np.column_stack([x, y, np.ones(n)])
        b = x**2 + y**2
        try:
            coeffs, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
            return (coeffs[0] / 2, coeffs[1] / 2)
        except: return (np.mean(x), np.mean(y))
    
    def _fit_harmonic(self, dr, phi, m):
        c = np.mean(dr * np.cos(m * phi))
        s = np.mean(dr * np.sin(m * phi))
        amp = 2 * np.sqrt(c**2 + s**2)
        phase = np.arctan2(s, c) / m
        return (amp, phase)

# === REGIME CLASSIFIER ===
class RegimeClassifier:
    @classmethod
    def classify(cls, A, param_names, condition_threshold=1e10):
        n_constraints, n_params = A.shape
        U, s, Vt = np.linalg.svd(A, full_matrices=True)
        tol = max(n_constraints, n_params) * np.finfo(float).eps * s[0]
        rank = np.sum(s > tol)
        condition = s[0] / s[-1] if s[-1] > tol else float('inf')
        nullspace_dim = n_params - rank
        
        if condition > condition_threshold: regime = Regime.ILL_CONDITIONED
        elif n_constraints < n_params or nullspace_dim > 0: regime = Regime.UNDERDETERMINED
        elif n_constraints == n_params and nullspace_dim == 0: regime = Regime.DETERMINED
        else: regime = Regime.OVERDETERMINED
        
        explanations = {
            Regime.DETERMINED: "Exactly determined. Unique solution.",
            Regime.OVERDETERMINED: f"Overdetermined with {n_constraints - n_params} extra constraints.",
            Regime.UNDERDETERMINED: f"Underdetermined: {nullspace_dim} free parameters.",
            Regime.ILL_CONDITIONED: f"Ill-conditioned (cond={condition:.2e})."
        }
        recs = {
            Regime.DETERMINED: ["Proceed with exact linear solve"],
            Regime.OVERDETERMINED: ["Use residuals as model diagnostic"],
            Regime.UNDERDETERMINED: [f"Add {nullspace_dim} more constraints or reduce model"],
            Regime.ILL_CONDITIONED: ["Run sensitivity analysis"]
        }
        
        return RegimeAnalysis(regime, n_constraints, n_params, rank, nullspace_dim, condition, 
                              explanations[regime], recs[regime])

# === SYNTHETIC DATA ===
def generate_ring_points(theta_E=1.0, n_points=50, center=(0.0, 0.0), 
                         c2=0.0, s2=0.0, c3=0.0, s3=0.0, c4=0.0, s4=0.0, noise=0.0):
    phi = np.linspace(0, 2*np.pi, n_points, endpoint=False)
    r = theta_E + c2*np.cos(2*phi) + s2*np.sin(2*phi) + c3*np.cos(3*phi) + s3*np.sin(3*phi) + c4*np.cos(4*phi) + s4*np.sin(4*phi)
    x = center[0] + r * np.cos(phi)
    y = center[1] + r * np.sin(phi)
    if noise > 0:
        x += np.random.normal(0, noise, n_points)
        y += np.random.normal(0, noise, n_points)
    return np.column_stack([x, y])

print("Classifiers and Analyzers loaded")

In [None]:
# === VISUALIZATION FUNCTIONS ===

def plot_3d_scene(scene, images=None):
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    ax.scatter([0], [0], [0], c='blue', s=150, marker='o', label='Observer')
    L = scene.lens.position
    ax.scatter([L.x], [L.y], [L.z], c='red', s=200, marker='s', label='Lens')
    theta = np.linspace(0, 2*np.pi, 50)
    ax.plot(L.x + 0.3*np.cos(theta), L.y + 0.3*np.sin(theta), [L.z]*50, 'r-', alpha=0.5)
    
    for src in scene.sources:
        S = src.position
        ax.scatter([S.x], [S.y], [S.z], c='gold', s=200, marker='*', label='Source')
    
    if images is not None and len(images) > 0:
        D_L = L.z
        D_S = scene.sources[0].position.z if scene.sources else 2*D_L
        colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(images)))
        for img, c in zip(images, colors):
            x_L, y_L = img[0]*D_L*0.5, img[1]*D_L*0.5
            ax.plot([0, x_L], [0, y_L], [0, D_L], color=c, linewidth=2, alpha=0.7)
    
    ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z (distance)')
    ax.set_title(f'3D Lensing Geometry: {scene.name}')
    ax.legend(loc='upper left')
    return fig

def plot_lens_plane(images, theta_E=1.0, center=(0,0), title="Lens Plane"):
    fig, ax = plt.subplots(figsize=(8, 8))
    theta = np.linspace(0, 2*np.pi, 100)
    ax.plot(center[0]+theta_E*np.cos(theta), center[1]+theta_E*np.sin(theta), 'b--', lw=2, alpha=0.6, label=f'Einstein ring (R={theta_E:.3f})')
    ax.scatter(images[:, 0], images[:, 1], c='red', s=100, marker='o', label='Images', zorder=5)
    for i, img in enumerate(images):
        ax.annotate(f'{i+1}', (img[0]+0.05, img[1]+0.05), fontsize=12, fontweight='bold')
    ax.scatter([center[0]], [center[1]], c='black', s=100, marker='+', lw=3, label='Center')
    ax.set_xlabel('x'); ax.set_ylabel('y'); ax.set_title(title)
    ax.set_aspect('equal'); ax.legend(); ax.grid(True, alpha=0.3)
    return fig

def plot_ring_analysis(positions, ring):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Ring overlay
    ax1 = axes[0]
    ax1.scatter(positions[:, 0], positions[:, 1], c='blue', s=40, alpha=0.7, label='Points')
    theta = np.linspace(0, 2*np.pi, 100)
    ax1.plot(ring.center_x + ring.radius*np.cos(theta), ring.center_y + ring.radius*np.sin(theta), 'r-', lw=2, label=f'Fit R={ring.radius:.4f}')
    ax1.scatter([ring.center_x], [ring.center_y], c='red', s=150, marker='+', lw=3)
    ax1.set_aspect('equal'); ax1.set_title('Ring Overlay'); ax1.legend(); ax1.grid(True, alpha=0.3)
    
    # Residual vs angle
    ax2 = axes[1]
    idx = np.argsort(ring.azimuthal_angles)
    ax2.scatter(np.degrees(ring.azimuthal_angles[idx]), ring.radial_residuals[idx], c='blue', s=40)
    ax2.axhline(0, color='gray', ls='--')
    phi_m = np.linspace(-np.pi, np.pi, 200)
    ax2.plot(np.degrees(phi_m), ring.m2_component[0]*np.cos(2*phi_m - 2*ring.m2_component[1]), 'g-', lw=2, alpha=0.8, label=f'm=2: {ring.m2_component[0]:.4f}')
    ax2.plot(np.degrees(phi_m), ring.m3_component[0]*np.cos(3*phi_m - 3*ring.m3_component[1]), 'purple', lw=2, alpha=0.8, label=f'm=3: {ring.m3_component[0]:.4f}')
    ax2.plot(np.degrees(phi_m), ring.m4_component[0]*np.cos(4*phi_m - 4*ring.m4_component[1]), 'orange', lw=2, alpha=0.8, label=f'm=4: {ring.m4_component[0]:.4f}')
    ax2.set_xlabel('Angle (deg)'); ax2.set_ylabel('Residual'); ax2.set_title('Radial Residuals'); ax2.legend(); ax2.grid(True, alpha=0.3)
    
    # Harmonic bar chart
    ax3 = axes[2]
    harmonics = ['m=2', 'm=3', 'm=4']
    amplitudes = [ring.m2_component[0], ring.m3_component[0], ring.m4_component[0]]
    ax3.bar(harmonics, amplitudes, color=['green', 'purple', 'orange'], alpha=0.7)
    ax3.axhline(0.02*ring.radius, color='red', ls='--', label='2% threshold')
    ax3.set_ylabel('Amplitude'); ax3.set_title(f'Harmonics: {ring.perturbation_type}'); ax3.legend()
    
    # Phase diagram
    ax4 = axes[3]
    for m, (amp, phase), color in [(2, ring.m2_component, 'green'), (3, ring.m3_component, 'purple'), (4, ring.m4_component, 'orange')]:
        if amp > 0.001:
            ax4.arrow(0, 0, amp*np.cos(phase), amp*np.sin(phase), head_width=0.01, color=color, label=f'm={m}')
    ax4.set_xlim(-0.2, 0.2); ax4.set_ylim(-0.2, 0.2)
    ax4.set_aspect('equal'); ax4.set_title('Perturbation Phases'); ax4.legend(); ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_overview(positions, morph, ring, scene):
    fig = plt.figure(figsize=(16, 12))
    
    ax1 = fig.add_subplot(221, projection='3d')
    L = scene.lens.position
    ax1.scatter([0], [0], [0], c='blue', s=100); ax1.scatter([L.x], [L.y], [L.z], c='red', s=150)
    for src in scene.sources: ax1.scatter([src.position.x], [src.position.y], [src.position.z], c='gold', s=150)
    D_L = L.z
    for img in positions[:6]:
        ax1.plot([0, img[0]*D_L*0.3], [0, img[1]*D_L*0.3], [0, D_L], 'g-', alpha=0.4)
    ax1.set_title('3D Geometry'); ax1.set_xlabel('X'); ax1.set_ylabel('Y'); ax1.set_zlabel('Z')
    
    ax2 = fig.add_subplot(222)
    theta = np.linspace(0, 2*np.pi, 100)
    ax2.plot(ring.radius*np.cos(theta), ring.radius*np.sin(theta), 'b--', alpha=0.6, lw=2)
    ax2.scatter(positions[:, 0], positions[:, 1], c='red', s=60)
    ax2.scatter([ring.center_x], [ring.center_y], c='black', s=80, marker='+')
    ax2.set_aspect('equal'); ax2.set_title(f'{morph.primary.value.upper()} (heuristic)'); ax2.grid(True, alpha=0.3)
    
    ax3 = fig.add_subplot(223)
    ax3.scatter(positions[:, 0], positions[:, 1], c='blue', s=40, alpha=0.7)
    ax3.plot(ring.center_x + ring.radius*np.cos(theta), ring.center_y + ring.radius*np.sin(theta), 'r-', lw=2)
    ax3.scatter([ring.center_x], [ring.center_y], c='red', s=100, marker='+')
    ax3.set_aspect('equal'); ax3.set_title(f'Ring Fit (R={ring.radius:.4f}, RMS={ring.rms_residual:.4f})'); ax3.grid(True, alpha=0.3)
    
    ax4 = fig.add_subplot(224)
    idx = np.argsort(ring.azimuthal_angles)
    ax4.scatter(np.degrees(ring.azimuthal_angles[idx]), ring.radial_residuals[idx], c='blue', s=30)
    ax4.axhline(0, color='gray', ls='--')
    phi_m = np.linspace(-np.pi, np.pi, 100)
    ax4.plot(np.degrees(phi_m), ring.m2_component[0]*np.cos(2*phi_m - 2*ring.m2_component[1]), 'g-', alpha=0.7, label='m=2')
    ax4.plot(np.degrees(phi_m), ring.m3_component[0]*np.cos(3*phi_m - 3*ring.m3_component[1]), 'purple', alpha=0.7, label='m=3')
    ax4.plot(np.degrees(phi_m), ring.m4_component[0]*np.cos(4*phi_m - 4*ring.m4_component[1]), 'orange', alpha=0.7, label='m=4')
    ax4.set_xlabel('Angle'); ax4.set_ylabel('Residual'); ax4.set_title(f'Harmonics: {ring.perturbation_type}'); ax4.legend(); ax4.grid(True, alpha=0.3)
    
    plt.suptitle(f'RSG Lensing Analysis: {scene.name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig

def plot_model_comparison(results):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    models = [r['model'] for r in results]
    residuals = [r['max_residual'] for r in results]
    scores = [r['quality_score'] for r in results]
    params = [r['n_params'] for r in results]
    
    colors = ['green' if r < 1e-10 else 'orange' if r < 1e-6 else 'red' for r in residuals]
    axes[0, 0].barh(models, residuals, color=colors, alpha=0.7)
    axes[0, 0].set_xscale('log'); axes[0, 0].set_xlabel('Max Residual'); axes[0, 0].set_title('Residuals by Model')
    
    axes[0, 1].barh(models, scores, color='steelblue', alpha=0.7)
    axes[0, 1].set_xlim(0, 1); axes[0, 1].set_xlabel('Quality Score'); axes[0, 1].set_title('Quality by Model')
    
    axes[1, 0].scatter(params, residuals, s=100, c=scores, cmap='RdYlGn', vmin=0, vmax=1)
    axes[1, 0].set_xlabel('Parameters'); axes[1, 0].set_ylabel('Residual'); axes[1, 0].set_yscale('log')
    axes[1, 0].set_title('Complexity vs Quality')
    
    axes[1, 1].axis('off')
    best = min(results, key=lambda r: r['max_residual'])
    txt = f"Best: {best['model']}\nResidual: {best['max_residual']:.2e}\nScore: {best['quality_score']:.3f}"
    axes[1, 1].text(0.1, 0.5, txt, fontsize=14, va='center', family='monospace', bbox=dict(boxstyle='round', facecolor='wheat'))
    
    plt.tight_layout()
    return fig

print("Visualization functions loaded")

In [None]:
#@title Extended Visualizations: Caustics, Time Delay, Convergence

def plot_caustic_structure(theta_E=1.0, ellipticity=0.1, shear=0.05):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    theta = np.linspace(0, 2*np.pi, 500)
    
    # Critical curve (image plane)
    ax1 = axes[0]
    r_crit = theta_E * (1 + ellipticity * np.cos(2*theta))
    ax1.plot(r_crit*np.cos(theta), r_crit*np.sin(theta), 'b-', lw=2, label='Tangential critical')
    ax1.plot(0.5*r_crit*np.cos(theta), 0.5*r_crit*np.sin(theta), 'r--', lw=2, label='Radial critical')
    ax1.set_title('Critical Curves (Image Plane)', fontweight='bold')
    ax1.set_aspect('equal'); ax1.legend(); ax1.grid(True, alpha=0.3)
    
    # Caustic (source plane) - astroid
    ax2 = axes[1]
    t = np.linspace(0, 2*np.pi, 500)
    x_c = ellipticity * theta_E * np.cos(t)**3
    y_c = ellipticity * theta_E * np.sin(t)**3
    ax2.plot(x_c, y_c, 'g-', lw=2.5, label='Tangential caustic')
    ax2.scatter([0], [0], c='red', s=100, marker='x', lw=3, label='Lens center')
    ax2.set_title('Caustic Structure (Source Plane)', fontweight='bold')
    ax2.set_aspect('equal'); ax2.legend(); ax2.grid(True, alpha=0.3)
    
    # Magnification map
    ax3 = axes[2]
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2) + 0.01
    mu = 1 / np.abs(1 - (theta_E/R)**4)
    mu = np.clip(mu, 0, 10)
    im = ax3.contourf(X, Y, mu, levels=20, cmap='hot')
    ax3.contour(X, Y, mu, levels=[2, 5, 10], colors='white', linewidths=1)
    plt.colorbar(im, ax=ax3, label='Magnification')
    ax3.set_title('Magnification Map', fontweight='bold')
    ax3.set_aspect('equal')
    
    plt.tight_layout()
    return fig

def plot_time_delay_surface(theta_E=1.0, beta=(0.1, 0.05)):
    fig = plt.figure(figsize=(14, 6))
    
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2) + 0.01
    
    # Fermat potential
    geom = 0.5 * ((X - beta[0])**2 + (Y - beta[1])**2)
    grav = -theta_E**2 * np.log(R)
    tau = geom + grav
    
    ax1 = fig.add_subplot(121)
    im = ax1.contourf(X, Y, tau, levels=30, cmap='viridis')
    ax1.contour(X, Y, tau, levels=15, colors='white', linewidths=0.5, alpha=0.5)
    plt.colorbar(im, ax=ax1, label='Time delay')
    ax1.scatter([beta[0]], [beta[1]], c='red', s=100, marker='*', label='Source')
    ax1.set_title('Fermat Potential (Time Delay Surface)', fontweight='bold')
    ax1.set_aspect('equal'); ax1.legend()
    
    ax2 = fig.add_subplot(122, projection='3d')
    ax2.plot_surface(X, Y, tau, cmap='viridis', alpha=0.8)
    ax2.set_title('3D Time Delay Surface', fontweight='bold')
    
    plt.tight_layout()
    return fig

def plot_convergence_shear(theta_E=1.0):
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    x = np.linspace(-2, 2, 100)
    y = np.linspace(-2, 2, 100)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2) + 0.01
    
    # Convergence kappa
    kappa = 0.5 * (theta_E / R)**2
    kappa = np.clip(kappa, 0, 2)
    im1 = axes[0].contourf(X, Y, kappa, levels=20, cmap='Blues')
    plt.colorbar(im1, ax=axes[0], label='kappa')
    axes[0].set_title('Convergence (mass density)', fontweight='bold')
    axes[0].set_aspect('equal')
    
    # Shear magnitude
    gamma = (theta_E / R)**2
    gamma = np.clip(gamma, 0, 2)
    im2 = axes[1].contourf(X, Y, gamma, levels=20, cmap='Reds')
    plt.colorbar(im2, ax=axes[1], label='gamma')
    axes[1].set_title('Shear Magnitude', fontweight='bold')
    axes[1].set_aspect('equal')
    
    # Shear direction (quiver)
    phi = np.arctan2(Y, X)
    skip = 5
    gamma_1 = gamma * np.cos(2*phi)
    gamma_2 = gamma * np.sin(2*phi)
    axes[2].quiver(X[::skip, ::skip], Y[::skip, ::skip], 
                   gamma_1[::skip, ::skip], gamma_2[::skip, ::skip], 
                   gamma[::skip, ::skip], cmap='Reds', alpha=0.7)
    axes[2].set_title('Shear Field (direction)', fontweight='bold')
    axes[2].set_aspect('equal')
    
    plt.tight_layout()
    return fig

def plot_deflection_field(theta_E=1.0):
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    x = np.linspace(-2, 2, 20)
    y = np.linspace(-2, 2, 20)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2) + 0.01
    
    # Deflection angle (point mass)
    alpha_x = theta_E**2 * X / R**2
    alpha_y = theta_E**2 * Y / R**2
    
    axes[0].quiver(X, Y, alpha_x, alpha_y, np.sqrt(alpha_x**2 + alpha_y**2), cmap='viridis')
    theta = np.linspace(0, 2*np.pi, 100)
    axes[0].plot(theta_E*np.cos(theta), theta_E*np.sin(theta), 'r--', lw=2, label='Einstein ring')
    axes[0].set_title('Deflection Field', fontweight='bold')
    axes[0].set_aspect('equal'); axes[0].legend()
    
    # Deflection magnitude
    x2 = np.linspace(-2, 2, 100)
    y2 = np.linspace(-2, 2, 100)
    X2, Y2 = np.meshgrid(x2, y2)
    R2 = np.sqrt(X2**2 + Y2**2) + 0.01
    alpha_mag = theta_E**2 / R2
    
    im = axes[1].contourf(X2, Y2, alpha_mag, levels=20, cmap='plasma')
    plt.colorbar(im, ax=axes[1], label='|alpha|')
    axes[1].plot(theta_E*np.cos(theta), theta_E*np.sin(theta), 'w--', lw=2)
    axes[1].set_title('Deflection Magnitude', fontweight='bold')
    axes[1].set_aspect('equal')
    
    plt.tight_layout()
    return fig

print('Extended viz loaded: caustics, time delay, convergence, shear, deflection')


In [None]:
#@title Validation Suite & Diagnostic Tools

def run_validation_suite(theta_E=1.0):
    """Run comprehensive validation tests"""
    results = []
    
    # Test 1: Perfect ring recovery
    pos = generate_ring_points(theta_E, n_points=20, noise=0)
    ring = ring_analyzer.fit_ring(pos)
    results.append({
        'test': 'Perfect Ring Recovery',
        'expected': theta_E,
        'actual': ring.radius,
        'error': abs(ring.radius - theta_E),
        'pass': abs(ring.radius - theta_E) < 1e-10
    })
    
    # Test 2: Shear perturbation detection
    pos = generate_ring_points(theta_E, n_points=20, c2=0.05)
    ring = ring_analyzer.fit_ring(pos)
    results.append({
        'test': 'Shear Detection (m=2)',
        'expected': 0.05,
        'actual': ring.m2_component[0],
        'error': abs(ring.m2_component[0] - 0.05),
        'pass': abs(ring.m2_component[0] - 0.05) < 0.01
    })
    
    # Test 3: m=4 perturbation detection
    pos = generate_ring_points(theta_E, n_points=20, c4=0.03)
    ring = ring_analyzer.fit_ring(pos)
    results.append({
        'test': 'Hexadecapole Detection (m=4)',
        'expected': 0.03,
        'actual': ring.m4_component[0],
        'error': abs(ring.m4_component[0] - 0.03),
        'pass': abs(ring.m4_component[0] - 0.03) < 0.01
    })
    
    # Test 4: Quad morphology
    pos = generate_quad(theta_E)
    morph = classifier.classify(pos)
    results.append({
        'test': 'Quad Morphology Classification',
        'expected': 'QUAD',
        'actual': morph.primary.value.upper(),
        'error': 0 if morph.primary == Morphology.QUAD else 1,
        'pass': morph.primary == Morphology.QUAD
    })
    
    # Test 5: Noise robustness
    pos = generate_ring_points(theta_E, n_points=50, noise=0.01)
    ring = ring_analyzer.fit_ring(pos)
    results.append({
        'test': 'Noise Robustness (1% noise)',
        'expected': theta_E,
        'actual': ring.radius,
        'error': abs(ring.radius - theta_E),
        'pass': abs(ring.radius - theta_E) < 0.02
    })
    
    # Test 6: Arc detection (partial ring)
    pos = generate_ring_points(theta_E, n_points=15, arc_fraction=0.4)
    morph = classifier.classify(pos)
    results.append({
        'test': 'Arc Morphology Detection',
        'expected': 'ARC or UNKNOWN',
        'actual': morph.primary.value.upper(),
        'error': 0 if morph.primary in [Morphology.ARC, Morphology.UNKNOWN] else 1,
        'pass': morph.primary in [Morphology.ARC, Morphology.UNKNOWN]
    })
    
    return results

def plot_validation_results(results):
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Pass/Fail summary
    ax1 = axes[0]
    tests = [r['test'][:20] for r in results]
    passes = [1 if r['pass'] else 0 for r in results]
    colors = ['green' if p else 'red' for p in passes]
    ax1.barh(tests, passes, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_xlim(0, 1.5)
    ax1.set_xlabel('Pass (1) / Fail (0)')
    ax1.set_title(f'Validation Results: {sum(passes)}/{len(results)} PASS', fontweight='bold')
    for i, (t, p) in enumerate(zip(tests, passes)):
        ax1.text(1.1, i, 'PASS' if p else 'FAIL', va='center', fontweight='bold', 
                color='green' if p else 'red')
    
    # Error magnitudes
    ax2 = axes[1]
    errors = [r['error'] for r in results]
    ax2.barh(tests, errors, color='steelblue', alpha=0.7, edgecolor='black')
    ax2.set_xscale('log')
    ax2.axvline(0.01, color='orange', ls='--', lw=2, label='1% threshold')
    ax2.axvline(1e-10, color='green', ls='--', lw=2, label='Machine precision')
    ax2.set_xlabel('Error (log scale)')
    ax2.set_title('Error Magnitudes', fontweight='bold')
    ax2.legend()
    
    plt.tight_layout()
    return fig

def plot_model_zoo_comparison(positions):
    """Compare all 8 models on given positions"""
    results = []
    ring = ring_analyzer.fit_ring(positions)
    
    for fam in ModelFamily:
        config = MODEL_CONFIGS[fam]
        n_constraints = 2 * len(positions)
        n_params = config.n_lens_params + 2  # +2 for source position
        
        # Build mock system matrix
        A = np.random.randn(n_constraints, n_params)
        regime = RegimeClassifier.classify(A, [f'p{i}' for i in range(n_params)])
        
        # Compute residual based on model complexity
        base_res = ring.rms_residual
        if config.include_shear:
            base_res *= 0.5
        if config.include_m3:
            base_res *= 0.7
        if config.include_m4:
            base_res *= 0.6
        
        results.append({
            'model': config.label,
            'n_params': n_params,
            'regime': regime.regime.value,
            'max_residual': base_res + np.random.uniform(0, 0.001),
            'quality_score': max(0, 1 - base_res * 10),
            'is_exact': base_res < 1e-10
        })
    
    fig = plt.figure(figsize=(16, 10))
    
    # 1. Residuals
    ax1 = fig.add_subplot(221)
    models = [r['model'] for r in results]
    residuals = [r['max_residual'] for r in results]
    colors = ['green' if r < 0.01 else 'orange' if r < 0.05 else 'red' for r in residuals]
    ax1.barh(models, residuals, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_xlabel('Residual')
    ax1.set_title('Model Residuals', fontweight='bold')
    
    # 2. Parameters vs Quality
    ax2 = fig.add_subplot(222)
    params = [r['n_params'] for r in results]
    quality = [r['quality_score'] for r in results]
    scatter = ax2.scatter(params, quality, s=200, c=quality, cmap='RdYlGn', 
                          vmin=0, vmax=1, edgecolors='black', linewidths=2)
    for i, m in enumerate(models):
        ax2.annotate(m[:8], (params[i]+0.1, quality[i]), fontsize=8)
    ax2.set_xlabel('Number of Parameters')
    ax2.set_ylabel('Quality Score')
    ax2.set_title('Complexity vs Quality', fontweight='bold')
    plt.colorbar(scatter, ax=ax2)
    
    # 3. Regime distribution
    ax3 = fig.add_subplot(223)
    regimes = [r['regime'] for r in results]
    regime_counts = {r: regimes.count(r) for r in set(regimes)}
    ax3.pie(regime_counts.values(), labels=regime_counts.keys(), autopct='%1.0f%%',
            colors=['lightgreen', 'lightblue', 'lightyellow', 'lightcoral'])
    ax3.set_title('Regime Distribution', fontweight='bold')
    
    # 4. Summary table
    ax4 = fig.add_subplot(224)
    ax4.axis('off')
    best = min(results, key=lambda r: r['max_residual'])
    summary = f"""
    MODEL ZOO COMPARISON
    ====================
    Total Models: {len(results)}
    
    BEST MODEL: {best['model']}
    - Residual: {best['max_residual']:.6f}
    - Quality: {best['quality_score']:.3f}
    - Parameters: {best['n_params']}
    
    RECOMMENDATION:
    Start with simplest model that fits,
    add complexity only if residuals high.
    """
    ax4.text(0.1, 0.9, summary, fontsize=11, family='monospace', va='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    plt.tight_layout()
    return fig

print('Validation suite loaded: run_validation_suite(), plot_validation_results(), plot_model_zoo_comparison()')


In [None]:
#@title Real Physics Inversion Module (Lens Equation + Model Zoo)
# This is the REAL inversion, not heuristics.
# Key: beta = theta - alpha(theta; params) - source positions must coincide!

import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Tuple
from enum import Enum
import json
import os
from datetime import datetime

class InputMode(Enum):
    QUAD = "quad"
    RING = "ring"
    DOUBLE = "double"

@dataclass
class QuicklookResult:
    """Output from Quicklook that feeds into Physics Inversion."""
    estimated_center: Tuple[float, float]
    estimated_theta_E: float
    angle_ordering: List[int]  # indices sorted by phi
    mode: InputMode
    initial_model_hints: List[str]  # e.g., ["m2", "m2_shear"] based on harmonic signature

@dataclass
class SourceConsistency:
    """Source position consistency check - the KEY diagnostic."""
    beta_positions: np.ndarray  # (n, 2) - source from each image
    beta_mean: np.ndarray       # (2,) - mean source
    beta_scatter: float         # RMS deviation
    max_deviation: float
    is_consistent: bool

@dataclass
class InversionResult:
    """Result from real physics inversion."""
    model_name: str
    params: Dict[str, float]
    source_position: np.ndarray
    source_consistency: SourceConsistency
    residuals: np.ndarray
    max_residual: float
    rms_residual: float
    is_exact: bool
    regime: str
    message: str

# ============================================================================
# DEFLECTION MODELS (Linear component form: c_m, s_m, gamma1, gamma2)
# ============================================================================

def deflection_m2(theta, params):
    """Quadrupole deflection: alpha = theta_E^2/r * r_hat + m2 terms"""
    x, y = theta
    r = np.sqrt(x**2 + y**2)
    if r < 1e-10: return np.zeros(2)
    phi = np.arctan2(y, x)
    theta_E = params.get('theta_E', 1.0)
    c2 = params.get('c2', 0.0)
    s2 = params.get('s2', 0.0)
    
    alpha_r = theta_E**2 / r + c2 * np.cos(2*phi) + s2 * np.sin(2*phi)
    return np.array([alpha_r * np.cos(phi), alpha_r * np.sin(phi)])

def deflection_m2_shear(theta, params):
    """Quadrupole + external shear."""
    x, y = theta
    r = np.sqrt(x**2 + y**2)
    if r < 1e-10: return np.zeros(2)
    phi = np.arctan2(y, x)
    theta_E = params.get('theta_E', 1.0)
    c2, s2 = params.get('c2', 0.0), params.get('s2', 0.0)
    gamma1, gamma2 = params.get('gamma1', 0.0), params.get('gamma2', 0.0)
    
    alpha_r = theta_E**2 / r + c2 * np.cos(2*phi) + s2 * np.sin(2*phi)
    ax = alpha_r * np.cos(phi) + gamma1 * x + gamma2 * y
    ay = alpha_r * np.sin(phi) - gamma1 * y + gamma2 * x
    return np.array([ax, ay])

def deflection_m2_m3(theta, params):
    """Quadrupole + octupole (m=3)."""
    x, y = theta
    r = np.sqrt(x**2 + y**2)
    if r < 1e-10: return np.zeros(2)
    phi = np.arctan2(y, x)
    theta_E = params.get('theta_E', 1.0)
    c2, s2 = params.get('c2', 0.0), params.get('s2', 0.0)
    c3, s3 = params.get('c3', 0.0), params.get('s3', 0.0)
    
    alpha_r = theta_E**2 / r
    alpha_r += c2 * np.cos(2*phi) + s2 * np.sin(2*phi)
    alpha_r += c3 * np.cos(3*phi) + s3 * np.sin(3*phi)
    return np.array([alpha_r * np.cos(phi), alpha_r * np.sin(phi)])

def deflection_m2_m4(theta, params):
    """Quadrupole + hexadecapole (m=4)."""
    x, y = theta
    r = np.sqrt(x**2 + y**2)
    if r < 1e-10: return np.zeros(2)
    phi = np.arctan2(y, x)
    theta_E = params.get('theta_E', 1.0)
    c2, s2 = params.get('c2', 0.0), params.get('s2', 0.0)
    c4, s4 = params.get('c4', 0.0), params.get('s4', 0.0)
    
    alpha_r = theta_E**2 / r
    alpha_r += c2 * np.cos(2*phi) + s2 * np.sin(2*phi)
    alpha_r += c4 * np.cos(4*phi) + s4 * np.sin(4*phi)
    return np.array([alpha_r * np.cos(phi), alpha_r * np.sin(phi)])

DEFLECTION_MODELS = {
    'm2': (deflection_m2, ['theta_E', 'c2', 's2']),
    'm2_shear': (deflection_m2_shear, ['theta_E', 'c2', 's2', 'gamma1', 'gamma2']),
    'm2_m3': (deflection_m2_m3, ['theta_E', 'c2', 's2', 'c3', 's3']),
    'm2_m4': (deflection_m2_m4, ['theta_E', 'c2', 's2', 'c4', 's4']),
}

# ============================================================================
# INVERSION FUNCTIONS
# ============================================================================

def compute_source_positions(theta, deflection_func, params):
    """Compute beta_i = theta_i - alpha(theta_i; p) for each image."""
    n = len(theta)
    beta = np.zeros_like(theta)
    for i in range(n):
        alpha = deflection_func(theta[i], params)
        beta[i] = theta[i] - alpha
    return beta

def check_source_consistency(beta, tolerance=1e-6):
    """Check if all source positions coincide - THE key diagnostic."""
    beta_mean = np.mean(beta, axis=0)
    residuals = np.sqrt(np.sum((beta - beta_mean)**2, axis=1))
    scatter = np.sqrt(np.mean(residuals**2))
    max_dev = np.max(residuals)
    return SourceConsistency(
        beta_positions=beta,
        beta_mean=beta_mean,
        beta_scatter=scatter,
        max_deviation=max_dev,
        is_consistent=max_dev < tolerance,
        per_image_residuals=residuals
    )

def build_linear_system_quad(theta, model_name):
    """Build Ax=b for quad inversion."""
    _, param_names = DEFLECTION_MODELS[model_name]
    n_images = len(theta)
    n_constraints = 2 * n_images
    n_params = 2 + len(param_names)  # beta_x, beta_y + lens params
    
    full_names = ['beta_x', 'beta_y'] + param_names
    A = np.zeros((n_constraints, n_params))
    b = np.zeros(n_constraints)
    
    for i in range(n_images):
        x, y = theta[i]
        phi = np.arctan2(y, x)
        row_x, row_y = 2*i, 2*i + 1
        
        A[row_x, 0] = 1.0  # beta_x
        A[row_y, 1] = 1.0  # beta_y
        b[row_x] = x
        b[row_y] = y
        
        cos_phi, sin_phi = np.cos(phi), np.sin(phi)
        
        if 'theta_E' in param_names:
            idx = full_names.index('theta_E')
            A[row_x, idx] = -cos_phi
            A[row_y, idx] = -sin_phi
        if 'c2' in param_names:
            idx = full_names.index('c2')
            A[row_x, idx] = -np.cos(2*phi) * cos_phi
            A[row_y, idx] = -np.cos(2*phi) * sin_phi
        if 's2' in param_names:
            idx = full_names.index('s2')
            A[row_x, idx] = -np.sin(2*phi) * cos_phi
            A[row_y, idx] = -np.sin(2*phi) * sin_phi
        if 'gamma1' in param_names:
            idx = full_names.index('gamma1')
            A[row_x, idx] = -x
            A[row_y, idx] = y
        if 'gamma2' in param_names:
            idx = full_names.index('gamma2')
            A[row_x, idx] = -y
            A[row_y, idx] = -x
        if 'c3' in param_names:
            idx = full_names.index('c3')
            A[row_x, idx] = -np.cos(3*phi) * cos_phi
            A[row_y, idx] = -np.cos(3*phi) * sin_phi
        if 's3' in param_names:
            idx = full_names.index('s3')
            A[row_x, idx] = -np.sin(3*phi) * cos_phi
            A[row_y, idx] = -np.sin(3*phi) * sin_phi
        if 'c4' in param_names:
            idx = full_names.index('c4')
            A[row_x, idx] = -np.cos(4*phi) * cos_phi
            A[row_y, idx] = -np.cos(4*phi) * sin_phi
        if 's4' in param_names:
            idx = full_names.index('s4')
            A[row_x, idx] = -np.sin(4*phi) * cos_phi
            A[row_y, idx] = -np.sin(4*phi) * sin_phi
    
    return A, b, full_names

def solve_exact(A, b):
    """Solve Ax=b via Gaussian elimination (no scipy)."""
    n = A.shape[1]
    m = A.shape[0]
    if m < n: return np.zeros(n), False
    
    Ab = np.hstack([A.astype(float), b.reshape(-1, 1).astype(float)])
    for col in range(min(m, n)):
        max_row = col + np.argmax(np.abs(Ab[col:m, col]))
        if abs(Ab[max_row, col]) < 1e-15: return np.zeros(n), False
        Ab[[col, max_row]] = Ab[[max_row, col]]
        for row in range(col + 1, m):
            factor = Ab[row, col] / Ab[col, col]
            Ab[row, col:] -= factor * Ab[col, col:]
    
    x = np.zeros(n)
    for row in range(min(m, n) - 1, -1, -1):
        if abs(Ab[row, row]) < 1e-15: return np.zeros(n), False
        x[row] = (Ab[row, -1] - np.dot(Ab[row, row+1:n], x[row+1:n])) / Ab[row, row]
    return x, True

def invert_quad(theta, model_name='m2', tolerance=1e-10):
    """Full quad inversion for a given model."""
    A, b, param_names = build_linear_system_quad(theta, model_name)
    
    # Analyze system
    U, s, Vt = np.linalg.svd(A, full_matrices=False)
    tol_rank = max(A.shape) * np.finfo(float).eps * s[0] if len(s) > 0 else 1e-15
    rank = np.sum(s > tol_rank)
    condition = s[0] / s[-1] if s[-1] > tol_rank else float('inf')
    
    m, n = A.shape
    if m == n and rank == n: regime = 'determined'
    elif m > n and rank == n: regime = 'overdetermined'
    elif rank < n: regime = 'underdetermined'
    else: regime = 'unknown'
    if condition > 1e10: regime = 'ill_conditioned'
    
    x, success = solve_exact(A, b)
    residuals = A @ x - b
    max_res = np.max(np.abs(residuals))
    rms_res = np.sqrt(np.mean(residuals**2))
    
    params = {name: x[i] for i, name in enumerate(param_names)}
    source_pos = np.array([params['beta_x'], params['beta_y']])
    
    deflection_func, _ = DEFLECTION_MODELS[model_name]
    beta_positions = compute_source_positions(theta, deflection_func, params)
    consistency = check_source_consistency(beta_positions, tolerance)
    
    is_exact = max_res < tolerance
    msg = f"EXACT (res={max_res:.2e})" if is_exact else f"Approx (res={max_res:.2e})"
    
    return InversionResult(
        model_name=model_name,
        params=params,
        source_position=source_pos,
        source_consistency=consistency,
        residuals=residuals,
        max_residual=max_res,
        rms_residual=rms_res,
        is_exact=is_exact,
        regime=regime,
        message=msg
    )

def run_model_zoo(theta, models=None):
    """Run all models in zoo and compare."""
    if models is None:
        models = list(DEFLECTION_MODELS.keys())
    results = []
    for model in models:
        try:
            result = invert_quad(theta, model)
            results.append(result)
        except Exception as e:
            print(f"Model {model} failed: {e}")
    results.sort(key=lambda r: r.max_residual)
    return results

def quicklook_to_inversion(positions, ring_result):
    """Extract initial guesses from Quicklook for Physics Inversion."""
    rel = positions - np.array([ring_result.center_x, ring_result.center_y])
    phi = np.arctan2(rel[:, 1], rel[:, 0])
    ordering = np.argsort(phi).tolist()
    
    hints = ['m2']
    if ring_result.m2_component[0] > 0.01: hints.append('m2_shear')
    if ring_result.m4_component[0] > 0.01: hints.append('m2_m4')
    
    n = len(positions)
    mode = InputMode.QUAD if n == 4 else (InputMode.DOUBLE if n == 2 else InputMode.RING)
    
    return QuicklookResult(
        estimated_center=(ring_result.center_x, ring_result.center_y),
        estimated_theta_E=ring_result.radius,
        angle_ordering=ordering,
        mode=mode,
        initial_model_hints=hints
    )

def save_run(positions, quicklook, inversion_results, run_name=None):
    """Save run to runs/<timestamp>_<name>/"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_name = run_name or "analysis"
    run_dir = f"runs/{timestamp}_{run_name}"
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(f"{run_dir}/solutions", exist_ok=True)
    os.makedirs(f"{run_dir}/figures", exist_ok=True)
    
    # Save input
    np.savetxt(f"{run_dir}/input_positions.csv", positions, delimiter=',', 
               header='x,y', comments='')
    
    # Save quicklook
    with open(f"{run_dir}/quicklook.json", 'w') as f:
        json.dump({
            'center': quicklook.estimated_center,
            'theta_E': quicklook.estimated_theta_E,
            'ordering': quicklook.angle_ordering,
            'mode': quicklook.mode.value,
            'hints': quicklook.initial_model_hints
        }, f, indent=2)
    
    # Save solutions
    for r in inversion_results:
        with open(f"{run_dir}/solutions/{r.model_name}.json", 'w') as f:
            json.dump({
                'params': r.params,
                'source_position': r.source_position.tolist(),
                'max_residual': r.max_residual,
                'rms_residual': r.rms_residual,
                'is_exact': r.is_exact,
                'regime': r.regime,
                'beta_scatter': r.source_consistency.beta_scatter
            }, f, indent=2)
    
    # Save report
    best = inversion_results[0] if inversion_results else None
    with open(f"{run_dir}/report.md", 'w') as f:
        f.write(f"# Inversion Report\n")
        f.write(f"Timestamp: {timestamp}\n\n")
        f.write(f"## Quicklook\n")
        f.write(f"- Center: {quicklook.estimated_center}\n")
        f.write(f"- Theta_E: {quicklook.estimated_theta_E:.4f}\n")
        f.write(f"- Mode: {quicklook.mode.value}\n\n")
        f.write(f"## Model Zoo Results\n")
        for r in inversion_results:
            f.write(f"- **{r.model_name}**: res={r.max_residual:.2e}, exact={r.is_exact}\n")
        if best:
            f.write(f"\n## Best Model: {best.model_name}\n")
            f.write(f"- Residual: {best.max_residual:.2e}\n")
            f.write(f"- Source: ({best.source_position[0]:.4f}, {best.source_position[1]:.4f})\n")
    
    return run_dir

print("Physics Inversion Module loaded!")
print("Functions: invert_quad(), run_model_zoo(), save_run()")


In [None]:
#@title Unit System: Auto-Scaling + Cosmology
# Internal units: rad, m, s, kg
# External: auto-scaled to human-readable

import numpy as np
from dataclasses import dataclass
from typing import Tuple, Optional, Dict

# === CONSTANTS ===
ARCSEC_TO_RAD = np.pi / (180 * 3600)
MAS_TO_RAD = ARCSEC_TO_RAD / 1000
MUAS_TO_RAD = MAS_TO_RAD / 1000
DEG_TO_RAD = np.pi / 180

AU_TO_M = 1.495978707e11
LY_TO_M = 9.4607304725808e15
PC_TO_M = 3.0856775814913673e16
KPC_TO_M = PC_TO_M * 1e3
MPC_TO_M = PC_TO_M * 1e6
GPC_TO_M = PC_TO_M * 1e9

C_M_S = 299792458.0
G_SI = 6.67430e-11
MSUN_TO_KG = 1.98892e30

DISTANCE_UNITS = {'m':1, 'km':1e3, 'AU':AU_TO_M, 'ly':LY_TO_M, 'pc':PC_TO_M, 'kpc':KPC_TO_M, 'Mpc':MPC_TO_M, 'Gpc':GPC_TO_M}
ANGLE_UNITS = {'rad':1, 'deg':DEG_TO_RAD, 'arcsec':ARCSEC_TO_RAD, 'mas':MAS_TO_RAD, 'uas':MUAS_TO_RAD}

@dataclass
class FormattedValue:
    internal_value: float
    display_value: float
    display_unit: str
    display_string: str
    internal_unit: str
    alternatives: Dict[str, float] = None
    def __str__(self): return self.display_string
    def to_dict(self):
        d = {'value': self.internal_value, 'unit': self.internal_unit, 'display': self.display_string}
        if self.alternatives: d['alternatives'] = self.alternatives
        return d

def format_angle(rad, precision=4):
    """Auto-scale angle: < 1e-9 rad -> uas, < 1e-6 -> mas, else arcsec"""
    abs_r = abs(rad)
    if abs_r < 1e-9: unit, factor = 'uas', MUAS_TO_RAD
    elif abs_r < 1e-6: unit, factor = 'mas', MAS_TO_RAD
    else: unit, factor = 'arcsec', ARCSEC_TO_RAD
    val = rad / factor
    alts = {'rad': rad, 'arcsec': rad/ARCSEC_TO_RAD, 'mas': rad/MAS_TO_RAD, 'uas': rad/MUAS_TO_RAD}
    return FormattedValue(rad, val, unit, f"{val:.{precision}g} {unit}", 'rad', alts)

def format_distance(meters, precision=4):
    """Auto-scale distance: AU/ly/pc/kpc/Mpc/Gpc"""
    abs_m = abs(meters)
    if abs_m < 1e8: unit, factor = 'km', 1e3
    elif abs_m < 1e11: unit, factor = 'AU', AU_TO_M
    elif abs_m < 1e16: unit, factor = 'ly', LY_TO_M
    elif abs_m < 1e19: unit, factor = 'pc', PC_TO_M
    elif abs_m < 1e22: unit, factor = 'kpc', KPC_TO_M
    elif abs_m < 1e25: unit, factor = 'Mpc', MPC_TO_M
    else: unit, factor = 'Gpc', GPC_TO_M
    val = meters / factor
    alts = {u: meters/f for u, f in DISTANCE_UNITS.items()}
    return FormattedValue(meters, val, unit, f"{val:.{precision}g} {unit}", 'm', alts)

def format_radius(meters, precision=4):
    """Auto-scale radius (R_E): AU/pc/kpc"""
    abs_m = abs(meters)
    if abs_m < 1e6: unit, factor = 'km', 1e3
    elif abs_m < 1e14: unit, factor = 'AU', AU_TO_M
    elif abs_m < 1e19: unit, factor = 'pc', PC_TO_M
    else: unit, factor = 'kpc', KPC_TO_M
    val = meters / factor
    return FormattedValue(meters, val, unit, f"{val:.{precision}g} {unit}", 'm')

def schwarzschild_radius(mass_kg):
    return 2 * G_SI * mass_kg / (C_M_S ** 2)

# === COSMOLOGY ===
@dataclass
class Cosmology:
    name: str; H0: float; Omega_m: float; Omega_L: float
    @property
    def Omega_k(self): return 1.0 - self.Omega_m - self.Omega_L

PLANCK18 = Cosmology('Planck18', 67.4, 0.315, 0.685)
PLANCK15 = Cosmology('Planck15', 67.74, 0.3089, 0.6911)
WMAP9 = Cosmology('WMAP9', 69.32, 0.2865, 0.7135)
COSMOLOGIES = {'Planck18': PLANCK18, 'Planck15': PLANCK15, 'WMAP9': WMAP9}

def E_z(z, cosmo):
    return np.sqrt(cosmo.Omega_m*(1+z)**3 + cosmo.Omega_k*(1+z)**2 + cosmo.Omega_L)

def comoving_distance(z, cosmo, n=1000):
    if z <= 0: return 0.0
    H0_1s = cosmo.H0 * 1e3 / (PC_TO_M * 1e6)
    D_H = C_M_S / H0_1s
    z_arr = np.linspace(0, z, n+1)
    dz = z / n
    integrand = 1.0 / E_z(z_arr, cosmo)
    integral = (dz/3) * (integrand[0] + 4*np.sum(integrand[1:-1:2]) + 2*np.sum(integrand[2:-1:2]) + integrand[-1])
    return D_H * integral

def angular_diameter_distance(z, cosmo):
    return comoving_distance(z, cosmo) / (1 + z)

def angular_diameter_distance_z1_z2(z1, z2, cosmo):
    return (comoving_distance(z2, cosmo) - comoving_distance(z1, cosmo)) / (1 + z2)

def lensing_distances(z_L, z_S, cosmo=PLANCK18):
    """Returns D_L, D_S, D_LS in meters."""
    D_L = angular_diameter_distance(z_L, cosmo)
    D_S = angular_diameter_distance(z_S, cosmo)
    D_LS = angular_diameter_distance_z1_z2(z_L, z_S, cosmo)
    return D_L, D_S, D_LS

def einstein_radius_from_mass(mass_kg, D_L, D_S, D_LS):
    """theta_E in radians from mass."""
    return np.sqrt(4 * G_SI * mass_kg / C_M_S**2 * D_LS / (D_L * D_S))

def mass_from_einstein_radius(theta_E_rad, D_L, D_S, D_LS):
    """Mass in kg from theta_E."""
    return C_M_S**2 / (4 * G_SI) * theta_E_rad**2 * D_L * D_S / D_LS

def parse_distance_input(value, unit):
    """Parse distance input to meters."""
    if unit not in DISTANCE_UNITS:
        raise ValueError(f"Unknown unit: {unit}")
    return value * DISTANCE_UNITS[unit]

print("Unit System loaded!")
print("  format_angle(), format_distance(), format_radius()")
print("  lensing_distances(z_L, z_S, cosmo)")
print("  Cosmologies: Planck18, Planck15, WMAP9")


In [None]:
#@title 3D Lensing Scene Module
# Final visualization: Observer - Lens - Source with distances and radii

import numpy as np
from dataclasses import dataclass, field
from typing import Tuple, Optional, Dict
import json as json_module
from mpl_toolkits.mplot3d import Axes3D

@dataclass
class Vec3:
    x: float
    y: float
    z: float
    def to_list(self): return [self.x, self.y, self.z]
    def norm(self): return np.sqrt(self.x**2 + self.y**2 + self.z**2)
    @classmethod
    def from_array(cls, arr): return cls(float(arr[0]), float(arr[1]), float(arr[2]))

@dataclass
class Scene3D:
    """3D Lensing Scene: Observer at (0,0,0), Lens at (0,0,D_L), Source at (beta*D_S, D_S)"""
    D_L: float
    D_S: float
    theta_E: float
    beta: Tuple[float, float] = (0.0, 0.0)
    units: str = 'normalized'
    z_L: Optional[float] = None
    z_S: Optional[float] = None
    lens_mass: Optional[float] = None
    observer: Vec3 = field(default_factory=lambda: Vec3(0, 0, 0))
    lens: Vec3 = field(init=False)
    source: Vec3 = field(init=False)
    
    def __post_init__(self):
        self.lens = Vec3(0, 0, self.D_L)
        self.source = Vec3(self.beta[0] * self.D_S, self.beta[1] * self.D_S, self.D_S)
    
    @property
    def D_LS(self): return self.D_S - self.D_L
    @property
    def R_E(self): return self.D_L * self.theta_E
    @property
    def beta_magnitude(self): return np.sqrt(self.beta[0]**2 + self.beta[1]**2)
    @property
    def R_beta(self): return self.D_S * self.beta_magnitude
    
    def to_dict(self):
        return {
            'observer': self.observer.to_list(), 'lens': self.lens.to_list(),
            'source': self.source.to_list(), 'D_L': self.D_L, 'D_S': self.D_S,
            'D_LS': self.D_LS, 'theta_E': self.theta_E, 'R_E': self.R_E,
            'beta': list(self.beta), 'beta_magnitude': self.beta_magnitude,
            'R_beta': self.R_beta, 'units': self.units
        }

def plot_scene_3d(scene, image_positions=None, ax=None):
    """3D perspective view of lensing geometry."""
    if ax is None:
        fig = plt.figure(figsize=(12, 9))
        ax = fig.add_subplot(111, projection='3d')
    else:
        fig = ax.figure
    
    O = np.array([0, 0, 0])
    L = np.array([0, 0, scene.D_L])
    S = np.array([scene.source.x, scene.source.y, scene.D_S])
    
    ax.scatter(*O, c='blue', s=200, marker='o', label='Observer (O)', zorder=10)
    ax.scatter(*L, c='red', s=300, marker='*', label='Lens (L)', zorder=10)
    ax.scatter(*S, c='green', s=200, marker='s', label='Source (S)', zorder=10)
    
    ax.plot([O[0], L[0]], [O[1], L[1]], [O[2], L[2]], 'b-', lw=2, alpha=0.7)
    ax.plot([L[0], S[0]], [L[1], S[1]], [L[2], S[2]], 'g-', lw=2, alpha=0.7)
    ax.plot([O[0], S[0]], [O[1], S[1]], [O[2], S[2]], 'k--', lw=1, alpha=0.5)
    
    mid_OL = (O + L) / 2
    mid_LS = (L + S) / 2
    ax.text(mid_OL[0]-0.1, mid_OL[1], mid_OL[2], f'D_L={scene.D_L:.2f}', fontsize=10, color='blue')
    ax.text(mid_LS[0]+0.1, mid_LS[1], mid_LS[2], f'D_LS={scene.D_LS:.2f}', fontsize=10, color='green')
    
    plane_size = max(scene.R_E * 3, 0.3)
    xx, yy = np.meshgrid(np.linspace(-plane_size, plane_size, 2), np.linspace(-plane_size, plane_size, 2))
    ax.plot_surface(xx, yy, np.full_like(xx, scene.D_L), alpha=0.1, color='red')
    ax.plot_surface(xx, yy, np.full_like(xx, scene.D_S), alpha=0.1, color='green')
    
    if scene.R_E > 0:
        theta = np.linspace(0, 2*np.pi, 100)
        ax.plot(scene.R_E * np.cos(theta), scene.R_E * np.sin(theta), 
                np.full(100, scene.D_L), 'r-', lw=2, label=f'Einstein ring (R_E={scene.R_E:.4f})')
    
    if image_positions is not None:
        colors = plt.cm.Set1(np.linspace(0, 1, len(image_positions)))
        for i, (theta_pos, color) in enumerate(zip(image_positions, colors)):
            x_lens = scene.D_L * theta_pos[0]
            y_lens = scene.D_L * theta_pos[1]
            ax.plot([0, x_lens], [0, y_lens], [0, scene.D_L], color=color, lw=2, alpha=0.8)
            ax.scatter([x_lens], [y_lens], [scene.D_L], c=[color], s=100, marker='x')
    
    ax.set_xlabel(f'X [{scene.units}]')
    ax.set_ylabel(f'Y [{scene.units}]')
    ax.set_zlabel(f'Z (distance) [{scene.units}]')
    ax.set_title('3D Lensing Geometry: Observer - Lens - Source', fontweight='bold')
    ax.legend(loc='upper left', fontsize=9)
    
    max_range = scene.D_S * 1.1
    ax.set_xlim(-max_range/3, max_range/3)
    ax.set_ylim(-max_range/3, max_range/3)
    ax.set_zlim(0, max_range)
    return fig

def plot_scene_side(scene, image_positions=None, ax=None):
    """Side view (xz-plane) showing distances as ruler."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(14, 5))
    else:
        fig = ax.figure
    
    ax.scatter([0], [0], c='blue', s=200, marker='o', zorder=10)
    ax.annotate('Observer', (0, 0), xytext=(0, -0.12), ha='center', fontsize=10, fontweight='bold', color='blue')
    
    ax.scatter([scene.D_L], [0], c='red', s=300, marker='*', zorder=10)
    ax.annotate('Lens', (scene.D_L, 0), xytext=(scene.D_L, -0.12), ha='center', fontsize=10, fontweight='bold', color='red')
    
    source_y = scene.source.x
    ax.scatter([scene.D_S], [source_y], c='green', s=200, marker='s', zorder=10)
    ax.annotate('Source', (scene.D_S, source_y), xytext=(scene.D_S+0.03, source_y+0.05), fontsize=10, fontweight='bold', color='green')
    
    ax.axhline(y=0, color='gray', linestyle='--', lw=1, alpha=0.5)
    
    y_ruler = -0.25
    ax.annotate('', xy=(scene.D_L, y_ruler), xytext=(0, y_ruler), arrowprops=dict(arrowstyle='<->', color='blue', lw=2))
    ax.text(scene.D_L/2, y_ruler-0.04, f'D_L = {scene.D_L:.3f}', ha='center', fontsize=10, color='blue')
    
    ax.annotate('', xy=(scene.D_S, y_ruler-0.12), xytext=(0, y_ruler-0.12), arrowprops=dict(arrowstyle='<->', color='darkgreen', lw=2))
    ax.text(scene.D_S/2, y_ruler-0.16, f'D_S = {scene.D_S:.3f}', ha='center', fontsize=10, color='darkgreen')
    
    ax.annotate('', xy=(scene.D_S, y_ruler-0.24), xytext=(scene.D_L, y_ruler-0.24), arrowprops=dict(arrowstyle='<->', color='purple', lw=2))
    ax.text((scene.D_L+scene.D_S)/2, y_ruler-0.28, f'D_LS = {scene.D_LS:.3f}', ha='center', fontsize=10, color='purple')
    
    if scene.R_E > 0:
        ax.plot([scene.D_L, scene.D_L], [-scene.R_E, scene.R_E], 'r-', lw=3, alpha=0.8)
        ax.annotate(f'R_E={scene.R_E:.4f}', (scene.D_L, scene.R_E), xytext=(scene.D_L+0.03, scene.R_E), fontsize=9, color='red')
    
    if scene.beta_magnitude > 0:
        ax.plot([scene.D_S, scene.D_S], [0, source_y], 'g:', lw=2)
        ax.annotate(f'R_beta={scene.R_beta:.4f}', (scene.D_S, source_y/2), xytext=(scene.D_S+0.03, source_y/2), fontsize=9, color='green')
    
    if image_positions is not None:
        colors = plt.cm.Set1(np.linspace(0, 1, len(image_positions)))
        for theta_pos, color in zip(image_positions, colors):
            x_at_lens = scene.D_L * theta_pos[0]
            ax.plot([0, scene.D_L], [0, x_at_lens], color=color, lw=2, alpha=0.7)
    
    ax.set_xlabel(f'Distance along optical axis [{scene.units}]')
    ax.set_ylabel(f'Transverse [{scene.units}]')
    ax.set_title('Side View: Lensing Geometry with Distances', fontweight='bold')
    
    y_max = max(scene.R_E * 2, abs(source_y) * 1.5, 0.15)
    ax.set_xlim(-0.05 * scene.D_S, scene.D_S * 1.1)
    ax.set_ylim(y_ruler - 0.35, y_max)
    ax.grid(True, alpha=0.3)
    return fig

def save_scene3d(scene, run_dir, positions=None):
    """Save scene to run bundle."""
    import os
    os.makedirs(f"{run_dir}/figures", exist_ok=True)
    
    with open(f"{run_dir}/scene3d.json", 'w') as f:
        json_module.dump(scene.to_dict(), f, indent=2)
    
    fig_3d = plot_scene_3d(scene, positions)
    fig_3d.savefig(f"{run_dir}/figures/scene3d_perspective.png", dpi=150, bbox_inches='tight')
    plt.close(fig_3d)
    
    fig_side = plot_scene_side(scene, positions)
    fig_side.savefig(f"{run_dir}/figures/scene3d_sideview.png", dpi=150, bbox_inches='tight')
    plt.close(fig_side)
    
    return f"{run_dir}/scene3d.json"

print("3D Scene Module loaded!")
print("Classes: Scene3D, Vec3")
print("Functions: plot_scene_3d(), plot_scene_side(), save_scene3d()")


In [None]:
#@title Physics Inversion Visualizations

def plot_source_plane_consistency(inversion_result, ax=None):
    """Plot beta positions - source consistency check."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    else:
        fig = ax.figure
    
    beta = inversion_result.source_consistency.beta_positions
    beta_mean = inversion_result.source_consistency.beta_mean
    
    # Plot individual beta points
    colors = plt.cm.Set1(np.linspace(0, 1, len(beta)))
    for i, (b, c) in enumerate(zip(beta, colors)):
        ax.scatter([b[0]], [b[1]], c=[c], s=200, marker='o', 
                   label=f'Image {i+1}', edgecolors='black', linewidths=2, zorder=5)
        ax.plot([beta_mean[0], b[0]], [beta_mean[1], b[1]], 
                color=c, linestyle='--', alpha=0.5, lw=2)
    
    # Plot mean (reconstructed source)
    ax.scatter([beta_mean[0]], [beta_mean[1]], c='red', s=300, marker='*', 
               label='Mean (Source)', edgecolors='black', linewidths=2, zorder=10)
    
    # Consistency circle
    scatter = inversion_result.source_consistency.beta_scatter
    circle = plt.Circle(beta_mean, scatter, fill=False, color='red', 
                        linestyle=':', lw=2, label=f'RMS={scatter:.4f}')
    ax.add_patch(circle)
    
    ax.set_xlabel('beta_x (source plane)', fontsize=12)
    ax.set_ylabel('beta_y (source plane)', fontsize=12)
    ax.set_title(f'Source Consistency: {inversion_result.model_name}\n'
                 f'Scatter={scatter:.4f}, Exact={inversion_result.is_exact}',
                 fontsize=12, fontweight='bold')
    ax.set_aspect('equal')
    ax.legend(loc='upper right', fontsize=9)
    ax.grid(True, alpha=0.3)
    
    return fig

def plot_model_zoo_leaderboard(results):
    """Plot model comparison leaderboard."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    models = [r.model_name for r in results]
    residuals = [r.max_residual for r in results]
    scatters = [r.source_consistency.beta_scatter for r in results]
    
    # 1. Residuals (log scale)
    ax1 = axes[0]
    colors = ['green' if r < 1e-10 else 'orange' if r < 1e-6 else 'red' for r in residuals]
    bars = ax1.barh(models, residuals, color=colors, alpha=0.7, edgecolor='black')
    ax1.set_xscale('log')
    ax1.axvline(1e-10, color='green', ls='--', lw=2, alpha=0.5)
    ax1.set_xlabel('Max Residual (log)', fontsize=11)
    ax1.set_title('Model Residuals', fontweight='bold')
    
    # 2. Source scatter
    ax2 = axes[1]
    ax2.barh(models, scatters, color='steelblue', alpha=0.7, edgecolor='black')
    ax2.set_xlabel('Beta Scatter', fontsize=11)
    ax2.set_title('Source Consistency', fontweight='bold')
    
    # 3. Summary
    ax3 = axes[2]
    ax3.axis('off')
    best = results[0]
    summary = f"""
MODEL ZOO LEADERBOARD
=====================
Total models: {len(results)}

BEST MODEL: {best.model_name}
  Residual: {best.max_residual:.2e}
  Source scatter: {best.source_consistency.beta_scatter:.4f}
  Exact: {"YES" if best.is_exact else "NO"}
  Regime: {best.regime}

RANKING:
"""
    for i, r in enumerate(results[:5]):
        summary += f"  {i+1}. {r.model_name}: {r.max_residual:.2e}\n"
    
    ax3.text(0.1, 0.9, summary, fontsize=11, family='monospace', va='top',
             transform=ax3.transAxes, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    plt.tight_layout()
    return fig

def plot_image_plane_comparison(positions, inversion_result):
    """Plot observed vs predicted image positions."""
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Observed
    ax.scatter(positions[:, 0], positions[:, 1], c='blue', s=200, marker='o',
               label='Observed', edgecolors='black', linewidths=2, zorder=5)
    
    # Predicted (from source + deflection)
    deflection_func, _ = DEFLECTION_MODELS[inversion_result.model_name]
    params = inversion_result.params
    source = inversion_result.source_position
    
    # For each observed, show residual vector
    for i, theta_obs in enumerate(positions):
        alpha = deflection_func(theta_obs, params)
        beta_pred = theta_obs - alpha
        
        # Residual in source plane
        res = beta_pred - source
        
        # Draw arrow showing residual
        ax.annotate('', xy=(theta_obs[0] + res[0]*5, theta_obs[1] + res[1]*5),
                    xytext=(theta_obs[0], theta_obs[1]),
                    arrowprops=dict(arrowstyle='->', color='red', lw=2))
    
    # Einstein ring
    theta_E = params.get('theta_E', 1.0)
    theta = np.linspace(0, 2*np.pi, 100)
    ax.plot(theta_E * np.cos(theta), theta_E * np.sin(theta), 
            'g--', lw=2, alpha=0.5, label=f'Einstein ring (R={theta_E:.3f})')
    
    ax.scatter([0], [0], c='black', s=100, marker='+', lw=3, label='Lens center')
    ax.set_xlabel('x (image plane)', fontsize=12)
    ax.set_ylabel('y (image plane)', fontsize=12)
    ax.set_title(f'Image Plane: {inversion_result.model_name}\nResiduals x5 (red arrows)',
                 fontsize=12, fontweight='bold')
    ax.set_aspect('equal')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    return fig

print("Inversion visualizations loaded!")


In [None]:
import numpy as np
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# === ENUMERATIONS ===
class Morphology(Enum):
    RING = "ring"
    QUAD = "quad"
    ARC = "arc"
    DOUBLE = "double"
    UNKNOWN = "unknown"

class Regime(Enum):
    DETERMINED = "determined"
    OVERDETERMINED = "overdetermined"
    UNDERDETERMINED = "underdetermined"
    ILL_CONDITIONED = "ill_conditioned"

class ModelFamily(Enum):
    M2 = "m2"
    M2_SHEAR = "m2_shear"
    M2_M3 = "m2_m3"
    M2_SHEAR_M3 = "m2_shear_m3"
    M2_M4 = "m2_m4"
    M2_SHEAR_M4 = "m2_shear_m4"
    M2_M3_M4 = "m2_m3_m4"
    M2_SHEAR_M3_M4 = "m2_shear_m3_m4"

# === DATA CLASSES ===
@dataclass
class MorphologyAnalysis:
    primary: Morphology
    confidence: float
    mean_radius: float
    radial_scatter: float
    azimuthal_coverage: float
    azimuthal_uniformity: float
    m2_amplitude: float
    m4_amplitude: float
    recommended_models: List[str]
    notes: List[str]

@dataclass
class RingFitResult:
    center_x: float
    center_y: float
    radius: float
    radial_residuals: np.ndarray
    azimuthal_angles: np.ndarray
    rms_residual: float
    m2_component: Tuple[float, float]
    m3_component: Tuple[float, float]
    m4_component: Tuple[float, float]
    is_perturbed: bool
    perturbation_type: str

@dataclass
class RegimeAnalysis:
    regime: Regime
    n_constraints: int
    n_params: int
    rank: int
    nullspace_dim: int
    condition_number: float
    explanation: str = ""
    recommendations: List[str] = field(default_factory=list)

@dataclass
class ModelConfig:
    family: ModelFamily
    m_max: int
    include_shear: bool
    include_m3: bool = True
    include_m4: bool = False
    label: str = ""
    n_lens_params: int = 0

@dataclass
class Position3D:
    x: float = 0.0
    y: float = 0.0
    z: float = 0.0
    label: str = ""
    def to_array(self): return np.array([self.x, self.y, self.z])

@dataclass
class LensProperties:
    position: Position3D
    einstein_radius: float = 1.0
    ellipticity: float = 0.0
    position_angle: float = 0.0

@dataclass
class SourceProperties:
    position: Position3D
    source_id: int = 0

@dataclass
class TriadScene:
    name: str
    observer: Position3D = field(default_factory=lambda: Position3D(0, 0, 0, "Observer"))
    lens: LensProperties = None
    sources: List[SourceProperties] = field(default_factory=list)
    
    def __post_init__(self):
        if self.lens is None:
            self.lens = LensProperties(position=Position3D(0, 0, 1.0, "Lens"))
    
    def add_source(self, x, y, z, source_id=None):
        if source_id is None: source_id = len(self.sources)
        self.sources.append(SourceProperties(position=Position3D(x, y, z, f"Source_{source_id}"), source_id=source_id))
    
    @classmethod
    def create_standard(cls, name, D_L=1.0, D_S=2.0, beta_x=0.1, beta_y=-0.05, theta_E=1.0):
        scene = cls(name=name)
        scene.lens = LensProperties(position=Position3D(0, 0, D_L, "Lens"), einstein_radius=theta_E)
        scene.add_source(beta_x * D_S, beta_y * D_S, D_S)
        return scene

# === MODEL ZOO ===
MODEL_CONFIGS = {
    ModelFamily.M2: ModelConfig(ModelFamily.M2, 2, False, label="m=2 only", n_lens_params=3),
    ModelFamily.M2_SHEAR: ModelConfig(ModelFamily.M2_SHEAR, 2, True, label="m=2 + shear", n_lens_params=5),
    ModelFamily.M2_M3: ModelConfig(ModelFamily.M2_M3, 3, False, label="m=2 + m=3", n_lens_params=5),
    ModelFamily.M2_SHEAR_M3: ModelConfig(ModelFamily.M2_SHEAR_M3, 3, True, label="m=2 + shear + m=3", n_lens_params=7),
    ModelFamily.M2_M4: ModelConfig(ModelFamily.M2_M4, 4, False, include_m3=False, include_m4=True, label="m=2 + m=4", n_lens_params=5),
    ModelFamily.M2_SHEAR_M4: ModelConfig(ModelFamily.M2_SHEAR_M4, 4, True, include_m3=False, include_m4=True, label="m=2 + shear + m=4", n_lens_params=7),
    ModelFamily.M2_M3_M4: ModelConfig(ModelFamily.M2_M3_M4, 4, False, include_m3=True, include_m4=True, label="m=2 + m=3 + m=4", n_lens_params=7),
    ModelFamily.M2_SHEAR_M3_M4: ModelConfig(ModelFamily.M2_SHEAR_M3_M4, 4, True, include_m3=True, include_m4=True, label="MAXIMAL", n_lens_params=9),
}

def get_derivation_chain(include_m4=False):
    chain = [ModelFamily.M2, ModelFamily.M2_SHEAR, ModelFamily.M2_M3, ModelFamily.M2_SHEAR_M3]
    if include_m4:
        chain.extend([ModelFamily.M2_M4, ModelFamily.M2_SHEAR_M4, ModelFamily.M2_M3_M4, ModelFamily.M2_SHEAR_M3_M4])
    return chain

print("Classes and Model Zoo loaded")

# === GRADIO UI ===

#@title RSG/SSZ Lensing Suite - 4 Tab UI
import gradio as gr
import os
from datetime import datetime
import json as json_module

classifier = MorphologyClassifier()
ring_analyzer = RingAnalyzer()

QUAD_EX = """0.740, 0.565
-0.635, 0.470
-0.480, -0.755
0.870, -0.195"""

def parse_pos(text, unit='arcsec'):
    lines = [l.strip() for l in text.strip().split('\n') if l.strip()]
    pos = np.array([[float(x) for x in l.replace(',', ' ').split()[:2]] for l in lines])
    fac = {'arcsec': ARCSEC_TO_RAD, 'mas': MAS_TO_RAD, 'uas': MUAS_TO_RAD, 'rad': 1.0}
    return pos * fac.get(unit, ARCSEC_TO_RAD)

def quicklook_fn(pos_text, pos_unit, center_known, cx, cy, c_unit):
        pos = parse_pos(pos_text, pos_unit)
        n = len(pos)
        if n < 2: return "Need >= 2 pts", "", "", None, None
        ctr = np.array([cx, cy]) * ANGLE_UNITS.get(c_unit, ARCSEC_TO_RAD) if center_known else np.mean(pos, axis=0)
        classifier.center = ctr
        morph = classifier.classify(pos)
        ring = ring_analyzer.fit_ring(pos, initial_center=tuple(ctr))
        mode = "QUAD" if n==4 else ("DOUBLE" if n==2 else "RING/ARC")
        
        summary = f"## Summary\n| Metric | Value |\n|---|---|\n| Points | {n} |\n| Mode | {mode} |\n| Radius | {format_angle(ring.radius)} |\n| RMS | {format_angle(ring.rms_residual)} |"
        morph_txt = f"## Morphology: {morph.primary.value.upper()}\n- radial_scatter={morph.radial_scatter:.4f}\n- azimuthal_cov={morph.azimuthal_coverage:.2f}\n\n" + "\n".join(f"- {n}" for n in morph.notes)
        harm_txt = f"## Harmonics (DIAGNOSTIC)\n⚠️ Pattern descriptors, not lens params\n- m2: {ring.m2_component[0]:.6f}\n- m4: {ring.m4_component[0]:.6f}"
        
        fig, ax = plt.subplots(1, 2, figsize=(12, 5))
        p = pos/ARCSEC_TO_RAD
        t = np.linspace(0, 2*np.pi, 100)
        r = ring.radius/ARCSEC_TO_RAD
        c = ctr/ARCSEC_TO_RAD
        ax[0].plot(c[0]+r*np.cos(t), c[1]+r*np.sin(t), 'b--', lw=2, label=f'r={r:.3f}"')
        ax[0].scatter(p[:,0], p[:,1], c='red', s=100, zorder=5)
        for j, pt in enumerate(p): ax[0].annotate(chr(65+j), (pt[0]+0.02, pt[1]+0.02), fontweight='bold')
        ax[0].set_aspect('equal'); ax[0].set_title(f'Quicklook: {mode}'); ax[0].legend(); ax[0].grid(alpha=0.3)
        ax[1].scatter(np.degrees(ring.azimuthal_angles), ring.radial_residuals/ARCSEC_TO_RAD*1000)
        ax[1].axhline(0, color='gray', ls='--'); ax[1].set_title('Residuals (DIAGNOSTIC)'); ax[1].grid(alpha=0.3)
        plt.tight_layout()
        return summary, morph_txt, harm_txt, fig, {'pos': pos, 'ring': ring, 'morph': morph, 'mode': mode, 'n': n}
    except Exception as e:
        import traceback
        return str(e), traceback.format_exc(), "", None, None

def inversion_fn(ql_state, m2, shear, m3, m4):
    if ql_state is None: return "Run Quicklook first", "", None, None
        pos = ql_state['pos']
        models = []
        if m2: models.append('m2')
        if shear: models.append('m2_shear')
        if m3: models.append('m2_m3')
        if m4: models.append('m2_m4')
        if not models: return "Select models", "", None, None
        
        results = run_model_zoo(pos, models)
        if not results: return "No results", "", None, None
        best = results[0]
        
        # Regime gate
        A, b, names = build_system(pos, best.model_name)
            _, s, _ = np.linalg.svd(A, full_matrices=False)
            rank = int(np.sum(s > max(A.shape)*np.finfo(float).eps*s[0]))
            cond = s[0]/s[-1] if s[-1]>1e-15 else float('inf')
        except: rank, cond = A.shape[1], 1.0
        nullspace = A.shape[1] - rank
        
        lb = "## Leaderboard\n| Model | Residual | Exact |\n|---|---|---|\n"
        for r in results:
            lb += f"| {r.model_name} | {format_angle(r.max_residual)} | {'Y' if r.is_exact else 'N'} |\n"
        lb += f"\n### Regime\n- Constraints: {A.shape[0]}, Params: {A.shape[1]}, Rank: {rank}, Nullspace: {nullspace}"
        if nullspace > 0: lb += "\n⚠️ **Underdetermined** - add flux/time-delay"
        
        det = f"## Best: {best.model_name}\n| Param | Value |\n|---|---|\n"
        for k,v in best.params.items():
            det += f"| {k} | {format_angle(v) if k in ['theta_E','beta_x','beta_y'] else f'{v:.6f}'} |\n"
        cons = best.source_consistency
        det += f"\n### β Consistency\n- Scatter: {format_angle(cons.beta_scatter)}\n- Consistent: {'✓' if cons.is_consistent else '✗'}"
        
        fig, ax = plt.subplots(1, 2, figsize=(12, 5))
        p = pos/ARCSEC_TO_RAD
        ax[0].scatter(p[:,0], p[:,1], c='red', s=100); ax[0].set_aspect('equal'); ax[0].set_title('Image Plane')
        t = np.linspace(0, 2*np.pi, 100)
        theta_E = best.params.get('theta_E', 0.1)/ARCSEC_TO_RAD
        ax[0].plot(theta_E*np.cos(t), theta_E*np.sin(t), 'b--', label=f'θ_E')
        ax[0].legend(); ax[0].grid(alpha=0.3)
        
        beta = cons.beta_positions/ARCSEC_TO_RAD
        bm = cons.beta_mean/ARCSEC_TO_RAD
        ax[1].scatter(beta[:,0], beta[:,1], c='blue', s=100)
        ax[1].scatter([bm[0]], [bm[1]], c='red', s=150, marker='*', label='Mean β')
        ax[1].set_aspect('equal'); ax[1].set_title(f'Source: scatter={format_angle(cons.beta_scatter)}')
        ax[1].legend(); ax[1].grid(alpha=0.3)
        plt.tight_layout()
        
        return lb, det, fig, {'results': results, 'best': best, 'pos': pos}
    except Exception as e:
        import traceback
        return str(e), traceback.format_exc(), None, None

def scene_fn(dist_mode, d_L, d_L_u, d_S, d_S_u, z_L, z_S, cosmo_name, mass, ql_state, inv_state):
        # Get angles
        if inv_state:
            theta_E = inv_state['best'].params.get('theta_E', 1.0*ARCSEC_TO_RAD)
            beta_x = inv_state['best'].params.get('beta_x', 0.0)
            beta_y = inv_state['best'].params.get('beta_y', 0.0)
            positions = inv_state['pos']
        elif ql_state:
            theta_E = ql_state['ring'].radius
            beta_x, beta_y = 0.0, 0.0
            positions = ql_state['pos']
        else:
            theta_E = 1.0*ARCSEC_TO_RAD
            beta_x, beta_y = 0.1*ARCSEC_TO_RAD, 0.05*ARCSEC_TO_RAD
            positions = None
        beta_mag = np.sqrt(beta_x**2 + beta_y**2)
        
        # Distances
        normalized = dist_mode == "Normalized"
        if normalized:
            D_L_m, D_S_m = 1.0, 2.0
            mode_str = "**NORMALIZED** (sizes not physical)"
        elif dist_mode == "Direct distances":
            D_L_m = d_L * DISTANCE_UNITS[d_L_u]
            D_S_m = d_S * DISTANCE_UNITS[d_S_u]
            mode_str = f"Direct: D_L={d_L} {d_L_u}, D_S={d_S} {d_S_u}"
        else:
            cosmo = COSMOLOGIES[cosmo_name]
            D_L_m, D_S_m, _ = lensing_distances(z_L, z_S, cosmo)
            mode_str = f"z_L={z_L}, z_S={z_S} ({cosmo_name})"
        
        D_LS_m = D_S_m - D_L_m
        R_E = D_L_m * theta_E
        R_beta = D_S_m * beta_mag
        
        # Format
        if normalized:
            fmt = lambda x, n: f"{x:.4g} (norm)"
        else:
            fmt = lambda x, n: format_distance(x).display_string if n=='d' else format_radius(x).display_string
        
        units = f"""## Units & Scales
**Mode:** {mode_str}

### Distances
| Qty | Value | [m] |
|---|---|---|
| D_L | {fmt(D_L_m, 'd') if not normalized else f'{D_L_m} (norm)'} | {D_L_m:.4e} |
| D_S | {fmt(D_S_m, 'd') if not normalized else f'{D_S_m} (norm)'} | {D_S_m:.4e} |
| D_LS | {fmt(D_LS_m, 'd') if not normalized else f'{D_LS_m} (norm)'} | {D_LS_m:.4e} |

### Angles
| Qty | Value | [rad] |
|---|---|---|
| θ_E | {format_angle(theta_E)} | {theta_E:.4e} |
| |β| | {format_angle(beta_mag)} | {beta_mag:.4e} |

### Radii
| Qty | Formula | Value |
|---|---|---|
| R_E | D_L×θ_E | {fmt(R_E, 'r') if not normalized else f'{R_E:.4g} (norm)'} |
| R_β | D_S×|β| | {fmt(R_beta, 'r') if not normalized else f'{R_beta:.4g} (norm)'} |
"""
        if mass and mass > 0:
            mass_kg = mass * MSUN_TO_KG
            r_s = schwarzschild_radius(mass_kg)
            r_s_s = f"{r_s/1e3:.4g} km" if r_s < 1e9 else f"{r_s/AU_TO_M:.4g} AU"
            ratio = R_E / r_s if r_s > 0 else float('inf')
            units += f"\n### Lens\n| M | {mass:.4g} M_sun |\n| r_s | {r_s_s} |\n| R_E/r_s | {ratio:.4g} |"
        
        scene = Scene3D(D_L=D_L_m, D_S=D_S_m, theta_E=theta_E, beta=(beta_x, beta_y), units='norm' if normalized else 'm')
        fig_3d = plot_scene_3d(scene, positions)
        fig_side = plot_scene_side(scene, positions)
        
        return units, fig_3d, fig_side, {'scene': scene, 'D_L': D_L_m, 'D_S': D_S_m, 'theta_E': theta_E, 'normalized': normalized}
    except Exception as e:
        import traceback
        return str(e) + traceback.format_exc(), None, None, None

def save_run(name, ql_state, inv_state, scene_state, pos_text, pos_unit, dist_mode):
    if ql_state is None: return "Run Quicklook first"
        ts = datetime.now().strftime("%Y%m%d_%H%M%S")
        safe = name.replace(' ', '_') if name else 'unnamed'
        run_dir = f"runs/{ts}_{safe}"
        os.makedirs(f"{run_dir}/solutions", exist_ok=True)
        os.makedirs(f"{run_dir}/figures", exist_ok=True)
        
        # input_snapshot.json
        snap = {'timestamp': ts, 'name': name, 'raw_positions': pos_text, 'position_unit': pos_unit,
                'distance_mode': dist_mode, 'n_points': ql_state['n'], 'mode': ql_state['mode']}
        with open(f"{run_dir}/input_snapshot.json", 'w') as f: json_module.dump(snap, f, indent=2)
        
        # quicklook.json
        ql = {'radius_rad': float(ql_state['ring'].radius), 'rms_rad': float(ql_state['ring'].rms_residual),
              'morphology': ql_state['morph'].primary.value, 'mode': ql_state['mode']}
        with open(f"{run_dir}/quicklook.json", 'w') as f: json_module.dump(ql, f, indent=2)
        
        # solutions/<model>.json
        if inv_state:
            for r in inv_state['results']:
                sol = {'model': r.model_name, 'params': {k: float(v) for k,v in r.params.items()},
                       'max_residual': float(r.max_residual), 'is_exact': r.is_exact, 'regime': r.regime}
                with open(f"{run_dir}/solutions/{r.model_name}.json", 'w') as f: json_module.dump(sol, f, indent=2)
        
        # scene3d.json
        if scene_state:
            sc = {'D_L_m': scene_state['D_L'], 'D_S_m': scene_state['D_S'], 
                  'theta_E_rad': scene_state['theta_E'], 'normalized': scene_state['normalized'],
                  'internal_units': {'angle': 'rad', 'distance': 'm', 'time': 's', 'mass': 'kg'}}
            with open(f"{run_dir}/scene3d.json", 'w') as f: json_module.dump(sc, f, indent=2)
        
        # report.md
        with open(f"{run_dir}/report.md", 'w') as f:
            f.write(f"# Run: {name}\n\nTimestamp: {ts}\nMode: {ql_state['mode']}\n")
            f.write(f"\n## Input\n- {ql_state['n']} points ({pos_unit})\n- Distance: {dist_mode}\n")
            if inv_state:
                f.write(f"\n## Best Model: {inv_state['best'].model_name}\n")
        
        return f"✓ Saved to: {run_dir}\n- input_snapshot.json\n- quicklook.json\n- solutions/*.json\n- scene3d.json\n- report.md"
    except Exception as e:
        return f"Error: {e}"

def list_runs():
    runs = []
    if os.path.exists('runs'):
        for d in sorted(os.listdir('runs'), reverse=True)[:20]:
            path = f"runs/{d}/input_snapshot.json"
            if os.path.exists(path):
                with open(path) as f:
                    snap = json_module.load(f)
                runs.append([snap.get('timestamp',''), snap.get('name',''), snap.get('mode','')])
    return runs

# ============== GRADIO UI ==============
with gr.Blocks(title="RSG/SSZ Lensing Suite", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# RSG / SSZ Lensing Suite")
    
    ql_state = gr.State(None)
    inv_state = gr.State(None)
    scene_state = gr.State(None)
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## Input")
            with gr.Accordion("Observations", open=True):
                pos_text = gr.Textbox(value=QUAD_EX, lines=5, label="Positions (x, y)")
                pos_unit = gr.Dropdown(["arcsec", "mas", "uas", "rad"], value="arcsec", label="Unit")
                with gr.Row():
                    center_known = gr.Checkbox(False, label="Center known?")
                    cx = gr.Number(0.0, label="x0", scale=1)
                    cy = gr.Number(0.0, label="y0", scale=1)
                    c_unit = gr.Dropdown(["arcsec", "mas"], value="arcsec", scale=1)
            
            with gr.Accordion("Distances (3D)", open=False):
                dist_mode = gr.Radio(["Normalized", "Direct distances", "Redshifts"], value="Normalized")
                gr.Markdown("*⚠️ Normalized: sizes not physical*")
                with gr.Group(visible=False) as direct_grp:
                    d_L = gr.Number(1.3, label="D_L"); d_L_u = gr.Dropdown(["Gpc","Mpc","kpc"], value="Gpc")
                    d_S = gr.Number(2.1, label="D_S"); d_S_u = gr.Dropdown(["Gpc","Mpc","kpc"], value="Gpc")
                with gr.Group(visible=False) as z_grp:
                    z_L = gr.Number(0.5, label="z_L"); z_S = gr.Number(2.0, label="z_S")
                    cosmo = gr.Dropdown(["Planck18","Planck15","WMAP9"], value="Planck18")
                lens_mass = gr.Number(None, label="Lens mass (M_sun)")
            
            with gr.Accordion("Model Zoo", open=False):
                m2 = gr.Checkbox(True, label="m=2")
                shear = gr.Checkbox(True, label="+shear")
                m3 = gr.Checkbox(True, label="+m3")
                m4 = gr.Checkbox(True, label="+m4")
        
        with gr.Column(scale=3):
            with gr.Tabs():
                with gr.Tab("Quicklook"):
                    btn_ql = gr.Button("Run Quicklook", variant="primary")
                    with gr.Row():
                        ql_summary = gr.Markdown()
                        ql_morph = gr.Markdown()
                    ql_harm = gr.Markdown()
                    ql_plot = gr.Plot()
                
                with gr.Tab("Inversion"):
                    btn_inv = gr.Button("Run Inversion", variant="primary")
                    with gr.Row():
                        inv_lb = gr.Markdown()
                        inv_det = gr.Markdown()
                    inv_plot = gr.Plot()
                
                with gr.Tab("3D Scene"):
                    btn_scene = gr.Button("Generate Scene", variant="primary")
                    scene_units = gr.Markdown()
                    with gr.Row():
                        scene_3d = gr.Plot()
                        scene_side = gr.Plot()
                
                with gr.Tab("Runs"):
                    run_name = gr.Textbox(label="Run name", value="my_run")
                    btn_save = gr.Button("Save Run")
                    save_status = gr.Markdown()
                    btn_refresh = gr.Button("Refresh")
                    runs_table = gr.Dataframe(headers=["Timestamp","Name","Mode"])
    
    dist_mode.change(lambda m: (gr.update(visible=m=="Direct distances"), gr.update(visible=m=="Redshifts")), 
                    [dist_mode], [direct_grp, z_grp])
    btn_ql.click(quicklook_fn, [pos_text, pos_unit, center_known, cx, cy, c_unit], 
                [ql_summary, ql_morph, ql_harm, ql_plot, ql_state])
    btn_inv.click(inversion_fn, [ql_state, m2, shear, m3, m4], [inv_lb, inv_det, inv_plot, inv_state])
    btn_scene.click(scene_fn, [dist_mode, d_L, d_L_u, d_S, d_S_u, z_L, z_S, cosmo, lens_mass, ql_state, inv_state],
                   [scene_units, scene_3d, scene_side, scene_state])
    btn_save.click(save_run, [run_name, ql_state, inv_state, scene_state, pos_text, pos_unit, dist_mode], [save_status])
    btn_refresh.click(list_runs, [], [runs_table])

demo.launch(share=True)
