In [1]:
# ========== CONFIGURATION ==========
# This cell sets up device (CUDA/CPU) and asks for file paths

import torch
import os
from tkinter import filedialog
import tkinter as tk

# --- Check for CUDA and set device ---
if torch.cuda.is_available():
    device = "cuda"
    print("CUDA available. Using GPU for inference.")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
else:
    device = "cpu"
    print("CUDA not available. Falling back to CPU.")
    print("  Note: Detection will be slower on CPU.")

# --- Get model weights path ---
print("\n--- Select Model Weights File ---")
root = tk.Tk()
root.withdraw()
root.attributes('-topmost', True)

model_weights_path = filedialog.askopenfilename(
    title="Select Model Weights (.pth file)",
    filetypes=[("PyTorch Model", "*.pth"), ("All Files", "*.*")],
    initialdir=os.getcwd()
)

if not model_weights_path:
    raise ValueError("No model weights selected. Please run this cell again.")

print(f"Selected weights: {model_weights_path}")

# --- Get image path ---
print("\n--- Select Input Image ---")
image_path = filedialog.askopenfilename(
    title="Select Input Image",
    filetypes=[
        ("TIFF files", "*.tif *.tiff"),
        ("Image files", "*.png *.jpg *.jpeg"),
        ("All Files", "*.*")
    ],
    initialdir=os.getcwd()
)

if not image_path:
    raise ValueError("No image selected. Please run this cell again.")

print(f"Selected image: {image_path}")

print("\n" + "="*60)
print("Configuration complete")
print(f"Device: {device}")
print(f"Weights: {os.path.basename(model_weights_path)}")
print(f"Image: {os.path.basename(image_path)}")
print("="*60)

CUDA not available. Falling back to CPU.
  Note: Detection will be slower on CPU.

--- Select Model Weights File ---
Selected weights: C:/Users/danny/OneDrive - University of Arkansas/Research/Projects/BubbleID/models/instance_segmentation_weights.pth

--- Select Input Image ---


ValueError: No image selected. Please run this cell again.

In [None]:
# ------------- LOADING MODEL ----------------------
import cv2
import os
import matplotlib.pyplot as plt
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2 import model_zoo
import numpy as np
import torch
from detectron2.data import MetadataCatalog, DatasetCatalog
from scipy.ndimage import distance_transform_edt, maximum_filter, label as ndi_label
from skimage.measure import regionprops
from skimage.segmentation import watershed

# Use device from configuration cell
cfg = get_cfg()
cfg.OUTPUT_DIR = "./"
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

# Use model weights path from configuration cell
cfg.MODEL.WEIGHTS = model_weights_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.4

# Set device from configuration cell
cfg.MODEL.DEVICE = device

print(f"Loading model on {device}...")
predictor = DefaultPredictor(cfg)
print("âœ“ Model loaded successfully")

# --------------------LOAD IMAGE --------------------------
# Use image_path from configuration cell
print(f"\nLoading image: {image_path}")
image = cv2.imread(image_path)

if image is None:
    raise ValueError(f"Could not load image: {image_path}")

print(f"âœ“ Image loaded Shape: {image.shape}")

import tifffile

# --- Define crop box (modify these if you want to crop) ---
# Set to None to skip cropping
x_min, y_min = None, None  # BOTTOM LEFT (x, y)
x_max, y_max = None, None  # TOP RIGHT (x, y)

if x_min is not None and y_min is not None and x_max is not None and y_max is not None:
    print(f"\nCropping: x from {x_min} to {x_max}, y from {y_min} to {y_max}")
    cropped_image = image[y_min:y_max, x_min:x_max]
    print(f"Cropped shape: {cropped_image.shape}")
    
    # Save the cropped image
    input_dir = os.path.dirname(image_path)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    cropped_img_path = os.path.join(input_dir, f"{base_name}_cropped.tif")
    cv2.imwrite(cropped_img_path, cropped_image)
    print(f"Saved cropped image: {cropped_img_path}")
    
    image = cropped_image
