In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Difference Hologram (Signed Boundary Advance Map) with Auto-Registration.

1) Load A (untr) and B (vacv); bake scene transforms.
2) Co-register: translate B so COM(B) -> COM(A). (Optional ICP if Open3D present.)
3) Pick joint interior center; ray-cast to get r_A, r_B; map Δr = r_B - r_A on sphere.
4) Detect salient lobes and plot.

Outputs:
  <label>_diff_hologram.png
  <label>_diff_hologram_{delta_map,valid_mask}.npy
  <label>_diff_hologram_lobes.csv
"""

import os, csv, warnings
import numpy as np
import matplotlib.pyplot as plt
import trimesh
from dataclasses import dataclass
from typing import Tuple, Optional, Dict
from scipy.ndimage import gaussian_filter
from skimage.measure import label
from skimage.segmentation import find_boundaries

# ---------------------------- Utilities ----------------------------

def load_mesh(path: str) -> trimesh.Trimesh:
    obj = trimesh.load(path)
    if isinstance(obj, trimesh.Scene):
        m = obj.dump(concatenate=True)  # bake transforms
    else:
        m = obj
    if not isinstance(m, trimesh.Trimesh):
        try:
            m = trimesh.util.concatenate([g for g in m.geometry.values()])
        except Exception as e:
            raise ValueError(f"Could not convert to Trimesh for {path}: {e}")
    if not m.is_watertight:
        warnings.warn(f"[warn] Not watertight: {os.path.basename(path)} (rays still OK).")
    m.remove_degenerate_faces()
    m.remove_unreferenced_vertices()
    m.process(validate=True)
    return m

def diag_mesh(name, m: trimesh.Trimesh):
    bb = m.bounds
    size = bb[1] - bb[0]
    diag = float(np.linalg.norm(size))
    print(f"[diag] {name}: bbox min={bb[0]}, max={bb[1]}, size={size}, diag={diag:.3f}, center_mass={m.center_mass}")

def try_icp_open3d(moving_pts: np.ndarray, fixed_pts: np.ndarray):
    """Optional: one-shot point-to-point ICP via Open3D if installed; returns 4x4 or None."""
    try:
        import open3d as o3d
    except Exception:
        return None
    src = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(moving_pts))
    tgt = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(fixed_pts))
    threshold = float(np.linalg.norm(tgt.get_max_bound() - tgt.get_min_bound())) * 0.2
    init = np.eye(4)
    reg = o3d.pipelines.registration.registration_icp(
        src, tgt, threshold, init,
        o3d.pipelines.registration.TransformationEstimationPointToPoint()
    )
    if reg.transformation is None:
        return None
    return np.asarray(reg.transformation)

def coregister_B_to_A(mA: trimesh.Trimesh, mB: trimesh.Trimesh, use_icp: bool = True):
    """Translate B to match A's COM; optionally refine with ICP if available."""
    shift = mA.center_mass - mB.center_mass
    mB.apply_translation(shift)
    print(f"[align] COM-align B by translation: {shift}")

    if use_icp:
        # Sample points uniformly on surfaces
        Apts, _ = trimesh.sample.sample_surface(mA, 5000)
        Bpts, _ = trimesh.sample.sample_surface(mB, 5000)
        T = try_icp_open3d(Bpts, Apts)
        if T is not None:
            R = T[:3, :3]; t = T[:3, 3]
            mB.apply_transform(T)
            print(f"[align] ICP refinement applied.\n        R≈\n{R}\n        t={t}")
        else:
            print("[align] ICP skipped (Open3D not available).")

def build_intersector(mesh: trimesh.Trimesh):
    try:
        from trimesh.ray.ray_pyembree import RayMeshIntersector
        return RayMeshIntersector(mesh)
    except Exception:
        from trimesh.ray.ray_triangle import RayMeshIntersector
        return RayMeshIntersector(mesh)

def sdf_sign(mesh: trimesh.Trimesh, pts: np.ndarray) -> np.ndarray:
    try:
        from trimesh.proximity import signed_distance
        return signed_distance(mesh, pts)
    except Exception:
        from trimesh.proximity import closest_point
        _, d, _ = closest_point(mesh, pts)
        return d  # unsigned fallback

