In [None]:
# Cell 3 — robust bright-percentile detection + verbose logging + per-lumen OUTER + one INNER-ALL
import os, shutil, numpy as np, torch, SimpleITK as sitk
from huggingface_hub import snapshot_download
from nnInteractive.inference.inference_session import nnInteractiveInferenceSession
import torch.nn.functional as F
from scipy.ndimage import binary_fill_holes, binary_closing, generate_binary_structure
from skimage.morphology import convex_hull_image
from scipy.ndimage import binary_fill_holes, binary_closing, generate_binary_structure
from PIL import Image, ImageDraw

def log(*a): print("[LOG]", *a)
def clipi(v, lo, hi): return int(max(lo, min(hi, v)))

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# --- Download model ---
MODEL_REPO = "nnInteractive/nnInteractive"
MODEL_NAME = "nnInteractive_v1.0"
model_root = snapshot_download(repo_id=MODEL_REPO, allow_patterns=[f"{MODEL_NAME}/*"], local_dir="/content/nnint")
model_path = os.path.join(model_root, MODEL_NAME)

# --- Load image (1, x, y, z) ---
img_itk = sitk.ReadImage(IMAGE_PATH)
arr_zyx = sitk.GetArrayFromImage(img_itk)                    # (z,y,x)
arr_xyz = np.transpose(arr_zyx, (2,1,0)).astype(np.float32)  # (x,y,z)
img = arr_xyz[None]
X, Y, Z = arr_xyz.shape
log("Loaded:", IMAGE_PATH, "shape (X,Y,Z):", arr_xyz.shape, "min/max:", float(arr_xyz.min()), float(arr_xyz.max()))

# --- Output dir ---
base = os.path.splitext(os.path.basename(IMAGE_PATH))[0]
out_dir = f"/content/{base}_nnInteractive"
shutil.rmtree(out_dir, ignore_errors=True)
os.makedirs(out_dir, exist_ok=True)
dbg_dir = os.path.join(out_dir, "debug"); os.makedirs(dbg_dir, exist_ok=True)

# ---------------------------
#  A) DETECTION (bright rings)
# ---------------------------
# Strategy: threshold the 3D volume at a high percentile, do connected components, keep top-N
PCT = 99.6           # raise to be stricter, lower to be looser
N_DESIRED = 3        # number of lumens you expect to save as OUTER files
MIN_VOX = 2000       # drop tiny specks

thr = np.percentile(arr_xyz, PCT)
mask3d = (arr_xyz >= thr)

# SimpleITK CC on (z,y,x) so transpose back:
mask3d_zyx = np.transpose(mask3d.astype(np.uint8), (2,1,0))
mask_itk = sitk.GetImageFromArray(mask3d_zyx)
cc = sitk.ConnectedComponent(mask_itk)
stats = sitk.LabelShapeStatisticsImageFilter(); stats.Execute(cc)

cands = []
for l in stats.GetLabels():
    x0,y0,z0,sx,sy,sz = stats.GetBoundingBox(l)  # NOTE: here this is in (x,y,z) due to how we built mask_itk
    vol = sx * sy * sz
    if vol >= MIN_VOX:
        cands.append((l, x0,y0,z0, sx,sy,sz, vol))

log(f"Detector threshold @ P{PCT:.1f} = {thr:.2f}")
log("Detected components:", len(cands))
cands = sorted(cands, key=lambda t: t[-1], reverse=True)[:N_DESIRED]

if not cands:
    raise RuntimeError("Detector found 0 candidates. Try lowering PCT or MIN_VOX.")

for j,(l,x0,y0,z0,sx,sy,sz,vol) in enumerate(cands,1):
    log(f"[Cand {j}] lbl={l} bbox(x[{x0},{x0+sx}) y[{y0},{y0+sy}) z[{z0},{z0+sz})) vox={vol}")

