In [8]:
# === HDBSCAN Hotspots: auto-tuned clustering → clusters + optional shells + Plotly ===
# Input: CSV with ['middle_x','middle_y','middle_z']
# Output:
#   - per-gene labels + probabilities
#   - cluster summary DataFrame (size, stability, prob stats)
#   - optional GLB shells per cluster (KDE-on-cluster)
#   - interactive Plotly view

import os, json, math
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import plotly.graph_objects as go

try:
    import hdbscan
    from hdbscan.validity import validity_index as dbcv_index
except Exception as e:
    raise ImportError("Please install hdbscan: pip install hdbscan") from e

from scipy.ndimage import gaussian_filter
from skimage.measure import marching_cubes
import trimesh

import warnings, logging
warnings.filterwarnings("ignore")            # hide sklearn/numba/hdbscan chatter
logging.getLogger().setLevel(logging.ERROR)  # quiet generic loggers

# Config 
@dataclass
class HDBSCANConfig:
    scale_coords: bool = True         # standardize coords before clustering
    min_cluster_size_strategy: str = "auto"  # "auto" or "sqrtN" or "fixed"
    fixed_min_cluster_size: int = 20  # only used if strategy=="fixed"
    min_cluster_size_grid: Tuple[float, float, int] = (0.6, 2.2, 9)
    # grid over √N * [low..high] with 'steps' values -> tested mcs
    min_samples_options: Tuple[Optional[float], ...] = (None, 1, 0.5, 1.0)
    # None->equals mcs; numbers: absolute if >=2 else fraction of mcs
    allow_single_cluster: bool = False
    cluster_selection_method: str = "eom"    # "eom" (default) or "leaf"
    metric: str = "euclidean"
    # per‑cluster shells (KDE-on-cluster)
    make_shells: bool = True
    shell_quantiles: Tuple[float, ...] = (95, 90, 80, 60)  
    shell_sigma_vox: float = 1.8
    shell_grid_base: int = 96
    shell_margin: float = 3.0
    out_dir: Optional[str] = None
    seed: int = 7


# Data loading
def load_points(csv: Optional[str]=None, df: Optional[pd.DataFrame]=None,
                cols=("middle_x","middle_y","middle_z")) -> np.ndarray:
    if df is None and csv is None:
        raise ValueError("Provide csv or df")
    if df is None:
        df = pd.read_csv(csv)
    pts = df[list(cols)].dropna().values.astype(np.float32)
    if len(pts) < 10:
        raise ValueError(f"Too few points: {len(pts)}")
    return pts


