# 🔭 RSG Lensing Inversion Framework

**Authors:** Carmen N. Wrede, Lino P. Casu

## Features
- Morphology Classification (Ring/Quad/Arc/Double)
- Ring Analysis + Harmonics (m=2, m=4)
- **3D Scene Visualization** (Observer → Lens → Source)
- Synthetic Data Generation

**Run all cells for shareable Gradio link!**

In [None]:
!pip install -q gradio numpy matplotlib

In [None]:
import numpy as np
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Dict, Any
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

class Morphology(Enum):
    RING = "ring"
    QUAD = "quad"
    ARC = "arc"
    DOUBLE = "double"
    UNKNOWN = "unknown"

@dataclass
class MorphologyAnalysis:
    primary: Morphology
    confidence: float
    mean_radius: float
    radial_scatter: float
    azimuthal_coverage: float
    azimuthal_uniformity: float
    m2_amplitude: float
    m4_amplitude: float
    recommended_models: List[str]
    notes: List[str]

class MorphologyClassifier:
    RING_RADIAL_SCATTER = 0.05
    RING_AZIMUTHAL_COV = 0.7
    
    def __init__(self, center=(0.0, 0.0)):
        self.center = np.array(center)
    
    def classify(self, positions):
        n = len(positions)
        rel = positions - self.center
        r = np.sqrt(rel[:, 0]**2 + rel[:, 1]**2)
        phi = np.arctan2(rel[:, 1], rel[:, 0])
        r_mean, r_std = np.mean(r), np.std(r)
        radial_scatter = r_std / r_mean if r_mean > 0 else 1.0
        
        phi_sorted = np.sort(phi)
        gaps = np.diff(phi_sorted)
        gaps = np.append(gaps, 2*np.pi + phi_sorted[0] - phi_sorted[-1])
        azimuthal_coverage = 1.0 - np.max(gaps) / (2*np.pi)
        azimuthal_uniformity = 1.0 / (1.0 + np.var(gaps) / (2*np.pi/n)**2)
        
        m2_amp = np.sqrt(np.mean((r-r_mean)*np.cos(2*phi))**2 + np.mean((r-r_mean)*np.sin(2*phi))**2) / r_mean
        m4_amp = np.sqrt(np.mean((r-r_mean)*np.cos(4*phi))**2 + np.mean((r-r_mean)*np.sin(4*phi))**2) / r_mean
        
        notes, models = [], []
        if n == 4:
            primary, conf = Morphology.QUAD, 0.9
            notes.append("Quad: 4 discrete images")
            models = ["m2", "m2+shear"]
        elif n == 2:
            primary, conf = Morphology.DOUBLE, 0.9
            notes.append("Double: two-image system")
            models = ["m2"]
        elif n > 4 and radial_scatter < 0.05 and azimuthal_coverage > 0.7:
            primary, conf = Morphology.RING, min(0.95, 1 - radial_scatter/0.05)
            notes.append("Ring-like")
            models = ["isotropic"]
            if m2_amp > 0.005: models.append("isotropic+shear")
        elif azimuthal_coverage < 0.5:
            primary, conf = Morphology.ARC, 0.7
            notes.append("Arc-like")
            models = ["m2", "isotropic"]
        else:
            primary, conf = Morphology.UNKNOWN, 0.5
            notes.append("Mixed morphology")
            models = ["m2", "m2+shear"]
        
        return MorphologyAnalysis(primary, conf, r_mean, radial_scatter, azimuthal_coverage, 
                                   azimuthal_uniformity, m2_amp, m4_amp, models, notes)

print("✅ Morphology classes loaded!")

In [None]:
@dataclass
class RingFitResult:
    center_x: float
    center_y: float
    radius: float
    radial_residuals: np.ndarray
    azimuthal_angles: np.ndarray
    rms_residual: float
    m2_component: Tuple[float, float]
    m4_component: Tuple[float, float]
    is_perturbed: bool
    perturbation_type: str

