In [4]:
# === HDBSCAN Hotspots — batch over *_aligned.json
# Updates JSON with cluster labels (0=noise, 1..K) and exports GLB shells.
# IMPORTANT: Clusters & shells are computed on CSV gene centers when available
#            , then labels are mapped back to JSON.
#
# Inputs:
#   data/green_monkey/structure_genes_aligned/<chr>/<chr>_*_aligned.json  (or *_alinged.json)
#   data/green_monkey/all_structure_files/<chr>/<time>/<cond>/structure_<time>_<cond>_gene_info.csv
#
# Outputs:
#   1) Updates each *_aligned.json → {"position":[[x,y,z]...], "clusters":[...]}
#   2) GLBs: data/green_monkey/density_shells_hdbscan/<chr>/<time>_<cond>/<chr>_<time>_<cond>_c{cid}_q{q}.glb
#
# Shell params (MATCH TEST PIPELINE): q=(95,90,80,60), sigma=1.8, grid=96, margin=3.0

import os, re, json, math, traceback
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score

import hdbscan
from hdbscan.validity import validity_index as dbcv_index
from scipy.ndimage import gaussian_filter
from skimage.measure import marching_cubes
import trimesh

import warnings, logging
warnings.filterwarnings("ignore")
logging.getLogger().setLevel(logging.ERROR)

# ---------- Roots ----------
BASE_DIR        = Path("data/green_monkey/structure_genes_aligned")
ALL_STRUCT_ROOT = Path("data/green_monkey/all_structure_files")     # CSVs live here
SHELLS_ROOT     = Path("data/green_monkey/density_shells_hdbscan")

VALIDCONDS = {"untr", "vacv"}   # extend if needed

# ---------- Config (matches test pipeline) ----------
@dataclass
class HDBSCANConfig:
    # clustering
    scale_coords: bool = True
    min_cluster_size_strategy: str = "auto"             # "auto" | "sqrtN" | "fixed"
    fixed_min_cluster_size: int = 20
    min_cluster_size_grid: Tuple[float, float, int] = (0.6, 2.2, 9)   # √N * [lo..hi]
    min_samples_options: Tuple[Optional[float], ...] = (None, 1, 0.5, 1.0)
    allow_single_cluster: bool = False
    cluster_selection_method: str = "eom"
    metric: str = "euclidean"

    # shells — SAME as test
    make_shells: bool = True
    shell_quantiles: Tuple[float, ...] = (95, 90, 80, 60)
    shell_sigma_vox: float = 1.8
    shell_grid_base: int = 96
    shell_margin: float = 3.0

    # export
    export_format: str = "glb"
    decimate_ratio: Optional[float] = 0.0   # 0 to skip; else 0<r<1 keeps r*faces

