In [5]:
# === KDE Hotspots: 3D density field → nested shells + visualization ===
# - Input: CSV with columns ['middle_x','middle_y','middle_z']
# - Method: voxel histogram → Gaussian blur (σ in voxels) = KDE approximation
# - Output: GLB meshes for top-quantile shells (e.g., 95/90/80), plus matplotlib 3D view

import os
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Sequence, Tuple, Dict, Optional
from scipy.ndimage import gaussian_filter
from skimage.measure import marching_cubes
import trimesh
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import plotly.graph_objects as go


# Config
@dataclass
class KDEConfig:
    grid_base: int = 160          # base resolution for shortest axis (increase for finer detail)
    margin: float = 5.0           # padding added (units of input coords)
    sigma_vox: float = 2.0        # KDE bandwidth in voxels (Gaussian σ)
    quantiles: Tuple[float,...] = (95, 90, 80)  # percentiles for shells (higher = smaller/hotter)
    min_component_area: int = 0   # optional: filter tiny comps after iso (set >0 to enable)
    csv_cols: Tuple[str,str,str] = ("middle_x","middle_y","middle_z")
    seed: int = 7

# I/O 
def load_points(csv_path: Optional[str]=None, df: Optional[pd.DataFrame]=None,
                cols: Tuple[str,str,str]=("middle_x","middle_y","middle_z")) -> np.ndarray:
    if df is None and csv_path is None:
        raise ValueError("Provide csv_path or df.")
    if df is None:
        df = pd.read_csv(csv_path)
    pts = df[list(cols)].dropna().values.astype(np.float32)
    if len(pts) < 10:
        raise ValueError(f"Too few points ({len(pts)}).")
    return pts

# Grid & KDE 
def compute_grid(pts: np.ndarray, grid_base: int, margin: float):
    """Compute voxel grid spanning points with padding; return (edges, spacing, origin, shape)."""
    mins = pts.min(axis=0) - margin
    maxs = pts.max(axis=0) + margin
    extent = maxs - mins
    # shortest axis gets grid_base; others scale to keep isotropic voxels
    shortest = float(extent.min())
    scale = grid_base / max(shortest, 1e-6)
    shape = np.maximum(np.round(extent * scale).astype(int), 8)
    # edges for np.histogramdd expect per-dim bin edges
    edges = [np.linspace(mins[i], maxs[i], shape[i] + 1) for i in range(3)]
    spacing = extent / shape
    origin = mins
    return edges, spacing, origin, tuple(int(s) for s in shape)

def voxelize_counts(pts: np.ndarray, edges) -> np.ndarray:
    """3D histogram of points over provided edges."""
    H, _ = np.histogramdd(pts, bins=edges)
    return H.astype(np.float32)

def kde_gaussian_splat(counts: np.ndarray, sigma_vox: float) -> np.ndarray:
    """Apply Gaussian blur to histogram counts (≈ KDE with Gaussian kernel in voxel units)."""
    if sigma_vox <= 0:
        return counts
    return gaussian_filter(counts, sigma=sigma_vox, mode="constant")

def quantile_levels(field: np.ndarray, q_percents: Sequence[float]) -> Dict[float, float]:
    """Quantiles computed on positive voxels only to avoid background domination."""
    vals = field[field > 0]
    if vals.size == 0:
        raise ValueError("Density field is empty (all zeros). Increase sigma/grid or check data.")
    return {q: np.percentile(vals, q) for q in q_percents}

# Isosurfaces & Export
def extract_isosurface(field: np.ndarray, level: float, spacing: np.ndarray, origin: np.ndarray):
    """Return a Trimesh from marching cubes at a given scalar level."""
    verts, faces, _, _ = marching_cubes(field, level=level, spacing=spacing)
    verts = verts + origin  # marching_cubes returns coordinates relative to origin (0,0,0); add true origin
    mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=True)
    return mesh

def export_glb(mesh: trimesh.Trimesh, out_path: str):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    mesh.export(out_path)  # trimesh auto-detects from suffix; '.glb' exports GLB

# Outputs & KDE hotspots function
@dataclass
class KDEOutputs:
    counts: np.ndarray
    density: np.ndarray
    edges: Tuple[np.ndarray,np.ndarray,np.ndarray]
    spacing: np.ndarray
    origin: np.ndarray
    shape: Tuple[int,int,int]
    levels: Dict[float,float]
    meshes: Dict[float, trimesh.Trimesh]

