In [4]:
# === HDBSCAN Hotspots (Batch for flat JSON filenames) ===
# Layout:
#   data/green_monkey/all_structure_files/
#     chr1/
#       chr1_12hr_UNTR_aligned.json
#       chr1_12hr_VACV_aligned.json
#     chr2/
#       chr2_18hrs_UNTR_aligned.json
#       ...
#
# Outputs:
#   1) Updates each *_aligned.json → {"position":[[x,y,z],...], "clusters":[...]} with 0=noise, 1..K=clusters
#   2) OBJ shells in: data/green_monkey/density_shells_hdbscan/<chr>/<time>_<cond>/
#        e.g., .../chr1/12hr_untr/12hr_untr_q95.obj

import os, re, json, math, traceback
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score

import hdbscan
from hdbscan.validity import validity_index as dbcv_index
from scipy.ndimage import gaussian_filter
from skimage.measure import marching_cubes
import trimesh

import warnings, logging
warnings.filterwarnings("ignore")            # hide sklearn/numba/hdbscan chatter
logging.getLogger().setLevel(logging.ERROR)  # quiet generic loggers

# ---------- Roots & ordering ----------
BASE_DIR = Path("data/green_monkey/structure_genes_aligned")
SHELLS_ROOT = Path("data/green_monkey/density_shells_hdbscan")

# Support both "hr" and "hrs"
TIMES_ORDER = ["12hrs", "18hrs", "24hrs"]
CONDS_ORDER = ["untr", "vacv"]
VALIDCONDS = set(CONDS_ORDER)

# ---------- Config ----------
@dataclass
class HDBSCANConfig:
    scale_coords: bool = True         # standardize coords before clustering
    min_cluster_size_strategy: str = "auto"  # "auto" or "sqrtN" or "fixed"
    fixed_min_cluster_size: int = 20  # only used if strategy=="fixed"
    min_cluster_size_grid: Tuple[float, float, int] = (0.6, 2.2, 9)
    # grid over √N * [low..high] with 'steps' values -> tested mcs
    min_samples_options: Tuple[Optional[float], ...] = (None, 1, 0.5, 1.0)
    # None->equals mcs; numbers: absolute if >=2 else fraction of mcs
    allow_single_cluster: bool = False
    cluster_selection_method: str = "eom"    # "eom" (default) or "leaf"
    metric: str = "euclidean"
    # per‑cluster shells
    make_shells: bool = True
    shell_quantiles: Tuple[float, ...] = (95, 90, 80, 60)  
    shell_sigma_vox: float = 1.8
    shell_grid_base: int = 96
    shell_margin: float = 3.0
    out_dir: Optional[str] = None
    seed: int = 7