def pick_joint_center(meshA: trimesh.Trimesh, meshB: trimesh.Trimesh) -> np.ndarray:
    cA = meshA.center_mass; cB = meshB.center_mass; mid = 0.5 * (cA + cB)
    seg = np.linspace(0, 1, 41)[:, None]
    line = cA[None, :] * (1 - seg) + cB[None, :] * seg
    sA = sdf_sign(meshA, line); sB = sdf_sign(meshB, line)
    cost = np.maximum(sA, sB).reshape(-1)
    best = line[int(np.argmin(cost))]
    if cost.min() > 0:
        bbA, bbB = meshA.bounds, meshB.bounds
        scale = float(np.linalg.norm((bbA[1]-bbA[0]) + (bbB[1]-bbB[0])) / 2.0)
        step = max(1e-2, scale * 0.01)
        offsets = np.array([-1, -0.5, 0, 0.5, 1.0]) * step
        Gx, Gy, Gz = np.meshgrid(offsets, offsets, offsets, indexing='ij')
        grid = mid + np.stack([Gx, Gy, Gz], axis=-1).reshape(-1, 3)
        sA = sdf_sign(meshA, grid); sB = sdf_sign(meshB, grid)
        cost = np.maximum(sA, sB).reshape(-1)
        best = grid[int(np.argmin(cost))]
    return best.astype(np.float64)

def spherical_grid(H: int, W: int):
    eps = 1e-4
    theta = np.linspace(eps, np.pi - eps, H, dtype=np.float64)
    phi   = np.linspace(-np.pi, np.pi, W, endpoint=False, dtype=np.float64)
    Th, Ph = np.meshgrid(theta, phi, indexing='ij')
    st, ct = np.sin(Th), np.cos(Th); cp, sp = np.cos(Ph), np.sin(Ph)
    dirs = np.stack([st*cp, st*sp, ct], axis=-1)  # (H,W,3)
    return Th, Ph, dirs.reshape(-1, 3)

def ray_first_hit_dists(intersector, origin: np.ndarray, directions: np.ndarray) -> np.ndarray:
    origins = np.repeat(origin[None, :], len(directions), axis=0)
    locs, idx_ray, _ = intersector.intersects_location(origins, directions, multiple_hits=False)
    dists = np.full(len(directions), np.nan, dtype=np.float64)
    if len(idx_ray) > 0:
        delta = locs - origins[idx_ray]
        dists[idx_ray] = np.linalg.norm(delta, axis=1)
    return dists

def robust_threshold(abs_map: np.ndarray, valid: np.ndarray, percentile: float = 85.0) -> float:
    vals = abs_map[valid]
    if vals.size == 0:
        return 0.0
    return float(np.percentile(vals, percentile))

def spherical_pixel_solid_angle(theta_low, theta_high, dphi) -> float:
    return (np.cos(theta_low) - np.cos(theta_high)) * dphi

# ---------------------------- Core Compute ----------------------------

@dataclass
class DiffHologramResult:
    delta_map: np.ndarray
    valid_mask: np.ndarray
    theta: np.ndarray
    phi: np.ndarray
    hit_rate_A: float
    hit_rate_B: float
    center: np.ndarray

def compute_difference_hologram(meshA_path: str,
                                meshB_path: str,
                                H: int = 256,
                                W: int = 512,
                                smooth_sigma_px: Optional[float] = 0.0,
                                fixed_center: Optional[np.ndarray] = None,
                                do_coregister: bool = True) -> DiffHologramResult:
    mA = load_mesh(meshA_path)
    mB = load_mesh(meshB_path)
    print("== BEFORE ALIGNMENT ==")
    diag_mesh("A", mA); diag_mesh("B", mB)
    print(f"[diag] |center_mass(A)-center_mass(B)| = {np.linalg.norm(mA.center_mass - mB.center_mass):.3f}")

    if do_coregister:
        coregister_B_to_A(mA, mB, use_icp=True)
        print("== AFTER ALIGNMENT ==")
        diag_mesh("A", mA); diag_mesh("B", mB)
        print(f"[diag] |center_mass(A)-center_mass(B)| = {np.linalg.norm(mA.center_mass - mB.center_mass):.3f}")

    center = fixed_center.astype(np.float64) if fixed_center is not None else pick_joint_center(mA, mB)
    print(f"[debug] Center candidate: {center}")

    iA = build_intersector(mA); iB = build_intersector(mB)
    Th, Ph, dirs = spherical_grid(H, W)

    dA = ray_first_hit_dists(iA, center, dirs)
    dB = ray_first_hit_dists(iB, center, dirs)
    hitA = float(np.isfinite(dA).mean()) * 100.0
    hitB = float(np.isfinite(dB).mean()) * 100.0
    print(f"[debug] Ray hit-rate: A={hitA:.1f}%  B={hitB:.1f}%")

    delta = (dB - dA).reshape(H, W).astype(np.float32)
    valid = np.isfinite(delta)

    if smooth_sigma_px and smooth_sigma_px > 0:
        tmp = np.where(valid, delta, 0.0)
        wts = valid.astype(np.float32)
        tmp_blur = gaussian_filter(tmp, smooth_sigma_px, mode='wrap')
        wts_blur = gaussian_filter(wts, smooth_sigma_px, mode='wrap')
        with np.errstate(invalid='ignore', divide='ignore'):
            delta = np.where(wts_blur > 1e-6, tmp_blur / wts_blur, np.nan).astype(np.float32)
        valid = np.isfinite(delta)

    return DiffHologramResult(delta_map=delta, valid_mask=valid, theta=Th, phi=Ph,
                              hit_rate_A=hitA, hit_rate_B=hitB, center=center)