class RingAnalyzer:
    PERTURBATION_THRESHOLD = 0.02
    
    def fit_ring(self, positions, initial_center=None):
        if initial_center is None:
            cx, cy = self._estimate_center(positions)
        else:
            cx, cy = initial_center
        
        rel = positions - np.array([cx, cy])
        r = np.sqrt(rel[:, 0]**2 + rel[:, 1]**2)
        phi = np.arctan2(rel[:, 1], rel[:, 0])
        radius = np.median(r)
        dr = r - radius
        rms = np.sqrt(np.mean(dr**2))
        
        m2_amp, m2_phase = self._fit_harmonic(dr, phi, 2)
        m4_amp, m4_phase = self._fit_harmonic(dr, phi, 4)
        
        is_perturbed = m2_amp > self.PERTURBATION_THRESHOLD * radius or m4_amp > self.PERTURBATION_THRESHOLD * radius
        if m2_amp > m4_amp and m2_amp > self.PERTURBATION_THRESHOLD * radius:
            ptype = "quadrupole (m=2)"
        elif m4_amp > self.PERTURBATION_THRESHOLD * radius:
            ptype = "hexadecapole (m=4)"
        elif is_perturbed:
            ptype = "mixed"
        else:
            ptype = "isotropic"
        
        return RingFitResult(cx, cy, radius, dr, phi, rms, (m2_amp, m2_phase), (m4_amp, m4_phase), is_perturbed, ptype)
    
    def _estimate_center(self, positions):
        n = len(positions)
        if n < 3:
            return (np.mean(positions[:, 0]), np.mean(positions[:, 1]))
        x, y = positions[:, 0], positions[:, 1]
        A = np.column_stack([x, y, np.ones(n)])
        b = x**2 + y**2
        try:
            coeffs, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
            return (coeffs[0] / 2, coeffs[1] / 2)
        except:
            return (np.mean(x), np.mean(y))
    
    def _fit_harmonic(self, dr, phi, m):
        c = np.mean(dr * np.cos(m * phi))
        s = np.mean(dr * np.sin(m * phi))
        amp = 2 * np.sqrt(c**2 + s**2)
        phase = np.arctan2(s, c)
        return (amp, phase)

def generate_ring_points(theta_E=1.0, n_points=50, center=(0.0, 0.0), c2=0.0, s2=0.0, c4=0.0, s4=0.0, noise=0.0):
    phi = np.linspace(0, 2*np.pi, n_points, endpoint=False)
    r = theta_E + c2*np.cos(2*phi) + s2*np.sin(2*phi) + c4*np.cos(4*phi) + s4*np.sin(4*phi)
    x = center[0] + r * np.cos(phi)
    y = center[1] + r * np.sin(phi)
    if noise > 0:
        x += np.random.normal(0, noise, n_points)
        y += np.random.normal(0, noise, n_points)
    return np.column_stack([x, y])

print("✅ Ring analysis loaded!")

In [None]:
@dataclass
class Position3D:
    x: float = 0.0
    y: float = 0.0
    z: float = 0.0
    label: str = ""
    def to_array(self): return np.array([self.x, self.y, self.z])

@dataclass
class LensProperties:
    position: Position3D
    einstein_radius: float = 1.0
    ellipticity: float = 0.0

@dataclass
class SourceProperties:
    position: Position3D
    source_id: int = 0

@dataclass
class TriadScene:
    name: str
    observer: Position3D = field(default_factory=lambda: Position3D(0, 0, 0, "Observer"))
    lens: LensProperties = None
    sources: List[SourceProperties] = field(default_factory=list)
    
    def __post_init__(self):
        if self.lens is None:
            self.lens = LensProperties(position=Position3D(0, 0, 1.0, "Lens"))
    
    def add_source(self, x, y, z, source_id=None):
        if source_id is None: source_id = len(self.sources)
        self.sources.append(SourceProperties(position=Position3D(x, y, z, f"Source_{source_id}"), source_id=source_id))
    
    @classmethod
    def create_standard(cls, name, D_L=1.0, D_S=2.0, beta_x=0.1, beta_y=-0.05, theta_E=1.0):
        scene = cls(name=name)
        scene.lens = LensProperties(position=Position3D(0, 0, D_L, "Lens"), einstein_radius=theta_E)
        scene.add_source(beta_x * D_S, beta_y * D_S, D_S)
        return scene

print("✅ 3D Scene classes loaded!")