# ---------- HDBSCAN utilities ----------
def _candidate_mcs(n: int, cfg: HDBSCANConfig) -> List[int]:
    if cfg.min_cluster_size_strategy == "fixed":
        return [max(5, int(cfg.fixed_min_cluster_size))]
    base = math.sqrt(n) if cfg.min_cluster_size_strategy in ("auto", "sqrtN") else cfg.fixed_min_cluster_size
    lo, hi, steps = cfg.min_cluster_size_grid
    vals = np.unique(np.clip(np.rint(base * np.geomspace(lo, hi, steps)), 5, max(5, n//5)).astype(int))
    return [int(v) for v in vals.tolist()]

def _candidate_min_samples(mcs: int, options: Tuple[Optional[float], ...]) -> List[Optional[int]]:
    out: List[Optional[int]] = []
    for opt in options:
        if opt is None:
            cand = None  # special case: HDBSCAN uses mcs
        elif opt >= 2:
            cand = int(opt)           # absolute value
        else:
            cand = max(1, int(round(mcs * float(opt))))  # fraction of mcs
        if cand not in out:           # preserve order, dedupe
            out.append(cand)
    return out

def _score_clustering(labels: np.ndarray, X: np.ndarray, clusterer: hdbscan.HDBSCAN) -> float:
    # Components: DBCV (density-based validity), silhouette on non-noise if valid,
    # penalty for high noise fraction, and reward for high mean membership prob.
    labs = labels
    n = len(labs)
    n_clusters = len(set(labs)) - (1 if -1 in labs else 0)
    noise_frac = np.mean(labs == -1)
    mean_prob = float(clusterer.probabilities_[labs != -1].mean()) if np.any(labs != -1) else 0.0

    # DBCV requires >= 2 clusters with non-noise
    dbcv = 0.0
    try:
        if n_clusters >= 2:
            dbcv = dbcv_index(X, labs)
    except Exception:
        dbcv = 0.0

    sil = 0.0
    try:
        valid = labs != -1
        if n_clusters >= 2 and valid.sum() > n_clusters:
            sil = silhouette_score(X[valid], labs[valid], metric="euclidean")
    except Exception:
        sil = 0.0

    # Composite score: primary=DBCV, then silhouette, then penalties/rewards
    score = (1.00 * dbcv) + (0.25 * sil) + (0.20 * (1.0 - noise_frac)) + (0.10 * mean_prob)
    return float(score)

def auto_hdbscan(X: np.ndarray, cfg: HDBSCANConfig):
    Xs = StandardScaler().fit_transform(X) if cfg.scale_coords else X.copy()
    best = None
    tried = []
    for mcs in _candidate_mcs(len(X), cfg):
        for ms in _candidate_min_samples(mcs, cfg.min_samples_options):
            model = hdbscan.HDBSCAN(
                min_cluster_size=int(mcs),
                min_samples=None if ms is None else int(ms),
                metric=cfg.metric,
                cluster_selection_method=cfg.cluster_selection_method,
                allow_single_cluster=cfg.allow_single_cluster,
                core_dist_n_jobs=1
            ).fit(Xs)
            labels = model.labels_
            score = _score_clustering(labels, Xs, model)
            tried.append(dict(min_cluster_size=mcs, min_samples=ms, score=score,
                              n_clusters=len(set(labels)) - (1 if -1 in labels else 0),
                              noise_frac=float(np.mean(labels == -1))))
            if best is None or score > best["score"]:
                best = dict(model=model, labels=labels, score=score,
                            min_cluster_size=mcs, min_samples=ms)
    diag = pd.DataFrame(tried).sort_values("score", ascending=False).reset_index(drop=True)
    return best["model"], best["labels"], Xs, diag

def _cluster_shells(P: np.ndarray, cfg: HDBSCANConfig):
    if len(P) < 10:
        return {}, {}
    mins = P.min(0) - cfg.shell_margin
    maxs = P.max(0) + cfg.shell_margin
    extent = maxs - mins
    scale = cfg.shell_grid_base / max(float(extent.min()), 1e-6)
    shape = np.maximum(np.round(extent * scale).astype(int), 8)
    edges = [np.linspace(mins[i], maxs[i], int(shape[i]) + 1) for i in range(3)]
    spacing = extent / shape
    origin = mins

    H, _ = np.histogramdd(P, bins=edges)
    dens = gaussian_filter(H.astype(np.float32), sigma=cfg.shell_sigma_vox, mode="constant")
    vals = dens[dens > 0]
    if vals.size == 0:
        return {}, {}
    levels = {q: float(np.percentile(vals, q)) for q in cfg.shell_quantiles}

    meshes = {}
    for q, lvl in levels.items():
        V, F, _, _ = marching_cubes(dens, level=lvl, spacing=spacing)
        V = V + origin
        meshes[q] = trimesh.Trimesh(vertices=V, faces=F, process=True)
    return meshes, levels

# ---------- JSON helpers ----------
def _load_positions_from_json(jpath: Path) -> np.ndarray:
    with jpath.open("r") as f:
        data = json.load(f)
    pts = data.get("positions", data.get("position", None))
    if pts is None:
        raise ValueError(f"No 'positions' or 'position' in {jpath}")
    arr = np.asarray(pts, dtype=float)
    if arr.ndim != 2 or arr.shape[1] != 3:
        raise ValueError(f"Expected Nx3 positions in {jpath}, got shape {arr.shape}")
    return arr.astype(np.float32)

def _write_clusters_back_json(jpath: Path, positions: np.ndarray, labels: np.ndarray):
    # Remap to 0=noise, 1..K clusters
    remapped = np.where(labels == -1, 0, labels + 1)

    with jpath.open("r") as f:
        data = json.load(f)

    data["position"] = np.asarray(positions, dtype=float).tolist()  # singular
    if "positions" in data:
        del data["positions"]
    data["clusters"] = remapped.astype(int).tolist()

    tmp = jpath.with_suffix(".json.tmp")
    with tmp.open("w") as f:
        json.dump(data, f, separators=(",", ":"), ensure_ascii=False, indent=2)
    tmp.replace(jpath)

# ---------- Filename parsing ----------
# Accepts: chr1_12hr_UNTR_aligned.json  OR  chr1_12hrs_VACV_aligned.json
FNAME_RE = re.compile(r"^(?P<chr>chr[^_]+)_(?P<time>\d+hr?s?)_(?P<cond>[A-Za-z]+)_aligned\.json$")

def parse_fname(p: Path) -> Optional[Tuple[str, str, str]]:
    m = FNAME_RE.match(p.name)
    if not m:
        return None
    chr_name = m.group("chr")
    time_name = m.group("time").lower()      # keep hr/hrs as given
    cond_name = m.group("cond").lower()
    return chr_name, time_name, cond_name

# ---------- Per-file run ----------
def process_aligned_json(jpath: Path, cfg: HDBSCANConfig):
    parsed = parse_fname(jpath)
    if not parsed:
        raise ValueError(f"Unrecognized filename pattern: {jpath.name}")
    chr_name, time_name, cond_name = parsed
    if cond_name not in VALIDCONDS:
        raise ValueError(f"Unknown condition '{cond_name}' in {jpath.name}")

    # Load points
    pts = _load_positions_from_json(jpath)
    if len(pts) < 10:
        raise ValueError(f"Too few points ({len(pts)}) in {jpath}")

    # Cluster (auto-tuned)
    model, labels, _, _ = auto_hdbscan(pts, cfg)

    # Write clusters back to the SAME file
    _write_clusters_back_json(jpath, pts, labels)

    # Shells (skip noise)
    if cfg.make_shells:
        uniq = sorted([u for u in np.unique(labels) if u != -1])
        out_dir = SHELLS_ROOT / chr_name / f"{time_name}_{cond_name}"
        out_dir.mkdir(parents=True, exist_ok=True)
        for cid in uniq:
            P = pts[labels == cid]
            meshes, _ = _cluster_shells(P, cfg)
            for q, mesh in meshes.items():
                opath = out_dir / f"{time_name}_{cond_name}_q{int(q)}.obj"
                mesh.export(opath)

    n_clusters = int(len(set(labels)) - (1 if -1 in labels else 0))
    noise_frac = float(np.mean(labels == -1))
    return dict(chr=chr_name, time=time_name, cond=cond_name,
                n_points=int(len(pts)),
                n_clusters=n_clusters,
                noise_frac=noise_frac,
                mcs=int(model.min_cluster_size),
                ms=None if model.min_samples is None else int(model.min_samples))

# ---------- Driver ----------
def main():
    SHELLS_ROOT.mkdir(parents=True, exist_ok=True)
    cfg = HDBSCANConfig(
        scale_coords=True,
        min_cluster_size_strategy="auto",
        min_cluster_size_grid=(0.6, 2.2, 9),
        min_samples_options=(None, 1, 0.5, 1.0),
        cluster_selection_method="eom",
        metric="euclidean",
        make_shells=True,
        shell_quantiles=(95, 90, 80, 60),
        shell_sigma_vox=1.8,
        shell_grid_base=96,
        shell_margin=3.0,
    )

    chrom_dirs = [d for d in BASE_DIR.iterdir() if d.is_dir()]

    done, skipped, failed = 0, 0, 0
    for chrom_dir in sorted(chrom_dirs, key=lambda p: p.name):
        chr_name = chrom_dir.name
        print(f"\n=== {chr_name} ===")
        jsons = sorted(chrom_dir.glob(f"{chr_name}_*_aligned.json"))
        if not jsons:
            print(f"[SKIP] No *_aligned.json in {chrom_dir}")
            skipped += 1
            continue

        ok = 0
        for jp in jsons:
            try:
                info = process_aligned_json(jp, cfg)
                print(f"[OK] {jp.name}: pts={info['n_points']} "
                      f"clusters={info['n_clusters']} noise={info['noise_frac']:.2f} "
                      f"mcs={info['mcs']} ms={info['ms']}")
                ok += 1
            except Exception as e:
                print(f"[FAIL] {jp.name}: {e}")
                traceback.print_exc()
                failed += 1

        done += (1 if ok > 0 else 0)

    print("\n=== Summary ===")
    print(f"Completed chromosomes: {done}")
    print(f"Skipped chromosomes:   {skipped}")
    print(f"Failures:              {failed}")

if __name__ == "__main__":
    main()



=== chr1 ===
[OK] chr1_12hrs_UNTR_aligned.json: pts=1931 clusters=2 noise=0.26 mcs=26 ms=None
[OK] chr1_12hrs_VACV_aligned.json: pts=1931 clusters=5 noise=0.22 mcs=26 ms=13
[OK] chr1_18hrs_UNTR_aligned.json: pts=1931 clusters=3 noise=0.42 mcs=59 ms=30
[OK] chr1_18hrs_VACV_aligned.json: pts=1931 clusters=2 noise=0.55 mcs=97 ms=None
[OK] chr1_24hrs_UNTR_aligned.json: pts=1931 clusters=3 noise=0.27 mcs=31 ms=16
[OK] chr1_24hrs_VACV_aligned.json: pts=1931 clusters=5 noise=0.25 mcs=31 ms=16

=== chr10 ===
[OK] chr10_12hrs_UNTR_aligned.json: pts=1136 clusters=2 noise=0.06 mcs=20 ms=10
[OK] chr10_12hrs_VACV_aligned.json: pts=1136 clusters=2 noise=0.00 mcs=28 ms=14
[OK] chr10_18hrs_UNTR_aligned.json: pts=1136 clusters=2 noise=0.03 mcs=20 ms=10
[OK] chr10_18hrs_VACV_aligned.json: pts=1136 clusters=2 noise=0.08 mcs=54 ms=27
[OK] chr10_24hrs_UNTR_aligned.json: pts=1136 clusters=2 noise=0.06 mcs=28 ms=None
[OK] chr10_24hrs_VACV_aligned.json: pts=1136 clusters=2 noise=0.01 mcs=39 ms=20

=== chr12 