else:
    print("\nSkipping crop - using full image")

# ------------------ INITIAL CROP------------------------------------
print("\nRunning detection (Pass 1/8: 6x8 grid)...")
rows = 6
cols = 8
crop_h = image.shape[0] // rows
crop_w = image.shape[1] // cols
fullimg = np.empty((rows, cols), dtype=object)

for k in range(rows):
    for j in range(cols):
        outputs = predictor(image[(k*crop_h):(k*crop_h)+crop_h, (j*crop_w):(j*crop_w)+crop_w])
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols)], axis=1) for i in range(rows)]
full_image = np.concatenate(img_rows, axis=0)

bottom_pad = np.zeros((image.shape[0]-(crop_h*rows), full_image.shape[1]), dtype=bool)
full_image = np.concatenate([full_image.astype(bool), bottom_pad], axis=0)

right_pad = np.zeros((full_image.shape[0], image.shape[1]-(crop_w*cols)), dtype=bool)
full_image = np.concatenate([full_image, right_pad], axis=1)

plt.imshow(full_image)

# ------------ CROP SHIFT RIGHT ---------------------
print("Running detection (Pass 2/8: 6x7 grid, shift right)...")
fullimg = np.empty((rows, cols-1), dtype=object)

for k in range(rows):
    for j in range(cols-1):
        outputs = predictor(image[(k*crop_h):(k*crop_h)+crop_h, (j*crop_w)+(crop_w//2):(j*crop_w)+crop_w+(crop_w//2)])
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols-1)], axis=1) for i in range(rows)]
full_image1 = np.concatenate(img_rows, axis=0)