In [None]:
def plot_3d_scene(scene, images=None):
    """3D visualization of Observer-Lens-Source geometry."""
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    ax.scatter([0], [0], [0], c='blue', s=150, marker='o', label='Observer')
    L = scene.lens.position
    ax.scatter([L.x], [L.y], [L.z], c='red', s=200, marker='s', label='Lens')
    
    # Lens plane circle
    theta_E = scene.lens.einstein_radius
    theta = np.linspace(0, 2*np.pi, 50)
    ax.plot(L.x + 0.3*np.cos(theta), L.y + 0.3*np.sin(theta), [L.z]*50, 'r-', alpha=0.5)
    
    for src in scene.sources:
        S = src.position
        ax.scatter([S.x], [S.y], [S.z], c='gold', s=200, marker='*', label='Source')
        ax.plot([0, 0], [0, 0], [0, S.z], 'k--', alpha=0.3)
    
    if images is not None:
        D_L = L.z
        D_S = scene.sources[0].position.z if scene.sources else 2*D_L
        colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(images)))
        for img, c in zip(images, colors):
            x_L, y_L = img[0]*D_L*0.5, img[1]*D_L*0.5
            ax.plot([0, x_L], [0, y_L], [0, D_L], color=c, linewidth=2, alpha=0.7)
            r = np.sqrt(img[0]**2 + img[1]**2)
            if r > 0.1:
                beta_r = r - theta_E/r
                ang = np.arctan2(img[1], img[0])
                x_S, y_S = beta_r*np.cos(ang)*D_S*0.5, beta_r*np.sin(ang)*D_S*0.5
                ax.plot([x_L, x_S], [y_L, y_S], [D_L, D_S], color=c, linestyle='--', linewidth=2, alpha=0.7)
    
    ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z (distance)')
    ax.set_title(f'3D Lensing Geometry: {scene.name}')
    ax.legend()
    return fig

def plot_lens_plane(images, theta_E=1.0, center=(0,0)):
    fig, ax = plt.subplots(figsize=(8, 8))
    theta = np.linspace(0, 2*np.pi, 100)
    ax.plot(center[0]+theta_E*np.cos(theta), center[1]+theta_E*np.sin(theta), 'b--', lw=2, alpha=0.6, label=f'Einstein ring')
    ax.scatter(images[:, 0], images[:, 1], c='red', s=100, marker='o', label='Images', zorder=5)
    for i, img in enumerate(images):
        ax.annotate(f'{i+1}', (img[0]+0.05, img[1]+0.05), fontsize=12, fontweight='bold')
    ax.scatter([center[0]], [center[1]], c='black', s=100, marker='+', lw=3, label='Lens center')
    ax.set_xlabel('θ_x'); ax.set_ylabel('θ_y'); ax.set_title('Lens Plane'); ax.set_aspect('equal'); ax.legend(); ax.grid(True, alpha=0.3)
    return fig

def plot_ring_analysis(positions, ring):
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))
    
    ax1 = axes[0]
    ax1.scatter(positions[:, 0], positions[:, 1], c='blue', s=40, alpha=0.7, label='Points')
    theta = np.linspace(0, 2*np.pi, 100)
    ax1.plot(ring.center_x + ring.radius*np.cos(theta), ring.center_y + ring.radius*np.sin(theta), 'r-', lw=2, label=f'Fit (R={ring.radius:.4f})')
    ax1.scatter([ring.center_x], [ring.center_y], c='red', s=150, marker='+', lw=3)
    ax1.set_aspect('equal'); ax1.set_title('Ring Overlay'); ax1.legend(); ax1.grid(True, alpha=0.3)
    
    ax2 = axes[1]
    idx = np.argsort(ring.azimuthal_angles)
    ax2.scatter(np.degrees(ring.azimuthal_angles[idx]), ring.radial_residuals[idx], c='blue', s=40)
    ax2.axhline(0, color='gray', ls='--')
    phi_m = np.linspace(-np.pi, np.pi, 200)
    ax2.plot(np.degrees(phi_m), ring.m2_component[0]*np.cos(2*phi_m - ring.m2_component[1]), 'g-', lw=2, label=f'm=2: {ring.m2_component[0]:.4f}')
    ax2.plot(np.degrees(phi_m), ring.m4_component[0]*np.cos(4*phi_m - ring.m4_component[1]), 'orange', lw=2, label=f'm=4: {ring.m4_component[0]:.4f}')
    ax2.set_xlabel('Angle (deg)'); ax2.set_ylabel('Residual'); ax2.set_title('Residual vs Angle'); ax2.legend(); ax2.grid(True, alpha=0.3)
    
    ax3 = axes[2]
    bars = ax3.bar(['m=2', 'm=4'], [ring.m2_component[0], ring.m4_component[0]], color=['green', 'orange'], alpha=0.7)
    ax3.axhline(0.02*ring.radius, color='red', ls='--', label='2% threshold')
    ax3.set_ylabel('Amplitude'); ax3.set_title(f'Harmonics: {ring.perturbation_type}'); ax3.legend()
    
    plt.tight_layout()
    return fig