# Auto-tuning
def _candidate_mcs(n: int, cfg: HDBSCANConfig) -> List[int]:
    if cfg.min_cluster_size_strategy == "fixed":
        return [max(5, int(cfg.fixed_min_cluster_size))]
    base = math.sqrt(n) if cfg.min_cluster_size_strategy in ("auto", "sqrtN") else cfg.fixed_min_cluster_size
    lo, hi, steps = cfg.min_cluster_size_grid
    vals = np.unique(np.clip(np.rint(base * np.geomspace(lo, hi, steps)), 5, max(5, n//5)).astype(int))
    return [int(v) for v in vals.tolist()]

def _candidate_min_samples(mcs: int, options: Tuple[Optional[float], ...]) -> List[Optional[int]]:
    out: List[Optional[int]] = []
    for opt in options:
        if opt is None:
            cand = None  # special case: HDBSCAN uses mcs
        elif opt >= 2:
            cand = int(opt)           # absolute value
        else:
            cand = max(1, int(round(mcs * float(opt))))  # fraction of mcs
        if cand not in out:           # preserve order, dedupe
            out.append(cand)
    return out


def _score_clustering(labels: np.ndarray, X: np.ndarray, clusterer: hdbscan.HDBSCAN) -> float:
    # Components: DBCV (density-based validity), silhouette on non-noise if valid,
    # penalty for high noise fraction, and reward for high mean membership prob.
    labs = labels
    n = len(labs)
    n_clusters = len(set(labs)) - (1 if -1 in labs else 0)
    noise_frac = np.mean(labs == -1)
    mean_prob = float(clusterer.probabilities_[labs != -1].mean()) if np.any(labs != -1) else 0.0

    # DBCV requires >= 2 clusters with non-noise
    dbcv = 0.0
    try:
        if n_clusters >= 2:
            dbcv = dbcv_index(X, labs)
    except Exception:
        dbcv = 0.0

    sil = 0.0
    try:
        valid = labs != -1
        if n_clusters >= 2 and valid.sum() > n_clusters:
            sil = silhouette_score(X[valid], labs[valid], metric="euclidean")
    except Exception:
        sil = 0.0

    # Composite score: primary=DBCV, then silhouette, then penalties/rewards
    score = (1.00 * dbcv) + (0.25 * sil) + (0.20 * (1.0 - noise_frac)) + (0.10 * mean_prob)
    return float(score)

def auto_hdbscan(X: np.ndarray, cfg: HDBSCANConfig):
    # Optional scaling
    Xs = StandardScaler().fit_transform(X) if cfg.scale_coords else X.copy()

    best = None
    tried = []

    for mcs in _candidate_mcs(len(X), cfg):
        for ms in _candidate_min_samples(mcs, cfg.min_samples_options):
            clusterer = hdbscan.HDBSCAN(
                min_cluster_size=int(mcs),
                min_samples=None if ms is None else int(ms),
                metric=cfg.metric,
                cluster_selection_method=cfg.cluster_selection_method,
                allow_single_cluster=cfg.allow_single_cluster,
                core_dist_n_jobs=1
            ).fit(Xs)
            labels = clusterer.labels_
            score = _score_clustering(labels, Xs, clusterer)
            tried.append(dict(min_cluster_size=mcs, min_samples=ms, score=score,
                              n_clusters=len(set(labels)) - (1 if -1 in labels else 0),
                              noise_frac=float(np.mean(labels==-1))))
            if best is None or score > best["score"]:
                best = dict(model=clusterer, labels=labels, score=score,
                            min_cluster_size=mcs, min_samples=ms)

    # pack diagnostics
    diag = pd.DataFrame(tried).sort_values("score", ascending=False).reset_index(drop=True)
    return best["model"], best["labels"], Xs, diag


# Optional cluster shells 
def _grid_from_points(P: np.ndarray, grid_base=96, margin=3.0):
    mins = P.min(0) - margin; maxs = P.max(0) + margin
    extent = maxs - mins
    scale = grid_base / max(float(extent.min()), 1e-6)
    shape = np.maximum(np.round(extent * scale).astype(int), 8)
    edges = [np.linspace(mins[i], maxs[i], int(shape[i]) + 1) for i in range(3)]
    spacing = extent / shape
    origin = mins
    return edges, spacing, origin, tuple(int(s) for s in shape)

def _voxelize(P: np.ndarray, edges):
    H, _ = np.histogramdd(P, bins=edges)
    return H.astype(np.float32)

def _cluster_shells(P: np.ndarray, cfg: HDBSCANConfig):
    """Return {q:mesh}, {q:level} for all requested quantiles on this cluster."""
    if len(P) < 10:
        return {}, {}
    # grid
    mins = P.min(0) - cfg.shell_margin; maxs = P.max(0) + cfg.shell_margin
    extent = maxs - mins
    scale = cfg.shell_grid_base / max(float(extent.min()), 1e-6)
    shape = np.maximum(np.round(extent * scale).astype(int), 8)
    edges = [np.linspace(mins[i], maxs[i], int(shape[i]) + 1) for i in range(3)]
    spacing = extent / shape
    origin = mins
    # density
    H, _ = np.histogramdd(P, bins=edges)
    dens = gaussian_filter(H.astype(np.float32), sigma=cfg.shell_sigma_vox, mode="constant")
    vals = dens[dens > 0]
    if vals.size == 0:
        return {}, {}
    levels = {q: float(np.percentile(vals, q)) for q in cfg.shell_quantiles}
    # meshes
    meshes = {}
    for q, lvl in levels.items():
        V, F, _, _ = marching_cubes(dens, level=lvl, spacing=spacing)
        V = V + origin
        meshes[q] = trimesh.Trimesh(vertices=V, faces=F, process=True)
    return meshes, levels




# outputs and run pipeline
@dataclass
class HDBSCANOutputs:
    labels: np.ndarray
    probabilities: np.ndarray
    clusterer: object
    X_scaled: np.ndarray
    tuning_diagnostics: pd.DataFrame
    cluster_summary: pd.DataFrame
    cluster_meshes: Dict[int, Dict[float, trimesh.Trimesh]]   # per cluster, per quantile
    cluster_levels: Dict[int, Dict[float, float]]             # per cluster, per quantile


def run_hdbscan_hotspots(pts: np.ndarray, cfg: HDBSCANConfig) -> HDBSCANOutputs:
    model, labels, Xs, diag = auto_hdbscan(pts, cfg)

    clusters = sorted([c for c in np.unique(labels) if c != -1])
    probs = model.probabilities_
    rows = []
    meshes_all: Dict[int, Dict[float, trimesh.Trimesh]] = {}
    levels_all: Dict[int, Dict[float, float]] = {}

    for cid in clusters:
        idx = (labels == cid)
        P = pts[idx]
        pr = probs[idx]

        rows.append(dict(
            cluster_id=int(cid),
            size=int(idx.sum()),
            prob_mean=float(pr.mean()),
            prob_median=float(np.median(pr)),
            stability=float(getattr(model, "cluster_persistence_", [None]*(cid+1))[cid]
                            if hasattr(model, "cluster_persistence_") else None),
            bbox_min_x=float(P[:,0].min()), bbox_min_y=float(P[:,1].min()), bbox_min_z=float(P[:,2].min()),
            bbox_max_x=float(P[:,0].max()), bbox_max_y=float(P[:,1].max()), bbox_max_z=float(P[:,2].max()),
        ))

        if cfg.make_shells:
            qmesh, qlevels = _cluster_shells(P, cfg)
            if qmesh:
                meshes_all[cid] = qmesh
                levels_all[cid] = qlevels
                if cfg.out_dir:
                    os.makedirs(cfg.out_dir, exist_ok=True)
                    for q, mesh in qmesh.items():
                        mesh.export(os.path.join(cfg.out_dir, f"cluster_{cid}_q{int(q)}.glb"))

    summary = pd.DataFrame(rows).sort_values("size", ascending=False).reset_index(drop=True)

    return HDBSCANOutputs(
        labels=labels,
        probabilities=probs,
        clusterer=model,
        X_scaled=Xs,
        tuning_diagnostics=diag,
        cluster_summary=summary,
        cluster_meshes=meshes_all,
        cluster_levels=levels_all
    )



# Plotly visualization
def _scatter_points_colored(pts, labels, probs, size=2.4):
    # map labels → colors; noise=-1 gray
    uniq = sorted(np.unique(labels).tolist())
    palette = ["#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
               "#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf"]
    color_map = {-1: "#bdbdbd"}
    ci = 0
    for u in uniq:
        if u == -1: 
            continue
        color_map[u] = palette[ci % len(palette)]
        ci += 1

    traces = []
    for u in uniq:
        mask = (labels == u)
        if not mask.any():
            continue
        if u == -1:
            traces.append(go.Scatter3d(
                x=pts[mask,0], y=pts[mask,1], z=pts[mask,2],
                mode="markers", name="noise",
                marker=dict(size=size, opacity=0.6, color=color_map[-1]),
                hovertemplate="noise<extra></extra>",
            ))
        else:
            traces.append(go.Scatter3d(
                x=pts[mask,0], y=pts[mask,1], z=pts[mask,2],
                mode="markers", name=f"cluster {u}",
                marker=dict(size=size, opacity=0.9, color=color_map[u]),
                # no %-formatting here; leave Plotly’s %{...} placeholders intact
                hovertemplate=f"cluster={u}<br>prob=%{{customdata:.2f}}<extra></extra>",
                customdata=probs[mask],
            ))
    return traces

def _concat_meshes(meshes: list[trimesh.Trimesh]):
    """Merge multiple trimesh objects into one Mesh3d payload for Plotly."""
    if not meshes:
        return None
    xs, ys, zs, is_, js_, ks_ = [], [], [], [], [], []
    v_offset = 0
    for m in meshes:
        V, F = m.vertices, m.faces
        xs.append(V[:, 0]); ys.append(V[:, 1]); zs.append(V[:, 2])
        is_.append(F[:, 0] + v_offset)
        js_.append(F[:, 1] + v_offset)
        ks_.append(F[:, 2] + v_offset)
        v_offset += V.shape[0]
    return dict(
        x=np.concatenate(xs),
        y=np.concatenate(ys),
        z=np.concatenate(zs),
        i=np.concatenate(is_),
        j=np.concatenate(js_),
        k=np.concatenate(ks_)
    )



def visualize_hdbscan_plotly(pts: np.ndarray, outs: HDBSCANOutputs,
                             show_shells=True, shell_opacity=0.28,
                             shell_colors=None):
    if shell_colors is None:
        shell_colors = {60:"#8c564b", 80:"#2ca02c", 90:"#ff7f0e", 95:"#1f77b4"}

    fig = go.Figure()

    # 1) points — one trace per label (noise, cluster 0, cluster 1, …)
    for tr in _scatter_points_colored(pts, outs.labels, outs.probabilities):
        tr.showlegend = True
        fig.add_trace(tr)

    # 2) shells — combine all clusters per quantile into one trace: q60/q80/q90/q95
    if show_shells and outs.cluster_meshes:
        # collect meshes by quantile
        by_q: Dict[int, list[trimesh.Trimesh]] = {}
        for _, qdict in outs.cluster_meshes.items():
            for q, mesh in qdict.items():
                by_q.setdefault(int(q), []).append(mesh)

        # draw in outer→inner order for visibility
        for q in sorted(by_q.keys()):  # 60,80,90,95
            payload = _concat_meshes(by_q[q])
            if payload is None:
                continue
            fig.add_trace(go.Mesh3d(
                **payload,
                name=f"q{q}",
                opacity=shell_opacity,
                color=shell_colors.get(int(q), "#7f7f7f"),
                flatshading=True,
                showscale=False,
                showlegend=True
            ))

    # bounds & layout
    mins, maxs = pts.min(0), pts.max(0)
    ctr = (mins+maxs)/2; r = float((maxs-mins).max())/2
    fig.update_layout(
        width=1000, height=800,
        scene=dict(
            aspectmode="data",
            xaxis=dict(range=[ctr[0]-r, ctr[0]+r], title="X"),
            yaxis=dict(range=[ctr[1]-r, ctr[1]+r], title="Y"),
            zaxis=dict(range=[ctr[2]-r, ctr[2]+r], title="Z"),
        ),
        legend=dict(itemsizing="constant"),
        margin=dict(l=0,r=0,t=30,b=0),
        title="HDBSCAN Hotspots: clusters (color) + 60/80/90/95% shells"
    )
    fig.show()



# run
if __name__ == "__main__":
    CSV = "data/green_monkey/all_structure_files/chr1/12hrs/vacv/structure_12hrs_vacv_gene_info.csv"
    pts = load_points(csv=CSV)

    cfg = HDBSCANConfig(
        scale_coords=True,
        min_cluster_size_strategy="auto",          # auto over √N * [0.6..2.2]
        min_cluster_size_grid=(0.6, 2.2, 9),
        min_samples_options=(None, 1, 0.5, 1.0), # None=mcs; 1; 0.5*mcs; 1.0*mcs
        cluster_selection_method="eom",
        metric="euclidean",
        make_shells=True,
        shell_quantiles=(95, 90, 80, 60),
        shell_sigma_vox=1.8,
        shell_grid_base=96,
        shell_margin=3.0,
        out_dir="data/green_monkey/va_testing/hdbscan/chr1_12h_vacv"
    )


    outs = run_hdbscan_hotspots(pts, cfg)

    # Diagnostics (which params won)
    print(outs.tuning_diagnostics.head(8))
    print("\nBest model:",
          "min_cluster_size=", outs.clusterer.min_cluster_size,
          "min_samples=", outs.clusterer.min_samples,
          "n_clusters=", len(set(outs.labels))- (1 if -1 in outs.labels else 0),
          "noise_frac=", float(np.mean(outs.labels==-1)))

    # Cluster summary table
    print("\nCluster summary:")
    print(outs.cluster_summary)

    # Interactive view
    # visualize_hdbscan_plotly(pts, outs, show_shells=True)


   min_cluster_size  min_samples     score  n_clusters  noise_frac
0                82         41.0  0.340441           2    0.352149
1                97         48.0  0.322484           2    0.448990
2                59          NaN  0.310227           2    0.504920
3                59         59.0  0.310227           2    0.504920
4                97         97.0  0.302751           2    0.644744
5                97          NaN  0.302751           2    0.644744
6                70         70.0  0.299722           2    0.608493
7                70          NaN  0.299722           2    0.608493

Best model: min_cluster_size= 82 min_samples= 41 n_clusters= 2 noise_frac= 0.35214914552045573

Cluster summary:
   cluster_id  size  prob_mean  prob_median  stability  bbox_min_x  \
0           0   664   0.945127     0.990688   0.152932  -48.649597   
1           1   587   0.980024     1.000000   0.057115  -48.488594   

   bbox_min_y  bbox_min_z  bbox_max_x  bbox_max_y  bbox_max_z  
0   12.4