# Debug save: show mid-Z raw, mid-Z mask
midZ = Z//2
im_raw = Image.fromarray(np.uint8(np.clip(arr_xyz[:,:,midZ] / max(1e-6, arr_xyz.max()), 0, 1)*255))
im_raw.save(os.path.join(dbg_dir, f"raw_midZ_{midZ:03d}.png"))
im_mask = Image.fromarray((mask3d[:,:,midZ] * 255).astype(np.uint8))
im_mask.save(os.path.join(dbg_dir, f"mask_midZ_{midZ:03d}.png"))

# ---------------------------
#  B) NNINTERACTIVE SESSION
# ---------------------------
device = torch.device("cuda:0")
sess = nnInteractiveInferenceSession(
    device=device, use_torch_compile=False, verbose=False,
    torch_n_threads=max(1, os.cpu_count()//2), do_autozoom=False, use_pinned_memory=True
)
sess.initialize_from_trained_model_folder(model_path)

# Downsample factors
SCALE_XY = 0.5
SCALE_Z  = 1.0
MIN_DHW  = 64

def downsample_xyz(vol_xyz_f32):
    sx, sy, sz = vol_xyz_f32.shape
    nx = max(MIN_DHW, int(round(sx * SCALE_XY)))
    ny = max(MIN_DHW, int(round(sy * SCALE_XY)))
    nz = max(8,        int(round(sz * SCALE_Z)))
    t = torch.from_numpy(vol_xyz_f32).permute(2,1,0).unsqueeze(0).unsqueeze(0).to(device)  # (1,1,sz,sy,sx)
    lr = F.interpolate(t, size=(nz, ny, nx), mode='trilinear', align_corners=False)
    lr_xyz = lr.squeeze(0).squeeze(0).permute(2,1,0).contiguous().detach()  # (nx,ny,nz)
    return lr_xyz.cpu().numpy(), (nx,ny,nz)

def upsample_mask_xyz(mask_lr_xyz_u8, target_shape_xyz):
    sx, sy, sz = target_shape_xyz
    t = torch.from_numpy(mask_lr_xyz_u8.astype(np.float32)).permute(2,1,0).unsqueeze(0).unsqueeze(0).to(device)
    hr = F.interpolate(t, size=(sz, sy, sx), mode='nearest')
    hr_xyz = hr.squeeze(0).squeeze(0).permute(2,1,0).contiguous().detach()
    return (hr_xyz > 0.5).cpu().numpy().astype(np.uint8)

def add_border_negatives(session, nx, ny, zlist):
    # grid of negatives along edges + a few interior points to prevent flooding
    xs = [0, nx-1, nx//2]
    ys = [0, ny-1, ny//2]
    for z in zlist:
        for y in ys:
            for x in xs:
                session.add_point_interaction((int(x), int(y), int(z)), include_interaction=False)

def fill_ball_from_ring(mask_xyz_u8):
    """
    Make a filled ball per z-slice from a sparse 'ring' mask: convex hull + hole fill,
    then a gentle 3D closing to smooth tiny z-gaps.
    """
    m = mask_xyz_u8.astype(bool).copy()
    sx, sy, sz = m.shape
    for z in range(sz):
        sl = m[:, :, z]
        if sl.any():
            sl = convex_hull_image(sl)
            sl = binary_fill_holes(sl)
            m[:, :, z] = sl
    s3 = generate_binary_structure(3, 1)  # light z smoothing
    m = binary_closing(m, structure=s3, iterations=1)
    return m.astype(np.uint8)

def upsample_bool_xyz(mask_lr_xyz_u8, target_shape_xyz):
    # nearest-neighbor upsample for boolean masks (wrapper to avoid thresholding twice)
    sx, sy, sz = target_shape_xyz
    t = torch.from_numpy(mask_lr_xyz_u8.astype(np.float32)).permute(2,1,0).unsqueeze(0).unsqueeze(0).to(device)
    hr = F.interpolate(t, size=(sz, sy, sx), mode='nearest')
    hr_xyz = hr.squeeze(0).squeeze(0).permute(2,1,0).contiguous().detach()
    return (hr_xyz > 0.5).cpu().numpy().astype(np.uint8)

def solidify_2d_slices(mask_xyz_u8):
    """
    Make a filled ball: fill 2D holes slice-by-slice, then a gentle 3D closing.
    mask_xyz_u8: (sx, sy, sz) uint8 in {0,1}
    """
    m = mask_xyz_u8.astype(bool).copy()
    sx, sy, sz = m.shape

    # fill enclosed holes per z-slice (fast & robust for ringy signals)
    s2 = generate_binary_structure(2, 1)  # not used directly but kept for clarity
    for z in range(sz):
        m[:, :, z] = binary_fill_holes(m[:, :, z])

    # light 3D closing to bridge tiny gaps through z
    s3 = generate_binary_structure(3, 1)
    m = binary_closing(m, structure=s3, iterations=1)

    return m.astype(np.uint8)


# Combined inner across lumens
inner_all = np.zeros((X, Y, Z), dtype=np.uint8)

for i,(lbl, x0,y0,z0, sx,sy,sz, vol) in enumerate(cands, start=1):
    # Expand bbox slightly
    pad_xy, pad_z = 12, 4
    x0p = clipi(x0-pad_xy, 0, X);   x1p = clipi(x0+sx+pad_xy, 0, X)
    y0p = clipi(y0-pad_xy, 0, Y);   y1p = clipi(y0+sy+pad_xy, 0, Y)
    z0p = clipi(z0-pad_z, 0, Z);    z1p = clipi(z0+sz+pad_z, 0, Z)

    roi = arr_xyz[x0p:x1p, y0p:y1p, z0p:z1p]
    if roi.size == 0:
        log(f"[Obj {i}] Empty ROI; skipping"); continue
    sxr, syr, szr = roi.shape
    cz = szr // 2

    # Tighten XY on mid slice via bright-ring pixels (same percentile as detector)
    mid = roi[:,:,cz]
    thr_mid = np.percentile(mid, PCT)
    ring = mid >= thr_mid
    if ring.sum() < 50:
        # backoff if too few pixels
        thr_mid = np.percentile(mid, PCT - 0.3)
        ring = mid >= thr_mid
    ys, xs = np.where(ring) if ring.any() else (np.array([syr//2]), np.array([sxr//2]))
    xmin, xmax = int(xs.min()), int(xs.max())+1
    ymin, ymax = int(ys.min()), int(ys.max())+1
    mxy, mz = 12, 6
    xmin = clipi(xmin-mxy, 0, sxr-1); xmax = clipi(xmax+mxy, 1, sxr)
    ymin = clipi(ymin-mxy, 0, syr-1); ymax = clipi(ymax+mxy, 1, syr)
    z0t = clipi(cz-mz, 0, szr-1);     z1t = clipi(cz+mz, 1, szr)

    tight = roi[xmin:xmax, ymin:ymax, z0t:z1t]
    sx_t, sy_t, sz_t = tight.shape
    log(f"[Obj {i}] ROI(x[{x0p},{x1p}) y[{y0p},{y1p}) z[{z0p},{z1p})) -> tight(sx,sy,sz)=({sx_t},{sy_t},{sz_t})")

    # Downsample
    tight_lr, (nx,ny,nz) = downsample_xyz(tight)
    cz_lr = nz // 2

    # Normalize per-crop
    p_lo, p_hi = np.percentile(tight_lr, (1, 99.5))
    tight_lr = np.clip((tight_lr - p_lo) / max(1e-6, (p_hi - p_lo)), 0, 1).astype(np.float32)
    mid_lr = tight_lr[:,:,cz_lr]
    # Build a sparse 'rim' mask directly from intensity (robust to spokes/puncta)
    P_RING = 99.2  # tweak 98.8–99.6 if needed
    ring_lr = (tight_lr >= np.percentile(tight_lr, P_RING)).astype(np.uint8)


    # Seeds: brightest K on mid slice (as ring positives)
    K = 10
    flat_idx = np.argpartition(mid_lr.ravel(), -K)[-K:]
    ring_seeds_yx = [np.unravel_index(fi, mid_lr.shape) for fi in flat_idx]

    # Background: darkest corner
    corners_yx = [(0, 0), (0, ny-1), (nx-1, 0), (nx-1, ny-1)]
    # careful: mid_lr is indexed [x,y] since we created it as [nx,ny]; convert to (y,x) when reading
    def get_mid(xy): 
        x,y = xy; return mid_lr[x, y]
    bg_xy = min(corners_yx, key=get_mid)
    bg_yx = (bg_xy[1], bg_xy[0])

    # Component center (global) mapped to low-res crop (x,y)
    cx_g, cy_g, cz_g = (x0 + x0+sx)//2, (y0 + y0+sy)//2, (z0 + z0+sz)//2
    # but better: use mask center-of-mass inside ROI mid slice
    ys_m, xs_m = np.where(ring)
    if len(xs_m) > 0:
        cx_hr = int(np.mean(xs_m)); cy_hr = int(np.mean(ys_m))
    else:
        cx_hr, cy_hr = sxr//2, syr//2
    # map to tight hr
    cx_hr = clipi(cx_hr - xmin, 0, sx_t-1); cy_hr = clipi(cy_hr - ymin, 0, sy_t-1)
    # map to low-res
    cx_l = clipi(int(round(cx_hr * (nx / max(1, sx_t)))), 0, nx-1)
    cy_l = clipi(int(round(cy_hr * (ny / max(1, sy_t)))), 0, ny-1)

    # z planes for context
    z_ctx = [cz_lr]
    if nz >= 5:
        z_ctx = sorted(set([cz_lr, max(0, cz_lr-2), min(nz-1, cz_lr+2)]))

    # ---- DEBUG OVERLAYS ----
    def save_overlay(img2d, seeds_pos, seeds_neg, path_png, scale_to_255=True):
        if scale_to_255:
            im = Image.fromarray(np.uint8(np.clip(img2d,0,1)*255))
        else:
            im = Image.fromarray(np.uint8(img2d))
        im = im.convert("RGB")
        dr = ImageDraw.Draw(im)
        for (ry,rx) in seeds_pos:
            dr.ellipse((rx-2,ry-2,rx+2,ry+2), outline=(0,255,0), width=1)
        for (ry,rx) in seeds_neg:
            dr.ellipse((rx-2,ry-2,rx+2,ry+2), outline=(255,0,0), width=1)
        im.save(path_png)

    save_overlay(mid_lr, ring_seeds_yx, [(bg_yx[0],bg_yx[1]), (cy_l,cx_l)], 
                 os.path.join(dbg_dir, f"obj{i:02d}_midLR_seeds.png"))

    # ----------------- OUTER (epithelium region) -----------------
    sess.reset_interactions()
    sess.set_image(tight_lr[None])
    sess.set_target_buffer(torch.zeros((nx, ny, nz), dtype=torch.uint8, device=device))
    # negative bbox & border negatives to prevent flooding
    sess.add_bbox_interaction([[0, nx], [0, ny], [0, nz]], include_interaction=False)
    for (ry, rx) in ring_seeds_yx:
        sess.add_point_interaction((int(rx), int(ry), int(cz_lr)), include_interaction=True)
    sess.add_point_interaction((int(bg_xy[0]), int(bg_xy[1]), int(cz_lr)), include_interaction=False)  # (x,y,z)
    sess.add_point_interaction((int(cx_l), int(cy_l), int(cz_lr)), include_interaction=False)
    add_border_negatives(sess, nx, ny, z_ctx)

    outer_lr = sess.target_buffer.detach().cpu().numpy().astype(np.uint8)
    nonz_outer = int(outer_lr.sum())
    log(f"[Obj {i}] OUTER nz vox (LR) = {nonz_outer}")
    if nonz_outer == 0:
        log(f"[Obj {i}] WARN: OUTER produced empty mask")

    # Upsample model's outer prediction AND the intensity rim, then fuse
    outer_tight_pred = upsample_mask_xyz(outer_lr, (sx_t, sy_t, sz_t))
    ring_tight       = upsample_bool_xyz(ring_lr, (sx_t, sy_t, sz_t))

    # Use (prediction ∪ rim) as evidence, then convex-hull + fill to get a ball
    outer_tight = fill_ball_from_ring(np.clip(outer_tight_pred | ring_tight, 0, 1).astype(np.uint8))
    outer_tight = solidify_2d_slices(outer_tight)
    outer_full = np.zeros((X,Y,Z), dtype=np.uint8)
    outer_full[x0p+xmin:x0p+xmax, y0p+ymin:y0p+ymax, z0p+z0t:z0p+z1t] = np.maximum(
        outer_full[x0p+xmin:x0p+xmax, y0p+ymin:y0p+ymax, z0p+z0t:z0p+z1t], outer_tight
    )
    sitk.WriteImage(
        sitk.GetImageFromArray(np.transpose((outer_full*255).astype(np.uint8),(2,1,0))),
        f"{out_dir}/{base}_outer_{i:02d}.tif",
        True
    )

    # ----------------- INNER (lumen region, merged) -----------------
    sess.reset_interactions()
    sess.set_image(tight_lr[None])
    sess.set_target_buffer(torch.zeros((nx, ny, nz), dtype=torch.uint8, device=device))
    sess.add_bbox_interaction([[0, nx], [0, ny], [0, nz]], include_interaction=False)
    # positives: lumen center
    sess.add_point_interaction((int(cx_l), int(cy_l), int(cz_lr)), include_interaction=True)
    # negatives: ring pixels
    for (ry, rx) in ring_seeds_yx:
        sess.add_point_interaction((int(rx), int(ry), int(cz_lr)), include_interaction=False)
    add_border_negatives(sess, nx, ny, z_ctx)

    inner_lr = sess.target_buffer.detach().cpu().numpy().astype(np.uint8)
    nonz_inner = int(inner_lr.sum())
    log(f"[Obj {i}] INNER nz vox (LR) = {nonz_inner}")
    if nonz_inner == 0:
        log(f"[Obj {i}] WARN: INNER produced empty mask")

    inner_tight = upsample_mask_xyz(inner_lr, (sx_t, sy_t, sz_t))
    inner_tight = fill_ball_from_ring(inner_tight)
    #inner_tight = solidify_2d_slices(inner_tight)
    inner_all[x0p+xmin:x0p+xmax, y0p+ymin:y0p+ymax, z0p+z0t:z0p+z1t] = np.maximum(
        inner_all[x0p+xmin:x0p+xmax, y0p+ymin:y0p+ymax, z0p+z0t:z0p+z1t], inner_tight
    )

    # Debug save mid-slice predictions
    midZ_full = (z0p+z0t) + (sz_t//2)
    Image.fromarray((outer_full[:,:,midZ_full]*255).astype(np.uint8)).save(
        os.path.join(dbg_dir, f"obj{i:02d}_outer_midZ_{midZ_full:03d}.png"))
    Image.fromarray((inner_tight[:,:,sz_t//2]*255).astype(np.uint8)).save(
        os.path.join(dbg_dir, f"obj{i:02d}_inner_tight_mid.png"))

    torch.cuda.empty_cache()

# Save combined INNER once
sitk.WriteImage(
    sitk.GetImageFromArray(np.transpose((inner_all*255).astype(np.uint8),(2,1,0))),
    f"{out_dir}/{base}_inner_all.tif",
    True
)
print("Saved outputs to:", out_dir)