def run_kde_hotspots(pts: np.ndarray, cfg: KDEConfig, out_dir: Optional[str]=None,
                     tag: str="kde") -> KDEOutputs:
    np.random.seed(cfg.seed)
    edges, spacing, origin, shape = compute_grid(pts, cfg.grid_base, cfg.margin)
    counts = voxelize_counts(pts, edges)
    density = kde_gaussian_splat(counts, cfg.sigma_vox)
    levels = quantile_levels(density, cfg.quantiles)
    meshes = {}
    for q in cfg.quantiles:
        mesh = extract_isosurface(density, level=levels[q], spacing=spacing, origin=origin)
        if out_dir:
            glb_path = os.path.join(out_dir, f"{tag}_shell_q{int(q)}.glb")
            export_glb(mesh, glb_path)
        meshes[q] = mesh
    return KDEOutputs(counts, density, tuple(edges), spacing, origin, shape, levels, meshes)


# === Interactive Plotly viewer for KDE shells + gene points ===
def _trimesh_to_mesh3d(mesh, name="shell", color="#1f77b4", opacity=0.28, showlegend=True):
    V = mesh.vertices
    F = mesh.faces
    return go.Mesh3d(
        x=V[:,0], y=V[:,1], z=V[:,2],
        i=F[:,0], j=F[:,1], k=F[:,2],
        name=name,
        opacity=opacity,
        color=color,
        flatshading=True,
        showscale=False,
        showlegend=showlegend,
        hovertemplate=f"{name}<extra></extra>",
    )

def _points_scatter3d(pts, size=2.3, name="genes", alpha=0.85):
    return go.Scatter3d(
        x=pts[:,0], y=pts[:,1], z=pts[:,2],
        mode="markers",
        marker=dict(size=size, opacity=alpha),
        name=name,
        hovertemplate="x=%{x:.2f}<br>y=%{y:.2f}<br>z=%{z:.2f}<extra></extra>",
    )

def _auto_bounds(pts):
    mins = pts.min(0); maxs = pts.max(0)
    ctr = (mins + maxs) / 2.0
    r = float((maxs - mins).max()) / 2.0
    return dict(
        x=[ctr[0]-r, ctr[0]+r],
        y=[ctr[1]-r, ctr[1]+r],
        z=[ctr[2]-r, ctr[2]+r],
    )

def visualize_plotly_kde(pts, meshes_by_quantile, point_stride=1, point_size=2.2,
                         shell_opacity=0.30, colors=None, title="KDE Hotspot Shells + Gene Points"):
    if colors is None:
        colors = {60:"#8c564b", 80:"#2ca02c", 90:"#ff7f0e", 95:"#1f77b4", 99:"#9467bd"}

    fig = go.Figure()

    # Points (stride for speed if needed)
    pts_plot = pts[::point_stride] if point_stride > 1 else pts
    fig.add_trace(_points_scatter3d(pts_plot, size=point_size))

    # Shells: plot outer first, inner last
    for q in sorted(meshes_by_quantile.keys()):
        mesh = meshes_by_quantile[q]
        fig.add_trace(_trimesh_to_mesh3d(
            mesh, name=f"Shell q{int(q)}",
            color=colors.get(int(q), "#7f7f7f"),
            opacity=shell_opacity
        ))

    b = _auto_bounds(pts)
    fig.update_layout(
        width=900, height=800,
        scene=dict(
            aspectmode="data",
            xaxis=dict(range=b["x"], title="X"),
            yaxis=dict(range=b["y"], title="Y"),
            zaxis=dict(range=b["z"], title="Z"),
        ),
        legend=dict(itemsizing="constant"),
        margin=dict(l=0, r=0, t=30, b=0),
        title=title
    )
    fig.show()



if __name__ == "__main__":
    # point loading
    CSV = "data/green_monkey/all_structure_files/chr1/12hrs/vacv/structure_12hrs_vacv_gene_info.csv"
    cfg = KDEConfig(
        grid_base=180,
        margin=10.0,
        sigma_vox=2.5,
        quantiles=(95, 90, 80, 60),
        csv_cols=("middle_x","middle_y","middle_z")
    )
    pts = load_points(csv_path=CSV, cols=cfg.csv_cols)

    # run KDE hotspot extraction
    out_dir = "data/green_monkey/va_testing/kde_hotspots/chr1_12h_vacv"  
    outs = run_kde_hotspots(pts, cfg, out_dir=out_dir, tag="chr1_12h_vacv")  

    print("Levels (density thresholds by percentile):", outs.levels)
    for q, m in outs.meshes.items():
        print(f"q{int(q)} shell: V={len(m.vertices):,}, F={len(m.faces):,}")

    # Plotly view
    # visualize_plotly_kde(
    #     pts, outs.meshes,
    #     point_stride=1,       # increase to 3/5 if notebook is slow
    #     shell_opacity=0.30
    # )




Levels (density thresholds by percentile): {95: 0.039810808002948755, 90: 0.026215554401278495, 80: 0.01241540145128966, 60: 0.001114151254296303}
q95 shell: V=9,678, F=19,332
q90 shell: V=14,284, F=28,584
q80 shell: V=17,694, F=35,392
q60 shell: V=21,026, F=42,048