def plot_overview(positions, morph, ring, scene):
    fig = plt.figure(figsize=(16, 12))
    
    ax1 = fig.add_subplot(221, projection='3d')
    L = scene.lens.position
    ax1.scatter([0], [0], [0], c='blue', s=100); ax1.scatter([L.x], [L.y], [L.z], c='red', s=150)
    for src in scene.sources: ax1.scatter([src.position.x], [src.position.y], [src.position.z], c='gold', s=150)
    D_L, D_S = L.z, scene.sources[0].position.z if scene.sources else 2*L.z
    for img in positions[:6]:
        ax1.plot([0, img[0]*D_L*0.3], [0, img[1]*D_L*0.3], [0, D_L], 'g-', alpha=0.4)
    ax1.set_title('3D Geometry')
    
    ax2 = fig.add_subplot(222)
    theta = np.linspace(0, 2*np.pi, 100)
    ax2.plot(scene.lens.einstein_radius*np.cos(theta), scene.lens.einstein_radius*np.sin(theta), 'b--', alpha=0.6)
    ax2.scatter(positions[:, 0], positions[:, 1], c='red', s=60)
    ax2.scatter([ring.center_x], [ring.center_y], c='black', s=80, marker='+')
    ax2.set_aspect('equal'); ax2.set_title(f'Lens Plane - {morph.primary.value.upper()} ({morph.confidence:.0%})'); ax2.grid(True, alpha=0.3)
    
    ax3 = fig.add_subplot(223)
    ax3.scatter(positions[:, 0], positions[:, 1], c='blue', s=40, alpha=0.7)
    ax3.plot(ring.center_x + ring.radius*np.cos(theta), ring.center_y + ring.radius*np.sin(theta), 'r-', lw=2)
    ax3.scatter([ring.center_x], [ring.center_y], c='red', s=100, marker='+')
    ax3.set_aspect('equal'); ax3.set_title(f'Ring Fit (R={ring.radius:.4f})'); ax3.grid(True, alpha=0.3)
    
    ax4 = fig.add_subplot(224)
    idx = np.argsort(ring.azimuthal_angles)
    ax4.scatter(np.degrees(ring.azimuthal_angles[idx]), ring.radial_residuals[idx], c='blue', s=30)
    ax4.axhline(0, color='gray', ls='--')
    phi_m = np.linspace(-np.pi, np.pi, 100)
    ax4.plot(np.degrees(phi_m), ring.m2_component[0]*np.cos(2*phi_m - ring.m2_component[1]), 'g-', alpha=0.7)
    ax4.plot(np.degrees(phi_m), ring.m4_component[0]*np.cos(4*phi_m - ring.m4_component[1]), 'orange', alpha=0.7)
    ax4.set_xlabel('Angle'); ax4.set_ylabel('Residual'); ax4.set_title(f'Harmonics: {ring.perturbation_type}'); ax4.grid(True, alpha=0.3)
    
    plt.suptitle(f'RSG Lensing Analysis: {scene.name}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig

print("✅ Visualization functions loaded!")

In [None]:
import gradio as gr

classifier = MorphologyClassifier()
ring_analyzer = RingAnalyzer()

EXAMPLE_RING = """0.95, 0.31
0.59, 0.81
0.00, 1.00
-0.59, 0.81
-0.95, 0.31
-0.95, -0.31
-0.59, -0.81
0.00, -1.00
0.59, -0.81
0.95, -0.31"""

EXAMPLE_QUAD = """0.740, 0.565
-0.635, 0.470
-0.480, -0.755
0.870, -0.195"""

EXAMPLE_SHEAR = """1.10, 0.00
0.78, 0.78
0.00, 0.90
-0.78, 0.78
-1.10, 0.00
-0.78, -0.78
0.00, -0.90
0.78, -0.78"""

def parse_positions(text):
    lines = [l.strip() for l in text.strip().split('\n') if l.strip()]
    return np.array([[float(x) for x in l.replace(',', ' ').split()[:2]] for l in lines])

def analyze(text):
    try:
        pos = parse_positions(text)
        if len(pos) < 2: return "Need >= 2 positions", None, None, None, None
        
        morph = classifier.classify(pos)
        ring = ring_analyzer.fit_ring(pos)
        scene = TriadScene.create_standard("Analysis", theta_E=ring.radius)
        
        report = f"""# Analysis Result

## Morphology: **{morph.primary.value.upper()}** ({morph.confidence:.0%})
{', '.join(morph.notes)}

## Ring Fit
| Metric | Value |
|--------|-------|
| Center | ({ring.center_x:.4f}, {ring.center_y:.4f}) |
| Radius | {ring.radius:.4f} |
| RMS Residual | {ring.rms_residual:.4f} |
| m=2 Amplitude | {ring.m2_component[0]:.4f} |
| m=4 Amplitude | {ring.m4_component[0]:.4f} |
| Perturbation | {ring.perturbation_type} |

## Recommended Models
{', '.join(morph.recommended_models)}"""
        
        fig_overview = plot_overview(pos, morph, ring, scene)
        fig_3d = plot_3d_scene(scene, pos)
        fig_lens = plot_lens_plane(pos, ring.radius, (ring.center_x, ring.center_y))
        fig_ring = plot_ring_analysis(pos, ring)
        
        return report, fig_overview, fig_3d, fig_lens, fig_ring
    except Exception as e:
        import traceback
        return f"Error: {e}\n{traceback.format_exc()}", None, None, None, None

def generate(stype, n, noise, c2, c4):
    n = int(n)
    if stype == "Quad":
        phi = np.array([0.3, 1.8, 3.5, 5.2])
        pos = np.column_stack([np.cos(phi), np.sin(phi)])
    else:
        pos = generate_ring_points(1.0, n, (0,0), c2, 0, c4, 0, noise)
    return '\n'.join([f'{p[0]:.4f}, {p[1]:.4f}' for p in pos])

with gr.Blocks(title="RSG Lensing Inversion Framework") as demo:
    gr.Markdown("# 🔭 RSG Lensing Inversion Framework\nGravitational Lensing Analysis Tool\n\nAnalyze Einstein Rings, Crosses, and Arcs with automatic morphology classification.")
    
    with gr.Tabs():
        with gr.Tab("📊 Analyze Positions"):
            with gr.Row():
                with gr.Column(scale=1):
                    inp = gr.Textbox(label="Image Positions (x, y per line)", lines=10, value=EXAMPLE_RING)
                    with gr.Row():
                        btn_analyze = gr.Button("🔍 Analyze", variant="primary")
                        btn_clear = gr.Button("Clear")
                    gr.Markdown("### Examples")
                    with gr.Row():
                        btn_ring = gr.Button("Einstein Ring")
                        btn_quad = gr.Button("Einstein Cross")
                        btn_shear = gr.Button("Ring + Shear")
                with gr.Column(scale=2):
                    out_md = gr.Markdown()
            
            gr.Markdown("### Visualizations")
            with gr.Row():
                out_overview = gr.Plot(label="Overview")
                out_3d = gr.Plot(label="3D Scene")
            with gr.Row():
                out_lens = gr.Plot(label="Lens Plane")
                out_ring = gr.Plot(label="Ring Analysis")
            
            btn_analyze.click(analyze, inp, [out_md, out_overview, out_3d, out_lens, out_ring])
            btn_clear.click(lambda: "", None, inp)
            btn_ring.click(lambda: EXAMPLE_RING, None, inp)
            btn_quad.click(lambda: EXAMPLE_QUAD, None, inp)
            btn_shear.click(lambda: EXAMPLE_SHEAR, None, inp)
        
        with gr.Tab("⚙️ Generate Synthetic"):
            with gr.Row():
                stype = gr.Dropdown(["Ring", "Quad"], value="Ring", label="Type")
                n = gr.Slider(4, 100, 20, step=1, label="Points")
                noise = gr.Slider(0, 0.1, 0.01, label="Noise")
            with gr.Row():
                c2 = gr.Slider(0, 0.3, 0, label="m=2 (shear)")
                c4 = gr.Slider(0, 0.2, 0, label="m=4")
            out_gen = gr.Textbox(label="Generated Positions", lines=10)
            btn_gen = gr.Button("Generate")
            btn_gen.click(generate, [stype, n, noise, c2, c4], out_gen)
        
        with gr.Tab("ℹ️ About"):
            gr.Markdown("""# RSG Lensing Inversion Framework

**Paper:** Radial Scaling Gauge for Maxwell Fields  
**Authors:** Carmen N. Wrede, Lino P. Casu

## Morphology Classification
- **RING**: Full Einstein ring (azimuthal coverage > 70%, radial scatter < 5%)
- **QUAD**: Einstein cross (4 discrete images)
- **ARC**: Partial ring structure
- **DOUBLE**: Two-image system

## Harmonic Analysis
- **m=2**: Quadrupole/shear perturbation
- **m=4**: Hexadecapole perturbation

## Model Zoo
- isotropic, isotropic+shear, m2, m2+shear, m2+m3, m2+m4
""")

demo.launch(share=True)