In [3]:
nucleiChannelPath = '/Users/eliasguan/Desktop/DL_210_data_analysis/Example_Image_analysis/smallworm/below_the_eyes/worm1/405/Image1_405.tif'
voxel_size = (500,75,75)

In [None]:
import numpy as np
from pathlib import Path
from tqdm import tqdm

# Test that this runs
print("Basic imports OK")

# Import Cellpose separately
from cellpose import models, io
from tqdm import tqdm
import torch
import random
import imageio  # for saving snapshots

# ----------------------------
# USER INPUTS
# ----------------------------
tile_size = 512               # Y/X tile size
tile_overlap = 64             # overlap in pixels
qc_sample_prob = 0.05         # probability of saving a QC snapshot per tile

# ----------------------------
# Load image
# ----------------------------
img = io.imread(nucleiChannelPath)  # Z, Y, X
Z, Y, X = img.shape
print("Original image shape:", img.shape)

# ----------------------------
# Compute anisotropy
# ----------------------------
z, y, x = voxel_size
anisotropy = z / ((y + x) / 2)
print("Anisotropy:", anisotropy)

# ----------------------------
# Create results folder
# ----------------------------
input_path = Path(nucleiChannelPath)
results_dir = input_path.parent / "cpsam_3D_tiled_results"
results_dir.mkdir(exist_ok=True)
qc_dir = results_dir / "QC_tiles"
qc_dir.mkdir(exist_ok=True)

# ----------------------------
# Initialize SAM model
# ----------------------------
model = models.CellposeModel(gpu=True)

# ----------------------------
# Compute tile start positions
# ----------------------------
y_starts = list(range(0, Y, tile_size - tile_overlap))
x_starts = list(range(0, X, tile_size - tile_overlap))

# ----------------------------
# Prepare final mask
# ----------------------------
final_masks = torch.zeros_like(torch.tensor(img), dtype=torch.int32)
current_max_label = 0

# ----------------------------
# Tile loop with QC
# ----------------------------
for y0 in tqdm(y_starts, desc="Tiles Y"):
    y1 = min(y0 + tile_size, Y)
    for x0 in tqdm(x_starts, desc="Tiles X", leave=False):
        x1 = min(x0 + tile_size, X)
        
        # Extract 3D tile (full Z, small XY)
        tile = img[:, y0:y1, x0:x1].astype(np.float32)
        
        # ----------------------------
        # Run 3D SAM on this tile
        # ----------------------------
        masks_tile, flows, styles = model.eval(
            tile,
            do_3D=True,
            channel_axis=None,
            z_axis=0,
            diameter=None,
            anisotropy=anisotropy,
            progress=False  # internal progress off
        )
        
        # Convert to torch for MPS-safe handling
        masks_tile = torch.tensor(masks_tile, dtype=torch.int32)
        
        # ----------------------------
        # QC: randomly save a tile + masks
        # ----------------------------
        if random.random() < qc_sample_prob:
            # Save raw tile
            imageio.volsave(qc_dir / f"tile_raw_Z{tile.shape[0]}_Y{x0}_{y0}.tif", tile.astype(np.uint16))
            # Save segmentation
            imageio.volsave(qc_dir / f"tile_mask_Z{masks_tile.shape[0]}_Y{x0}_{y0}.tif", masks_tile.numpy().astype(np.uint16))
        
        # ----------------------------
        # Merge tile masks into final 3D mask
        # ----------------------------
        masks_tile[masks_tile > 0] += current_max_label
        current_max_label = masks_tile.max()
        
        # Merge using maximum (avoids overwriting existing IDs)
        final_masks[:, y0:y1, x0:x1] = torch.maximum(final_masks[:, y0:y1, x0:x1], masks_tile)

# ----------------------------
# Save final merged 3D masks
# ----------------------------
save_path = results_dir / (input_path.stem + "_SAM3D_tiled")
torch.save(final_masks, save_path)  # torch tensor save
imageio.volsave(str(save_path) + ".tif", final_masks.numpy().astype(np.uint16))

print("✔ DONE — 3D SAM tiled + stitched masks saved to:", save_path)
print("QC tiles saved to:", qc_dir)



Basic imports OK


100%|███████████████████████████████████████████████████████████████████████| 38/38 [00:00<00:00, 172.61it/s]


Original image shape: (38, 2168, 11136)
Anisotropy: 6.666666666666667


Tiles Y:   0%|                                                                         | 0/5 [00:00<?, ?it/s]
Tiles X:   0%|                                                                        | 0/25 [00:00<?, ?it/s][AMPS does not support 3D post-processing, switching to CPU

Tiles X:   4%|██▎                                                        | 1/25 [40:49<16:19:47, 2449.46s/it][AMPS does not support 3D post-processing, switching to CPU

Tiles X:   8%|████▌                                                    | 2/25 [1:21:47<15:40:48, 2454.30s/it][AMPS does not support 3D post-processing, switching to CPU

Tiles X:  12%|██████▊                                                  | 3/25 [2:02:43<15:00:16, 2455.32s/it][AMPS does not support 3D post-processing, switching to CPU

Tiles X:  16%|█████████                                                | 4/25 [2:43:41<14:19:38, 2456.12s/it][AMPS does not support 3D post-processing, switching to CPU

Tiles X:  20%|███████████▍         