In [None]:
pip install napari tifffile trimesh

In [3]:
"""
3D Visualization of Neuron Z-Stack Using Napari
-----------------------------------------------
• Loads z-stack TIFF (Z, Y, X)
• Optional downsampling to reduce memory usage
• Opens interactive 3D viewer
"""

import numpy as np
from tifffile import imread
import napari

# ✅ User settings
tiff_path = "/Users/joeyroberts/Desktop/Spines/Processed/processed_3D.tif"   # Update if needed
downsample_z = 2          # Reduce slices by factor
downsample_xy = 2         # Reduce resolution by factor

print("Loading TIFF stack...")
volume = imread(tiff_path)  # Shape: (Z, Y, X)
print("Original volume shape:", volume.shape)

volume_ds = volume[::downsample_z, ::downsample_xy, ::downsample_xy]
print("Downsampled shape:", volume_ds.shape)

# ✅ View in 3D
viewer = napari.view_image(
    volume_ds,
    name="Neuron Z-Stack",
    rendering="attenuated_mip",  # good for neuron imaging
    contrast_limits=[np.min(volume_ds), np.max(volume_ds)]
)

viewer.dims.ndisplay = 3  # Force 3D mode

print("✅ Viewer opened: rotate with right mouse, zoom with scroll")
napari.run()


Loading TIFF stack...
Original volume shape: (60, 1024, 1024)
Downsampled shape: (30, 512, 512)


