In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import supervision as sv
import importlib
import torch
import sys

In [None]:
libraries = ["numpy", "torch", "matplotlib", "cv2", "math", "supervision"]
for lib in libraries:
    try:
        importlib.import_module(lib)
        print(f"{lib} ✅ Installed")
    except ImportError:
        print(f"{lib} ❌ NOT Installed")

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

In [None]:
image_path = '../data/refined/img-stones/FSE_35_004.jpg'

image_bgr = cv2.imread(image_path)
image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)  # Convert BGR -> RGB

plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('off')
plt.title("Original Stone Image")
plt.show()

In [None]:
sys.path.append("../third_party/segment-anything/")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../models/sam_vit_h_4b8939.pth"
model_type = "vit_h"
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print("Using device:", device)

In [None]:
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    # min_mask_region_area=10, 
)

In [None]:
masks = mask_generator.generate(image)

In [None]:
print("Number of masks generated:", len(masks))
print("Keys of the first mask:", masks[0].keys())

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.title("Color-Coded Segmentation (Automatic)")
plt.show()

In [None]:
def build_totalmask(pred):
    """
    Builds a binary mask (stone=white, mortar=black) from SAM predictions.
    Optionally fill small holes with morphological close.
    """
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt

    height, width = pred[0]['segmentation'].shape
    total_mask = np.zeros((height, width), dtype=np.uint8)

    # Summation or logical OR approach
    for seg in pred:
        seg_bin = seg['segmentation'].astype(np.uint8)
        total_mask += seg_bin  # Summation approach

    # Otsu threshold to unify overlapping areas
    _, total_mask_bin = cv2.threshold(total_mask, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    # Morphological close to fill small holes
    kernel = np.ones((2,2), np.uint8)
    total_mask_bin = cv2.morphologyEx(total_mask_bin, cv2.MORPH_CLOSE, kernel)

    plt.figure(figsize=(8,8))
    plt.imshow(total_mask_bin, cmap='gray')
    plt.title("Binary Mask of Stones (Automatic)")
    plt.axis('off')
    plt.show()

    return total_mask_bin

In [None]:
final_mask_bin = build_totalmask(masks)

In [None]:
# Show side-by-side
plt.figure(figsize=(15,8))
plt.subplot(1,2,1)
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(final_mask_bin, cmap='gray')
plt.title("Binary Mask of Stones (Auto)")
plt.axis('off')
plt.show()

In [None]:
# Step 7: Bounding Box Prompts for Missing Stones
# We copy the bounding boxes from the "create-boxes.py" script output.
# Paste that list into boxes_raw below.

from segment_anything import SamPredictor
import numpy as np
import torch

# 1) Initialize SamPredictor
mask_predictor = SamPredictor(sam)
mask_predictor.set_image(image)  # 'image' is your RGB image from earlier steps

# 2) Define bounding boxes from your local script output
#    Replace these tuples with the actual ones from create-boxes.py
boxes_raw = [
    (3, 14, 172, 93), (175, 13, 378, 88), (92, 110, 300, 200), (225, 307, 408, 390), (411, 304, 496, 391), (501, 304, 557, 384)
]

# Convert each tuple to a NumPy array
boxes = [np.array(b) for b in boxes_raw]
print("Boxes from local script:", boxes)

# Convert to Torch tensor on the correct device
input_boxes = torch.tensor(boxes, device=mask_predictor.device)

# 3) Transform boxes for SAM
transformed_boxes = mask_predictor.transform.apply_boxes_torch(
    input_boxes, 
    image.shape[:2]  # (height, width)
)

# 4) Predict masks for each bounding box
masks_box, scores, logits = mask_predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False
)

# masks_box shape: [num_boxes, 1, H, W]
# Convert it to NumPy for further processing or merging with your automatic mask
masks_box = masks_box.squeeze(1).cpu().numpy()  # -> shape: (num_boxes, H, W)

print("Bounding box-based masks shape:", masks_box.shape)
print("Scores:", scores)


In [None]:
# Let's visualize each bounding box mask
for i in range(masks_box.shape[0]):
    plt.figure(figsize=(8,8))
    plt.imshow(masks_box[i], cmap='gray')
    plt.title(f"Box Mask {i}")
    plt.axis('off')
    plt.show()

In [None]:
# Step 8: Merge the Bounding Box Masks with the Automatic Mask

# We'll do a logical OR so that any stone found by the bounding boxes is added to the final mask.

final_mask_bool = final_mask_bin.astype(bool)  # convert auto mask to bool

for i in range(masks_box.shape[0]):
    stone_bool = masks_box[i].astype(bool)
    final_mask_bool = np.logical_or(final_mask_bool, stone_bool)

In [None]:
combined_mask_bin = final_mask_bool.astype(np.uint8) * 255
# combined_mask_bin = cv2.bitwise_not(combined_mask_bin) # Invert if needed to get white stones on black background


In [None]:
# Visualize the combined result
plt.figure(figsize=(15,8))
plt.subplot(1,2,1)
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(combined_mask_bin, cmap='gray')
plt.title("Final Mask (Auto + Box Prompts)")
plt.axis('off')
plt.show()

# %%
print("Done! This final mask includes automatic segmentation plus bounding box corrections.")