# ---------------------------- Lobe Detection ----------------------------

@dataclass
class LobeStats:
    label: int
    sign: str
    area_sr: float
    mean_delta: float
    max_delta: float
    centroid_theta: float
    centroid_phi: float

def detect_lobes(delta_map: np.ndarray,
                 theta: np.ndarray,
                 phi: np.ndarray,
                 valid: np.ndarray,
                 percentile: float = 85.0) -> Tuple[np.ndarray, np.ndarray, Dict[int, LobeStats]]:
    H, W = delta_map.shape
    absd = np.abs(delta_map)
    tau = robust_threshold(absd, valid, percentile=percentile)
    if tau == 0:
        return np.zeros_like(delta_map, int), np.zeros_like(delta_map, int), {}

    pos_mask = valid & (delta_map >=  tau)
    neg_mask = valid & (delta_map <= -tau)

    labels_pos = label(pos_mask, connectivity=1)
    labels_neg = label(neg_mask, connectivity=1)

    dphi = 2*np.pi / W
    theta_edges = np.zeros(H+1)
    theta_edges[1:-1] = 0.5*(theta[:-1,0] + theta[1:,0])
    theta_edges[0]    = max(0.0, theta[0,0] - (theta[1,0]-theta[0,0])/2)
    theta_edges[-1]   = min(np.pi, theta[-1,0] + (theta[-1,0]-theta[-2,0])/2)
    row_solid = np.array([spherical_pixel_solid_angle(theta_edges[i], theta_edges[i+1], dphi) for i in range(H)])

    stats: Dict[int, LobeStats] = {}

    def accumulate(labels_map: np.ndarray, sign_name: str):
        for lab in np.unique(labels_map):
            if lab == 0: continue
            mask = (labels_map == lab)
            vals = delta_map[mask]
            rows, _ = np.where(mask)
            area = float(row_solid[rows].sum())
            th_c = float(np.mean(theta[mask]))
            phi_c = float(np.angle(np.mean(np.exp(1j*phi[mask]))))
            stats_id = int((1 if sign_name=='advance' else -1) * lab)
            stats[stats_id] = LobeStats(
                label=stats_id, sign=sign_name, area_sr=area,
                mean_delta=float(np.mean(vals)),
                max_delta=float(np.max(vals) if sign_name=='advance' else np.min(vals)),
                centroid_theta=th_c, centroid_phi=phi_c
            )

    accumulate(labels_pos, 'advance')
    accumulate(labels_neg, 'retreat')
    return labels_pos, labels_neg, stats

# ---------------------------- Plot & Save ----------------------------

def plot_difference_hologram(delta_map: np.ndarray,
                             valid: np.ndarray,
                             labels_pos: np.ndarray,
                             labels_neg: np.ndarray,
                             out_png: str,
                             vlim: Optional[float] = None,
                             title: Optional[str] = None):
    H, W = delta_map.shape
    if not np.any(valid):
        print("[warn] No valid Δr samples; image may be blank (check alignment/center).")
    if vlim is None:
        m = np.nanpercentile(np.abs(delta_map[valid]), 95) if np.any(valid) else 1.0
        vlim = max(m, 1e-6)

    fig, ax = plt.subplots(figsize=(12, 6), dpi=150)
    im = ax.imshow(np.where(valid, delta_map, np.nan), origin='upper',
                   extent=[-180, 180, 180, 0],
                   cmap='RdBu_r', vmin=-vlim, vmax=+vlim, interpolation='nearest')
    ax.set_xlabel('Longitude (°)')
    ax.set_ylabel('Latitude (°)')
    if title: ax.set_title(title)

    for labs in (labels_pos, labels_neg):
        b = find_boundaries(labs, mode='outer')
        yy, xx = np.where(b)
        if len(xx) > 0:
            ax.plot(xx * 360/W - 180, yy * 180/H, '.', ms=0.2, color='k', alpha=0.7)

    cb = fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
    cb.set_label('Signed boundary advance Δr (mesh units)')
    fig.tight_layout()
    fig.savefig(out_png, bbox_inches='tight')
    plt.close(fig)

def save_arrays(delta_map: np.ndarray, valid: np.ndarray, out_npy_prefix: str):
    np.save(out_npy_prefix + "_delta_map.npy", delta_map)
    np.save(out_npy_prefix + "_valid_mask.npy", valid)