# ---------- HDBSCAN utilities ----------
def _candidate_mcs(n: int, cfg: HDBSCANConfig) -> List[int]:
    if cfg.min_cluster_size_strategy == "fixed":
        return [max(5, int(cfg.fixed_min_cluster_size))]
    base = math.sqrt(n)
    lo, hi, steps = cfg.min_cluster_size_grid
    vals = np.unique(np.clip(np.rint(base * np.geomspace(lo, hi, steps)), 5, max(5, n//5)).astype(int))
    return [int(v) for v in vals.tolist()]

def _candidate_min_samples(mcs: int, options: Tuple[Optional[float], ...]) -> List[Optional[int]]:
    out: List[Optional[int]] = []
    for opt in options:
        if opt is None: cand = None
        elif opt >= 2:  cand = int(opt)
        else:           cand = max(1, int(round(mcs * float(opt))))
        if cand not in out: out.append(cand)
    return out

def _score_clustering(labels: np.ndarray, X: np.ndarray, clusterer: hdbscan.HDBSCAN) -> float:
    k = len(set(labels)) - (1 if -1 in labels else 0)
    noise = float(np.mean(labels == -1))
    meanp = float(clusterer.probabilities_[labels != -1].mean()) if np.any(labels != -1) else 0.0
    dbcv = 0.0
    try:
        if k >= 2: dbcv = dbcv_index(X, labels)
    except Exception: pass
    sil = 0.0
    try:
        mask = labels != -1
        if k >= 2 and mask.sum() > k: sil = silhouette_score(X[mask], labels[mask], metric="euclidean")
    except Exception: pass
    return 1.00*dbcv + 0.25*sil + 0.20*(1.0 - noise) + 0.10*meanp

def auto_hdbscan(X: np.ndarray, cfg: HDBSCANConfig):
    Xs = StandardScaler().fit_transform(X) if cfg.scale_coords else X.copy()
    best = None; tried = []
    for mcs in _candidate_mcs(len(X), cfg):
        for ms in _candidate_min_samples(mcs, cfg.min_samples_options):
            model = hdbscan.HDBSCAN(
                min_cluster_size=int(mcs),
                min_samples=None if ms is None else int(ms),
                metric=cfg.metric,
                cluster_selection_method=cfg.cluster_selection_method,
                allow_single_cluster=cfg.allow_single_cluster,
                core_dist_n_jobs=1
            ).fit(Xs)
            labs = model.labels_
            score = _score_clustering(labs, Xs, model)
            tried.append(dict(min_cluster_size=mcs, min_samples=ms, score=score,
                              n_clusters=len(set(labs))-(1 if -1 in labs else 0),
                              noise=float(np.mean(labs==-1))))
            if best is None or score > best["score"]:
                best = dict(model=model, labels=labs, score=score,
                            min_cluster_size=mcs, min_samples=ms)
    diag = pd.DataFrame(tried).sort_values("score", ascending=False).reset_index(drop=True)
    return best["model"], best["labels"], Xs, diag

# ---------- Shell helpers (match test) ----------
def _cluster_shells(P: np.ndarray, cfg: HDBSCANConfig):
    if len(P) < 10: return {}, {}
    mins = P.min(0) - cfg.shell_margin
    maxs = P.max(0) + cfg.shell_margin
    extent = maxs - mins
    scale  = cfg.shell_grid_base / max(float(extent.min()), 1e-6)
    shape  = np.maximum(np.round(extent * scale).astype(int), 8)
    edges  = [np.linspace(mins[i], maxs[i], int(shape[i]) + 1) for i in range(3)]
    spacing = extent / shape
    origin  = mins

    H, _ = np.histogramdd(P, bins=edges)
    dens = gaussian_filter(H.astype(np.float32), sigma=cfg.shell_sigma_vox, mode="constant")
    vals = dens[dens > 0]
    if vals.size == 0: return {}, {}
    levels = {q: float(np.percentile(vals, q)) for q in cfg.shell_quantiles}

    meshes = {}
    for q, lvl in levels.items():
        V, F, _, _ = marching_cubes(dens, level=lvl, spacing=spacing)
        V = V + origin
        meshes[q] = trimesh.Trimesh(vertices=V, faces=F, process=True)
    return meshes, levels

# ---------- JSON + CSV helpers ----------
# Accept BOTH: ..._aligned.json  /  ..._gene_aligned.json  (and typo _alinged)
FNAME_RE = re.compile(
    r"""^
    (?P<chr>chr[0-9XYM]+)
    _(?P<time>\d+hr?s?)
    _(?P<cond>[A-Za-z]+)
    (?:_gene)?_
    (?:aligned|alinged)\.json$
    """, re.VERBOSE | re.IGNORECASE
)

def parse_fname(p: Path) -> Optional[Tuple[str, str, str]]:
    m = FNAME_RE.match(p.name)
    if not m: return None
    return m.group("chr").lower(), m.group("time").lower(), m.group("cond").lower()

def _peer_csv_path(chr_name: str, time_name: str, cond_name: str) -> Path:
    # time_name like "12hrs" or "12hr"
    return (ALL_STRUCT_ROOT / chr_name / time_name / cond_name /
            f"structure_{time_name}_{cond_name}_gene_info.csv")

def _load_gene_csv_points(csv_path: Path) -> Tuple[np.ndarray, List[str]]:
    """Load CSV gene centers and gene_name column."""
    df = pd.read_csv(csv_path)
    pts_df = df[["middle_x", "middle_y", "middle_z"]].dropna()
    if "gene_name" not in df.columns:
        raise ValueError(f"'gene_name' column not found in {csv_path}")
    names = df.loc[pts_df.index, "gene_name"].astype(str).tolist()
    pts = pts_df.values.astype(np.float32)
    if len(pts) < 10:
        raise ValueError(f"Too few CSV gene centers: {len(pts)} in {csv_path}")
    return pts, names

def _load_positions_from_json(jpath: Path) -> Tuple[np.ndarray, Optional[List[str]]]:
    """Robust loader: dict{'position'|'positions','ids'?} | list[[x,y,z]...] | list[{x,y,z}|{middle_x,...}]"""
    with jpath.open("r") as f:
        data = json.load(f)

    ids: Optional[List[str]] = None

    if isinstance(data, dict):
        pts = data.get("positions", data.get("position"))
        if pts is None: raise ValueError(f"No 'positions' or 'position' in {jpath}")
        arr = np.asarray(pts, dtype=float)
        # optional ids parallel to positions
        if "ids" in data:
            ids = [str(x) for x in data["ids"]]
            if len(ids) != len(arr):
                raise ValueError(f"'ids' length != positions length in {jpath}")
    elif isinstance(data, list):
        if not data: raise ValueError(f"Empty list in {jpath}")
        first = data[0]
        if isinstance(first, (list, tuple)) and len(first) == 3:
            arr = np.asarray(data, dtype=float)
        elif isinstance(first, dict):
            if all(k in first for k in ("x","y","z")):
                arr = np.asarray([[d["x"], d["y"], d["z"]] for d in data], dtype=float)
            elif all(k in first for k in ("middle_x","middle_y","middle_z")):
                arr = np.asarray([[d["middle_x"], d["middle_y"], d["middle_z"]] for d in data], dtype=float)
            else:
                raise ValueError(f"List of dicts in {jpath} missing (x,y,z) or (middle_x, middle_y, middle_z)")
        else:
            raise ValueError(f"Unsupported list element type {type(first)} in {jpath}")
    else:
        raise ValueError(f"Unsupported JSON root type {type(data)} in {jpath}")

    if arr.ndim != 2 or arr.shape[1] != 3:
        raise ValueError(f"Expected Nx3 positions in {jpath}, got {arr.shape}")
    return arr.astype(np.float32), ids

def _write_clusters_back_json(jpath: Path, positions: np.ndarray, labels: np.ndarray):
    remapped = np.where(labels == -1, 0, labels + 1)  # 0=noise, 1..K
    with jpath.open("r") as f:
        data = json.load(f)
    data["position"] = np.asarray(positions, dtype=float).tolist()
    if "positions" in data:
        try: del data["positions"]
        except Exception: pass
    data["clusters"] = remapped.astype(int).tolist()
    tmp = jpath.with_suffix(".json.tmp")
    with tmp.open("w") as f:
        json.dump(data, f, separators=(",", ":"), ensure_ascii=False, indent=2)
    tmp.replace(jpath)

# ---------- Per-file run ----------
def process_aligned_json(jpath: Path, cfg: HDBSCANConfig):
    parsed = parse_fname(jpath)
    if not parsed:
        raise ValueError(f"Unrecognized filename pattern: {jpath.name}")
    chr_name, time_name, cond_name = parsed
    if cond_name not in VALIDCONDS:
        raise ValueError(f"Unknown condition '{cond_name}' in {jpath.name} (allowed: {sorted(VALIDCONDS)})")

    # Load JSON beads (+ ids) and CSV gene centers (+ names)
    json_pts, json_ids = _load_positions_from_json(jpath)
    csv_path = _peer_csv_path(chr_name, time_name, cond_name)
    if not csv_path.exists():
        raise FileNotFoundError(f"Peer CSV not found for {jpath.name}: {csv_path}")

    csv_pts, csv_names = _load_gene_csv_points(csv_path)
    print(f"[DATA] {jpath.name}: JSON N={len(json_pts)}  CSV N={len(csv_pts)}  (clustering on CSV; mapping by gene_name)")

    if json_ids is None:
        raise ValueError(
            f"{jpath.name} missing 'ids'. "
            f"This pipeline requires JSON 'ids' to match CSV 'gene_name' for label mapping."
        )

    # Cluster on CSV gene centers (matches test)
    model, base_labels, _, _ = auto_hdbscan(csv_pts, cfg)

    # Map labels to JSON by gene_name (NO nearest-neighbor)
    # Build name -> label dict (case-sensitive by default)
    name_to_label: Dict[str, int] = {str(n): int(l) for n, l in zip(csv_names, base_labels)}
    # If case-insensitive mapping is needed, uncomment the next line and the .lower() usage below:
    # name_to_label = {str(n).lower(): int(l) for n, l in zip(csv_names, base_labels)}

    json_labels = np.array([name_to_label.get(str(g)) for g in json_ids], dtype=object)
    # Any missing names → mark as noise (-1)
    missing_mask = pd.isna(json_labels)
    if missing_mask.any():
        nmiss = int(missing_mask.sum())
        print(f"[WARN] {jpath.name}: {nmiss} / {len(json_labels)} ids not found in CSV gene_name — set to noise.")
        json_labels[missing_mask] = -1
    json_labels = json_labels.astype(int)

    # Write clusters to JSON
    _write_clusters_back_json(jpath, json_pts, json_labels)

    # Export GLB shells from the same base set we clustered (matches test visuals)
    if cfg.make_shells:
        uniq = sorted([u for u in np.unique(base_labels) if u != -1])
        out_dir = SHELLS_ROOT / chr_name / f"{time_name}_{cond_name}"
        out_dir.mkdir(parents=True, exist_ok=True)
        for cid in uniq:
            P = csv_pts[base_labels == cid]
            meshes, levels = _cluster_shells(P, cfg)
            for q, mesh in meshes.items():
                # optional decimation
                if cfg.decimate_ratio and cfg.decimate_ratio > 0.0:
                    try:
                        target = int(len(mesh.faces) * float(cfg.decimate_ratio))
                        if 50 < target < len(mesh.faces):
                            mesh = mesh.simplify_quadratic_decimation(target)
                    except Exception:
                        pass
                suffix = cfg.export_format.lower().strip(".")
                opath = out_dir / f"{chr_name}_{time_name}_{cond_name}_c{cid}_q{int(q)}.{suffix}"
                mesh.export(opath)
                try:
                    bmin, bmax = mesh.bounds
                    print(f"[SHELL] {opath.name} V={len(mesh.vertices)} F={len(mesh.faces)} "
                          f"bbox=({np.round(bmin,2)}→{np.round(bmax,2)}) level={levels[q]:.4f}")
                except Exception:
                    print(f"[SHELL] {opath.name} V={len(mesh.vertices)} F={len(mesh.faces)}")

    n_clusters = int(len(set(base_labels)) - (1 if -1 in base_labels else 0))
    noise_frac = float(np.mean(base_labels == -1))
    return dict(chr=chr_name, time=time_name, cond=cond_name,
                n_points=int(len(csv_pts)), n_clusters=n_clusters,
                noise_frac=noise_frac,
                mcs=int(model.min_cluster_size),
                ms=None if model.min_samples is None else int(model.min_samples))

# ---------- Driver ----------
def main():
    SHELLS_ROOT.mkdir(parents=True, exist_ok=True)
    cfg = HDBSCANConfig()  # uses the test settings
    print(f"[CFG] shells q={cfg.shell_quantiles} sigma={cfg.shell_sigma_vox} grid={cfg.shell_grid_base}")

    chrom_dirs = [d for d in BASE_DIR.iterdir() if d.is_dir()]
    done = skipped = failed = 0

    for chrom_dir in sorted(chrom_dirs, key=lambda p: p.name):
        chr_name = chrom_dir.name
        print(f"\n=== {chr_name} ===")

        # pick up aligned + typo, with or without _gene
        jsons = sorted({
            *chrom_dir.glob(f"{chr_name}_*_aligned.json"),
            *chrom_dir.glob(f"{chr_name}_*_alinged.json"),
            *chrom_dir.glob(f"{chr_name}_*_gene_aligned.json"),
            *chrom_dir.glob(f"{chr_name}_*_gene_alinged.json"),
        }, key=lambda p: p.name)

        if not jsons:
            print(f"[SKIP] No *_aligned.json in {chrom_dir}")
            skipped += 1
            continue

        ok = 0
        for jp in jsons:
            try:
                info = process_aligned_json(jp, cfg)
                print(f"[OK] {jp.name}: pts={info['n_points']} "
                      f"clusters={info['n_clusters']} noise={info['noise_frac']:.2f} "
                      f"mcs={info['mcs']} ms={info['ms']}")
                ok += 1
            except Exception as e:
                print(f"[FAIL] {jp.name}: {e}")
                traceback.print_exc()
                failed += 1

        done += (1 if ok > 0 else 0)

    print("\n=== Summary ===")
    print(f"Completed chromosomes: {done}")
    print(f"Skipped chromosomes:   {skipped}")
    print(f"Failures:              {failed}")

if __name__ == "__main__":
    main()


[CFG] shells q=(95, 90, 80, 60) sigma=1.8 grid=96

=== chr1 ===
[DATA] chr1_12hrs_untr_aligned.json: JSON N=1931  CSV N=1931  (clustering on CSV; mapping by gene_name)
[SHELL] chr1_12hrs_untr_c0_q95.glb V=6280 F=12508 bbox=([-43.03  38.23 -54.54]→[-36.87  43.83 -49.96]) level=0.0485
[SHELL] chr1_12hrs_untr_c0_q90.glb V=9324 F=18648 bbox=([-43.11  38.13 -54.76]→[-36.79  43.93 -49.85]) level=0.0323
[SHELL] chr1_12hrs_untr_c0_q80.glb V=12475 F=24962 bbox=([-43.24  37.92 -54.97]→[-36.68  44.06 -48.76]) level=0.0145
[SHELL] chr1_12hrs_untr_c0_q60.glb V=13888 F=27764 bbox=([-43.54  37.4  -55.23]→[-36.48  44.29 -48.32]) level=0.0019
[SHELL] chr1_12hrs_untr_c1_q95.glb V=6120 F=12172 bbox=([-41.18  39.87 -50.69]→[-37.65  44.86 -46.91]) level=0.0240
[SHELL] chr1_12hrs_untr_c1_q90.glb V=9627 F=19258 bbox=([-41.24  39.82 -50.75]→[-37.56  44.93 -46.7 ]) level=0.0159
[SHELL] chr1_12hrs_untr_c1_q80.glb V=13169 F=26362 bbox=([-41.33  39.74 -50.83]→[-37.45  45.02 -46.5 ]) level=0.0075
[SHELL] chr1_12hr