Use `viewer = napari.Viewer(); viewer.add_image(...)` instead.
  viewer = napari.view_image(


✅ Viewer opened: rotate with right mouse, zoom with scroll


In [12]:
import os
import numpy as np
from tifffile import imread, imsave
from skimage.filters import frangi, threshold_otsu
from skimage.morphology import skeletonize_3d, binary_dilation, remove_small_objects, binary_closing
from skimage.feature import peak_local_max
from scipy.ndimage import distance_transform_edt, gaussian_filter
from sklearn.cluster import DBSCAN
from scipy.spatial import cKDTree
import napari
import warnings
warnings.filterwarnings("ignore")

# =====================================================
# USER SETTINGS — (Adjust these first if needed)
# =====================================================
tiff_path = "/Users/joeyroberts/Desktop/Spines/Processed/processed_3D.tif"

px_xy = 0.096  # microns per pixel (XY)
px_z = 1.0     # microns per slice (Z)

# Expected physical spine sizes (for filtering)
MIN_SPINE_DIAM_UM = 0.25
MAX_SPINE_DIAM_UM = 1.2

# Detection scales (start conservative)
sigma_xy = 1.2       # smoothing / Frangi upper scale (in XY voxels)
sigma_z = 0.4        # smoothing Z scale (in voxels)
spine_min_dist = 4   # minimum voxel distance between peaks (initial)
min_prominence = 0.06  # initial feature threshold (0-1)

# SECONDARY pass: lower threshold to recover faint heads near shaft
SECONDARY_MIN_PROM = 0.02
SECONDARY_MIN_DIST = 2

downsample = (2, 2, 2)  # Z,Y,X downsample factors to speed up

# DBSCAN settings (will be relaxed if no detections)
dbscan_eps_um = 0.35  # in microns
dbscan_min_samples = 1

# clustering to merge detections per-skeleton
NEIGHBOR_DIST_UM = 0.25  # how close along skeleton two detections count as same (um)

# misc
save_dir = "/Users/joeyroberts/Desktop/Spines/Detected"
os.makedirs(save_dir, exist_ok=True)


# =====================================================
# Utility helpers
# =====================================================
def relax_params(params):
    """Slightly relax parameters (called when no candidates)."""
    params['min_prominence'] = max(0.005, params['min_prominence'] * 0.6)
    params['spine_min_dist'] = max(1, int(params['spine_min_dist'] * 0.8))
    params['dbscan_eps_um'] = min(1.0, params['dbscan_eps_um'] * 1.3)
    return params


# =====================================================
# LOAD + DOWNSAMPLE
# =====================================================
print("Loading volume...")
vol = imread(tiff_path).astype(np.float32)
print("Original shape:", vol.shape)

z_ds, y_ds, x_ds = downsample
if tuple(downsample) != (1,1,1):
    vol = vol[::z_ds, ::y_ds, ::x_ds]
    print("Downsampled shape:", vol.shape)

Z, Y, X = vol.shape

# update effective pixel sizes after downsample
eff_px_z = px_z * z_ds
eff_px_y = px_xy * y_ds
eff_px_x = px_xy * x_ds

# =====================================================
# INITIAL PARAMS (we keep them in a dict to relax if needed)
# =====================================================
params = {
    'sigma_xy': sigma_xy,
    'sigma_z': sigma_z,
    'spine_min_dist': spine_min_dist,
    'min_prominence': min_prominence,
    'dbscan_eps_um': dbscan_eps_um,
    'dbscan_min_samples': dbscan_min_samples,
}

print("Effective pixel sizes (um): Z=%.3f, Y=%.3f, X=%.3f" % (eff_px_z, eff_px_y, eff_px_x))
print("Initial params:", params)


# =====================================================
# MAIN PROCESSING (with automatic relaxation loop)
# =====================================================
MAX_TRIES = 4
final_coords = None

for attempt in range(MAX_TRIES):
    print("\n--- Attempt", attempt+1, "with params:", params)

    # 1) Denoise (fast Gaussian)
    print("Denoising with gaussian (sigma_z=%.2f, sigma_xy=%.2f)..." % (params['sigma_z']*2, params['sigma_xy']))
    den = gaussian_filter(vol, sigma=(params['sigma_z']*2, params['sigma_xy'], params['sigma_xy']))
    den = (den - den.min()) / (den.max() - den.min() + 1e-12)

    # 2) Frangi enhancement (tubular structure)
    print("Applying Frangi (scale_range 1 -> %.2f)..." % max(1.0, params['sigma_xy']))
    try:
        vessel = frangi(den, scale_range=(1, max(1.0, params['sigma_xy'])),
                        scale_step=1, alpha=0.7, beta=0.5, gamma=15, black_ridges=False)
    except TypeError:
        vessel = frangi(den)
    vessel = (vessel - vessel.min()) / (vessel.max() - vessel.min() + 1e-12)

    # 3) Dendrite mask
    try:
        th = threshold_otsu(vessel)
    except Exception:
        th = np.percentile(vessel, 50)
    binary = vessel > (th * 1.05)  # slightly less aggressive to keep thin shafts
    binary = remove_small_objects(binary, min_size=150)
    binary = binary_closing(binary, footprint=np.ones((1,3,3)))
    binary = binary_dilation(binary, footprint=np.ones((1,3,3)))

    print("Binary dendrite voxels:", int(binary.sum()))

    # 4) Skeletonize dendrite for grouping
    print("Skeletonizing (for grouping)...")
    sk = skeletonize_3d(binary).astype(np.uint8)
    sk_coords = np.argwhere(sk > 0)
    if sk_coords.size == 0:
        print("Warning: skeleton empty — relaxing and retrying.")
        params = relax_params(params)
        continue
    sk_tree = cKDTree(sk_coords * np.array([eff_px_z, eff_px_y, eff_px_x]))

    # 5) Feature score (curvature * proximity to dendrite * intensity)
    curv = gaussian_filter(vessel, sigma=(params['sigma_z'], params['sigma_xy']*0.7, params['sigma_xy']*0.7))
    dist = distance_transform_edt(~binary)
    local_int = den
    feat = curv * (1.0 / (1.0 + dist)) * (local_int ** 1.0)
    feat = (feat - feat.min()) / (feat.max() - feat.min() + 1e-12)

    # zero out low-intensity areas
    low_mask = den < np.percentile(den, 55)
    feat[low_mask] = 0.0

    print("Feature score max:", float(feat.max()), "mean:", float(feat.mean()))

    # 6) Primary local maxima detection (conservative)
    coords = peak_local_max(
        feat,
        min_distance=max(1, params['spine_min_dist']),
        threshold_abs=params['min_prominence'],
        footprint=np.ones((3, 5, 5))
    )
    print("Primary raw peak candidates:", len(coords))

    # 7) Secondary pass: lower-threshold peaks near dendrite to recover faint heads
    coords_secondary = peak_local_max(
        feat,
        min_distance=max(1, SECONDARY_MIN_DIST),
        threshold_abs=SECONDARY_MIN_PROM,
        footprint=np.ones((3, 5, 5))
    )
    # keep secondary only if they are within some micron distance to dendrite (dist < 2 um)
    if len(coords_secondary) > 0:
        coords_sec_near = []
        for c in coords_secondary:
            zc,yc,xc = c
            if dist[zc,yc,xc] * eff_px_y <= 2.0:  # use approx physical distance (um)
                coords_sec_near.append(c)
        coords_sec_near = np.array(coords_sec_near, dtype=int) if len(coords_sec_near) else np.zeros((0,3),int)
    else:
        coords_sec_near = np.zeros((0,3),int)
    print("Secondary nearby candidates:", len(coords_sec_near))

    # Merge primary + secondary unique coords
    if coords.size == 0:
        all_coords = coords_sec_near
    elif coords_sec_near.size == 0:
        all_coords = coords
    else:
        # stack and unique by integer positions
        all_coords = np.vstack((coords, coords_sec_near))
        all_coords = np.unique(all_coords, axis=0)
    print("All combined candidates:", len(all_coords))

    if len(all_coords) == 0:
        print("No combined candidates — relaxing and retrying...")
        params = relax_params(params)
        continue

    # 8) DBSCAN clustering on um-scaled coords to merge tiny jittered peaks
    uv = all_coords.astype(float)
    uv[:,0] *= eff_px_z; uv[:,1] *= eff_px_y; uv[:,2] *= eff_px_x
    clustering = DBSCAN(eps=params['dbscan_eps_um'], min_samples=params['dbscan_min_samples']).fit(uv)
    labels = clustering.labels_
    keep_mask = labels != -1
    coords_kept = all_coords[keep_mask]
    labels_kept = labels[keep_mask]
    print("After DBSCAN keep:", len(coords_kept))

    if len(coords_kept) == 0:
        print("All removed by DBSCAN — relaxing and retrying...")
        params = relax_params(params)
        continue

    # 9) Merge DBSCAN clusters into centroids, but then group **by nearest skeleton voxel**
    merged = []
    for lab in np.unique(labels_kept):
        pts = coords_kept[labels_kept == lab]
        # choose the highest-feature point in the cluster
        best_idx = None
        best_val = -np.inf
        for p in pts:
            v = feat[tuple(p)]
            if v > best_val:
                best_val = v
                best_idx = p
        merged.append(best_idx)
    merged = np.array(merged).astype(int)
    print("Clusters merged ->", len(merged))

    # build mapping: each merged point -> nearest skeleton index
    if merged.size == 0:
        print("Merged empty — relaxing and retrying...")
        params = relax_params(params)
        continue

    merged_float = merged.astype(float)
    merged_um = merged_float.copy()
    merged_um[:,0] *= eff_px_z; merged_um[:,1] *= eff_px_y; merged_um[:,2] *= eff_px_x
    dists, idxs = sk_tree.query(merged_um, k=1)
    # now group merged points by skeleton index and keep the single best per skeleton neighborhood
    sk_group = {}
    for i, sk_idx in enumerate(idxs):
        sk_group.setdefault(int(sk_idx), []).append(merged[i])

    final_by_skel = []
    for sk_idx, pts in sk_group.items():
        # if multiple merged pts map to same sk voxel, keep the one with larger feat value
        best = None; best_val = -np.inf
        for p in pts:
            val = feat[tuple(p)]
            if val > best_val:
                best_val = val; best = p
        final_by_skel.append(best)
    final_by_skel = np.array(final_by_skel).astype(int)
    print("After grouping by skeleton:", len(final_by_skel))

    # 10) Size & intensity filter around each final candidate
    min_vox_area = np.pi * ((MIN_SPINE_DIAM_UM / 2.0) / eff_px_y) ** 2
    max_vox_area = np.pi * ((MAX_SPINE_DIAM_UM / 2.0) / eff_px_y) ** 2
    final = []
    for (z,y,x) in final_by_skel:
        z0 = max(0, z-2); z1 = min(Z, z+3)
        y0 = max(0, y-3); y1 = min(Y, y+4)
        x0 = max(0, x-3); x1 = min(X, x+4)
        window = den[z0:z1, y0:y1, x0:x1]
        local_thr = np.percentile(window, 60)
        area_count = (window > local_thr).sum()
        # accept if area within plausible range (allow some leeway)
        if area_count >= max(1, min_vox_area * 0.5) and area_count <= max_vox_area * 8:
            final.append((z,y,x))
    final = np.array(final)
    print("After size/intensity filtering:", len(final))

    if final.size == 0:
        print("No final candidates after filtering — relaxing and retrying...")
        params = relax_params(params)
        continue

    # success
    final_coords = final
    print("Detection succeeded on attempt", attempt+1)
    break

# end loop

if final_coords is None or len(final_coords) == 0:
    print("WARNING: No spines detected after all relaxation attempts. Try manual parameter tuning or provide a crop for tuning.")
    final_coords = np.zeros((0,3), dtype=int)

# =====================================================
# SAVE OUTPUTS
# =====================================================
print("Saving outputs...")
imsave(os.path.join(save_dir, "vessel_3D.tif"), (vessel * 255).astype(np.uint8))

spine_mask = np.zeros_like(vol, dtype=np.uint8)
for (z,y,x) in final_coords:
    spine_mask[int(z), int(y), int(x)] = 255
imsave(os.path.join(save_dir, "spine_candidates_pruned.tif"), spine_mask)
print("Saved to", save_dir)

# =====================================================
# VISUALIZATION
# =====================================================
print("Launching Napari...")
viewer = napari.Viewer()
viewer.add_image(den, name="Denoised")
viewer.add_image(vessel, name="Dendrite Vesselness", opacity=0.45)
viewer.add_image((binary.astype(np.uint8) * 255), name="DendriteMask", opacity=0.25)
if len(final_coords) > 0:
    viewer.add_points(final_coords, name="Spine Detections", size=4, face_color='yellow')
viewer.dims.ndisplay = 3
napari.run()

print("Done.")


Loading volume...
Original shape: (625, 1024, 1024)
Downsampled shape: (313, 512, 512)
Effective pixel sizes (um): Z=2.000, Y=0.192, X=0.192
Initial params: {'sigma_xy': 1.2, 'sigma_z': 0.4, 'spine_min_dist': 4, 'min_prominence': 0.06, 'dbscan_eps_um': 0.35, 'dbscan_min_samples': 1}

--- Attempt 1 with params: {'sigma_xy': 1.2, 'sigma_z': 0.4, 'spine_min_dist': 4, 'min_prominence': 0.06, 'dbscan_eps_um': 0.35, 'dbscan_min_samples': 1}
Denoising with gaussian (sigma_z=0.80, sigma_xy=1.20)...
Applying Frangi (scale_range 1 -> 1.20)...
Binary dendrite voxels: 198317
Skeletonizing (for grouping)...
Feature score max: 0.9999999999982547 mean: 0.0002873278590336595
Primary raw peak candidates: 672
Secondary nearby candidates: 974
All combined candidates: 974
After DBSCAN keep: 974
Clusters merged -> 974
After grouping by skeleton: 933
After size/intensity filtering: 933
Detection succeeded on attempt 1
Saving outputs...
Saved to /Users/joeyroberts/Desktop/Spines/Detected
Launching Napari...