def save_lobe_csv(stats: Dict[int, 'LobeStats'], out_csv: str):
    with open(out_csv, 'w', newline='') as f:
        w = csv.writer(f)
        w.writerow(["label","sign","area_sr","mean_delta","max_delta","centroid_theta","centroid_phi"])
        for k in sorted(stats.keys(), key=lambda x: (x<0, abs(x))):
            s = stats[k]
            w.writerow([s.label, s.sign, f"{s.area_sr:.6f}", f"{s.mean_delta:.6f}",
                        f"{s.max_delta:.6f}", f"{s.centroid_theta:.6f}", f"{s.centroid_phi:.6f}"])

# ---------------------------- Driver ----------------------------

def run_one(time_label: str,
            meshA_path: str,  # baseline/control (untr)
            meshB_path: str,  # condition (vacv)
            out_dir: str = "data/green_monkey/va_testing/diff_hologram",
            H: int = 256, W: int = 512,
            smooth_sigma_px: float = 0.0,
            lobe_percentile: float = 85.0,
            fixed_center: Optional[np.ndarray] = None):
    os.makedirs(out_dir, exist_ok=True)
    res = compute_difference_hologram(meshA_path, meshB_path, H=H, W=W,
                                      smooth_sigma_px=smooth_sigma_px,
                                      fixed_center=fixed_center,
                                      do_coregister=True)

    labels_pos, labels_neg, stats = detect_lobes(
        res.delta_map, res.theta, res.phi, res.valid_mask, percentile=lobe_percentile
    )

    base = os.path.join(out_dir, f"{time_label}_diff_hologram")
    plot_difference_hologram(
        res.delta_map, res.valid_mask, labels_pos, labels_neg,
        out_png=base + ".png",
        title=f"Difference Hologram Δr (B–A): {time_label}"
    )
    save_arrays(res.delta_map, res.valid_mask, base)
    save_lobe_csv(stats, base + "_lobes.csv")

    print(f"[{time_label}] saved:")
    print(f"  • Map image  : {base}.png")
    print(f"  • NPY arrays : {base}_delta_map.npy, {base}_valid_mask.npy")
    print(f"  • Lobe stats : {base}_lobes.csv")
    if stats:
        top = sorted(stats.values(), key=lambda s: abs(s.mean_delta)*s.area_sr, reverse=True)[:5]
        print("  • Top lobes by |mean|×area:")
        for s in top:
            print(f"    - {s.sign:7s} | area={s.area_sr:.3f} sr | mean={s.mean_delta:+.4f} | peak={s.max_delta:+.4f}")
    else:
        print("  • No salient lobes at current percentile; try lowering lobe_percentile or smoothing.")
    print(f"[final] Hit-rates: A={res.hit_rate_A:.1f}%  B={res.hit_rate_B:.1f}%  |  Center: {res.center}")

if __name__ == "__main__":
    jobs = [
        dict(time_label="chr1_12h",
             meshA_path="data/green_monkey/all_structure_files/chr1/spatial_data/overall_shapes/chr1_12hrs_untr_metaball.obj",
             meshB_path="data/green_monkey/all_structure_files/chr1/spatial_data/overall_shapes/chr1_12hrs_vacv_metaball.obj"),
        # Add more timepoints...
    ]
    for j in jobs:
        run_one(**j, H=256, W=512, smooth_sigma_px=0.0, lobe_percentile=85.0, fixed_center=None)


  m.remove_degenerate_faces()


== BEFORE ALIGNMENT ==
[diag] A: bbox min=[-44.25700794  37.24437723 -55.32718145], max=[-36.44841934  45.55498291 -45.78098316], size=[7.8085886  8.31060568 9.54619829], diag=14.872, center_mass=[-40.01394241  41.32698472 -50.64148952]
[diag] B: bbox min=[-50.29174394  11.33644292 -74.17730233], max=[-40.96338338  20.40990785 -62.67100909], size=[ 9.32836056  9.07346493 11.50629324], diag=17.371, center_mass=[-45.76308494  16.2818578  -67.91130967]
[diag] |center_mass(A)-center_mass(B)| = 30.961
[align] COM-align B by translation: [ 5.74914253 25.04512692 17.26982016]
[align] ICP skipped (Open3D not available).
== AFTER ALIGNMENT ==
[diag] A: bbox min=[-44.25700794  37.24437723 -55.32718145], max=[-36.44841934  45.55498291 -45.78098316], size=[7.8085886  8.31060568 9.54619829], diag=14.872, center_mass=[-40.01394241  41.32698472 -50.64148952]
[diag] B: bbox min=[-44.54260141  36.38156984 -56.90748217], max=[-35.21424085  45.45503477 -45.40118893], size=[ 9.32836056  9.07346493 11.5062