left_pad = np.zeros((full_image1.shape[0], crop_w // 2), dtype=bool)
add = 0
if crop_w % 2 == 1:
    add = 1
right_pad = np.zeros((full_image1.shape[0], (crop_w // 2) + image.shape[1] - (crop_w*cols) + add), dtype=bool)
full_image1 = np.concatenate([left_pad, full_image1, right_pad], axis=1)

bottom_pad = np.zeros((image.shape[0]-(crop_h*rows), full_image1.shape[1]), dtype=bool)
full_image1 = np.concatenate([full_image1.astype(bool), bottom_pad], axis=0)

plt.imshow(full_image1)

# ---------------- CROP SHIFT DOWN -------------------------
print("Running detection (Pass 3/8: 5x8 grid, shift down)...")
fullimg = np.empty((rows-1, cols), dtype=object)

for k in range(rows-1):
    for j in range(cols):
        outputs = predictor(
            image[
                (k*crop_h)+(crop_h//2):(k*crop_h)+crop_h+(crop_h//2),
                (j*crop_w):(j*crop_w)+crop_w
            ]
        )
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols)], axis=1) for i in range(rows-1)]
full_image2 = np.concatenate(img_rows, axis=0)
add = 0
if crop_h % 2 == 1:
    add = 1
top_pad = np.zeros((crop_h // 2, full_image2.shape[1]), dtype=bool)
bottom_pad = np.zeros((crop_h//2 + image.shape[0] - (crop_h*rows) + add, full_image2.shape[1]), dtype=bool)
full_image2 = np.concatenate([top_pad, full_image2.astype(bool), bottom_pad], axis=0)

right_pad = np.zeros((full_image2.shape[0], image.shape[1]-(crop_w*cols)), dtype=bool)
full_image2 = np.concatenate([full_image2, right_pad], axis=1)

plt.imshow(full_image2)

# --------------------- CROP SHIFT DOWN AND RIGHT --------------------
print("Running detection (Pass 4/8: 5x7 grid, shift both)...")
fullimg = np.empty((rows-1, cols-1), dtype=object)

for k in range(rows-1):
    for j in range(cols-1):
        outputs = predictor(
            image[
                (k*crop_h)+(crop_h//2):(k*crop_h)+crop_h+(crop_h//2),
                (j*crop_w)+(crop_w//2):(j*crop_w)+crop_w+(crop_w//2)
            ]
        )
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols-1)], axis=1) for i in range(rows-1)]
full_image3 = np.concatenate(img_rows, axis=0)
add = 0
if crop_h % 2 == 1:
    add = 1
top_pad = np.zeros((crop_h // 2, full_image3.shape[1]), dtype=bool)
bottom_pad = np.zeros((crop_h//2 + image.shape[0] - (crop_h*rows) + add, full_image3.shape[1]), dtype=bool)
full_image3 = np.concatenate([top_pad, full_image3.astype(bool), bottom_pad], axis=0)
add = 0
if crop_w % 2 == 1:
    add = 1
left_pad = np.zeros((full_image3.shape[0], crop_w // 2), dtype=bool)
right_pad = np.zeros((full_image3.shape[0], (crop_w // 2) + image.shape[1] - (crop_w*cols) + add), dtype=bool)
full_image3 = np.concatenate([left_pad, full_image3, right_pad], axis=1)

plt.imshow(full_image3)

full_mask = full_image | full_image1 | full_image2 | full_image3
plt.imshow(full_mask)

# ------------------- CROP LARGER ----------------------
print("Running detection (Pass 5/8: 3x4 grid)...")
rows = 3
cols = 4
crop_h = image.shape[0] // rows
crop_w = image.shape[1] // cols
fullimg = np.empty((rows, cols), dtype=object)

for k in range(rows):
    for j in range(cols):
        outputs = predictor(
            image[
                (k*crop_h):(k*crop_h)+crop_h,
                (j*crop_w):(j*crop_w)+crop_w
            ]
        )
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols)], axis=1) for i in range(rows)]
full_image4 = np.concatenate(img_rows, axis=0)

bottom_pad = np.zeros((image.shape[0]-(crop_h*rows), full_image4.shape[1]), dtype=bool)
full_image4 = np.concatenate([full_image4.astype(bool), bottom_pad], axis=0)

right_pad = np.zeros((full_image4.shape[0], image.shape[1]-(crop_w*cols)), dtype=bool)
full_image4 = np.concatenate([full_image4, right_pad], axis=1)

plt.imshow(full_image4)

# --------------------- CROP LARGE SHIFT RIGHT --------------------------
print("Running detection (Pass 6/8: 3x3 grid, shift right)...")
fullimg = np.empty((rows, cols-1), dtype=object)

for k in range(rows):
    for j in range(cols-1):
        outputs = predictor(
            image[
                (k*crop_h):(k*crop_h)+crop_h,
                (j*crop_w)+(crop_w//2):(j*crop_w)+crop_w+(crop_w//2)
            ]
        )
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols-1)], axis=1) for i in range(rows)]
full_image5 = np.concatenate(img_rows, axis=0)

left_pad = np.zeros((full_image5.shape[0], crop_w // 2), dtype=bool)
add = 0
if crop_w % 2 == 1:
    add = 1
right_pad = np.zeros((full_image5.shape[0], (crop_w // 2) + image.shape[1] - (crop_w*cols) + add), dtype=bool)
full_image5 = np.concatenate([left_pad, full_image5, right_pad], axis=1)

bottom_pad = np.zeros((image.shape[0]-(crop_h*rows), full_image5.shape[1]), dtype=bool)
full_image5 = np.concatenate([full_image5.astype(bool), bottom_pad], axis=0)

plt.imshow(full_image5)

# ------------------------ CROP LARGE SHIFT DOWN ---------------------
print("Running detection (Pass 7/8: 2x4 grid, shift down)...")
fullimg = np.empty((rows-1, cols), dtype=object)

for k in range(rows-1):
    for j in range(cols):
        outputs = predictor(
            image[
                (k*crop_h)+(crop_h//2):(k*crop_h)+crop_h+(crop_h//2),
                (j*crop_w):(j*crop_w)+crop_w
            ]
        )
        masks = outputs['instances'].pred_masks.cpu()
        bb = outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols)], axis=1) for i in range(rows-1)]
full_image6 = np.concatenate(img_rows, axis=0)

add = 0
if crop_h % 2 == 1:
    add = 1
top_pad = np.zeros((crop_h // 2, full_image6.shape[1]), dtype=bool)
bottom_pad = np.zeros((crop_h//2 + image.shape[0] - (crop_h*rows) + add, full_image6.shape[1]), dtype=bool)
full_image6 = np.concatenate([top_pad, full_image6.astype(bool), bottom_pad], axis=0)

right_pad = np.zeros((full_image6.shape[0], image.shape[1]-(crop_w*cols)), dtype=bool)
full_image6 = np.concatenate([full_image6, right_pad], axis=1)

plt.imshow(full_image6)

# --------------- CROP LARGE SHIFT DOWN AND RIGHT ----------------
print("Running detection (Pass 8/8: 2x3 grid, shift both)...")
fullimg = np.empty((rows-1, cols-1), dtype=object)

for k in range(rows-1):
    for j in range(cols-1):
        outputs = predictor(
            image[
                (k*crop_h)+(crop_h//2):(k*crop_h)+crop_h+(crop_h//2),
                (j*crop_w)+(crop_w//2):(j*crop_w)+crop_w+(crop_w//2)
            ]
        )
        masks = outputs['instances'].pred_masks.cpu()
        bb=outputs['instances'].pred_boxes

        for l in range(len(bb)-1, -1, -1):
            if bb[l].tensor.cpu().tolist()[0][0] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][1] < 10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][2] > crop_w-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)
            elif bb[l].tensor.cpu().tolist()[0][3] > crop_h-10:
                masks = torch.cat((masks[:l], masks[l+1:]), dim=0)

        if len(masks) > 0:
            combined_mask = torch.zeros_like(masks[0], dtype=torch.bool)
            for i in range(len(masks)):
                if masks[i].sum().item() < (crop_h*crop_w)//3:
                    combined_mask |= masks[i]
        else:
            combined_mask = torch.zeros((crop_h, crop_w), dtype=torch.bool)

        fullimg[k, j] = combined_mask

img_rows = [np.concatenate([fullimg[i, j] for j in range(cols-1)], axis=1) for i in range(rows-1)]
full_image7 = np.concatenate(img_rows, axis=0)

add = 0
if crop_h % 2 == 1:
    add = 1
top_pad = np.zeros((crop_h // 2, full_image7.shape[1]), dtype=bool)
bottom_pad = np.zeros((crop_h//2 + image.shape[0] - (crop_h*rows) + add, full_image7.shape[1]), dtype=bool)
full_image7 = np.concatenate([top_pad, full_image7.astype(bool), bottom_pad], axis=0)
add = 0
if crop_w % 2 == 1:
    add = 1
left_pad = np.zeros((full_image7.shape[0], crop_w // 2), dtype=bool)
right_pad = np.zeros((full_image7.shape[0], (crop_w // 2) + image.shape[1] - (crop_w*cols) + add), dtype=bool)
full_image7 = np.concatenate([left_pad, full_image7, right_pad], axis=1)

plt.imshow(full_image7)

# Combine all 8 passes
print("\nCombining all 8 detection passes...")
full_mask = (full_image | full_image1 | full_image2 | full_image3 | 
             full_image4 | full_image5 | full_image6 | full_image7)
plt.imshow(full_mask)

# ---- SAVE the full_mask with custom colormap ----
mask_rgb = np.zeros((*full_mask.shape, 3), dtype=np.uint8)
mask_rgb[full_mask == 0] = [128, 0, 128]  # Purple background
mask_rgb[full_mask == 1] = [255, 255, 0]  # Yellow objects

input_dir = os.path.dirname(image_path)
base_name = os.path.splitext(os.path.basename(image_path))[0]
fullmask_path = os.path.join(input_dir, f"{base_name}_fullmask.png")

cv2.imwrite(fullmask_path, cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR))
print(f"âœ“ Saved combined mask: {fullmask_path}")

# -------------- MARKER-BASED WATERSHED SPLITTING --------------
print("\nApplying watershed segmentation...")
mask_uint8 = (full_mask > 0).astype(np.uint8) * 255
contour_data = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
clusters = contour_data[0] if len(contour_data) == 2 else contour_data[1]

split_label_mask = np.zeros_like(mask_uint8, dtype=np.int32)
current_label = 1

for cluster in clusters:
    cluster_mask = np.zeros_like(mask_uint8)
    cv2.drawContours(cluster_mask, [cluster], -1, 255, -1)
    cluster_bool = cluster_mask > 0
    dist = distance_transform_edt(cluster_bool)
    local_max = (dist == maximum_filter(dist, size=15)) & (dist > 0)
    markers_labeled, n_markers = ndi_label(local_max.astype(np.uint8))
    if n_markers <= 1:
        split_label_mask[cluster_bool] = current_label
        current_label += 1
        continue
    labels = watershed(-dist, markers=markers_labeled, mask=cluster_bool)
    for sublabel in range(1, labels.max() + 1):
        split_label_mask[(labels == sublabel)] = current_label
        current_label += 1

# -------------- REGIONPROPS & SAVE RESULTS --------------
print("Calculating region properties...")
regions = regionprops(split_label_mask)
diameters = [region.equivalent_diameter for region in regions]

print(f"âœ“ Detected {len(regions)} droplets")
print(f"  Mean diameter: {np.mean(diameters):.2f} pixels")
print(f"  Std diameter: {np.std(diameters):.2f} pixels")

# Output file paths
tif_path = os.path.join(input_dir, f"{base_name}_mask.tif")
circles_img_path = os.path.join(input_dir, f"{base_name}_circles.png")
diameter_txt_path = os.path.join(input_dir, f"{base_name}_diameters.txt")

# --- CREATE & SAVE CIRCLE-FITTING/OVERLAY IMAGE ---
print("\nGenerating visualizations...")
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(split_label_mask > 0, cmap='gray')
for region in regions:
    y, x = region.centroid
    ediam = region.equivalent_diameter
    circ = plt.Circle((x, y), ediam / 2, edgecolor='red', facecolor='none', linewidth=2)
    ax.add_patch(circ)
ax.set_axis_off()
fig.tight_layout()
plt.savefig(circles_img_path, bbox_inches='tight', pad_inches=0)
plt.close(fig)

# --- SAVE DIAMETERS AS TEXT FILE ---
np.savetxt(diameter_txt_path, diameters)

print(f"\n{'='*60}")
print("PROCESSING COMPLETE")
print(f"{'='*60}")
print(f"âœ“ Mask (PNG): {fullmask_path}")
print(f"âœ“ Overlay with circles (PNG): {circles_img_path}")
print(f"âœ“ Droplet diameters (TXT): {diameter_txt_path}")
print(f"{'='*60}")

In [None]:
# Check if cropping was done - skip FOV processing if not
if x_min is None or y_min is None:
    print("Skipping FOV processing - no crop was applied to the image.")
    print("To enable FOV processing, set x_min, y_min, x_max, y_max values in the detection cell.")
else:
    import numpy as np
    import cv2
    import tifffile
    import os
    import matplotlib.pyplot as plt

    # --- USER-SPECIFIED FOV from original image ---
    fov_xmin_orig, fov_xmax_orig = 839, 955
    fov_ymin_orig, fov_ymax_orig = 1460, 1579

    # --- CROP ORIGIN from cell one ---
    # x_min, y_min set above as 262, 614 -- do NOT change these here!

    # --- Convert to cropped image coordinates ---
    fov_xmin = fov_xmin_orig - x_min
    fov_xmax = fov_xmax_orig - x_min
    fov_ymin = fov_ymin_orig - y_min
    fov_ymax = fov_ymax_orig - y_min

    print(f"Working with FOV on cropped image: x:{fov_xmin}:{fov_xmax}, y:{fov_ymin}:{fov_ymax}")
    print(f"Image shape: {image.shape}")

    # --- Check bounds ---
    H, W = image.shape[:2]
    if not (0 <= fov_xmin < fov_xmax <= W) or not (0 <= fov_ymin < fov_ymax <= H):
        raise ValueError("FOV crop indices are out of bounds! Please check input.")

    # --- Now crop your FOV in image and mask, then proceed as before ---
    fov_img = image[fov_ymin:fov_ymax, fov_xmin:fov_xmax]
    # etc: use fov_img, and the same slices for masks/overlays as in prior cell

    # ========== 1. Crop the image in the FOV region ==========
    fov_img = image[fov_ymin:fov_ymax, fov_xmin:fov_xmax]
    fov_img_path = os.path.join(input_dir, f"{base_name}_FOV_image.tif")
    tifffile.imwrite(fov_img_path, fov_img)  # Use tifffile to preserve depth/channels

    # ========== 2. Find regions whose centroid is inside the FOV ==========
    fov_regions = []
    fov_indices = []
    fov_mask = np.zeros_like(full_mask, dtype=bool)
    for idx, region in enumerate(regions):
        y, x = region.centroid
        if fov_ymin <= y < fov_ymax and fov_xmin <= x < fov_xmax:
            fov_regions.append(region)
            fov_indices.append(idx)
            fov_mask[full_mask == region.label] = True

    # Crop the FOV mask to the FOV region
    cropped_fov_mask = fov_mask[fov_ymin:fov_ymax, fov_xmin:fov_xmax]

    # Color FOV mask: purple background, yellow objects
    mask_rgb_fov = np.zeros((*cropped_fov_mask.shape, 3), dtype=np.uint8)
    mask_rgb_fov[cropped_fov_mask == 0] = [128, 0, 128]
    mask_rgb_fov[cropped_fov_mask == 1] = [255, 255, 0]
    fov_mask_path = os.path.join(input_dir, f"{base_name}_FOV_mask.png")
    cv2.imwrite(fov_mask_path, cv2.cvtColor(mask_rgb_fov, cv2.COLOR_RGB2BGR))

    # ========== 3. FOV circles overlay ==========
    fov_circles_path = os.path.join(input_dir, f"{base_name}_FOV_circles.png")
    fig, ax = plt.subplots(figsize=(8,8))
    ax.imshow(cropped_fov_mask, cmap="gray", alpha=0.3)
    ax.imshow(fov_img, alpha=0.8)
    for region in fov_regions:
        y, x = region.centroid
        ediam = region.equivalent_diameter
        # Draw only if centroid within cropped FOV
        circ = plt.Circle((x-fov_xmin, y-fov_ymin), ediam/2, edgecolor='red', facecolor='none', linewidth=2)
        ax.add_patch(circ)
    ax.set_axis_off()
    plt.tight_layout()
    plt.savefig(fov_circles_path, bbox_inches='tight', pad_inches=0)
    plt.close(fig)

    # ========== 4. Save FOV diameters ==========
    fov_diameters = np.array([region.equivalent_diameter for region in fov_regions])
    fov_diam_path = os.path.join(input_dir, f"{base_name}_FOV_diameters.txt")
    np.savetxt(fov_diam_path, fov_diameters)

    print(f"Saved FOV cropped image (tif): {fov_img_path}")
    print(f"Saved FOV mask (png): {fov_mask_path}")
    print(f"Saved FOV overlay with fitted circles (png): {fov_circles_path}")
    print(f"Saved FOV droplet diameters (txt): {fov_diam_path}")