### Delete the patches

In [None]:
import os

# Define the folders to clean
folders_to_clean = [
    "../data/patches/images",  
    "../data/patches/masks" 
]

# Iterate over each folder
for folder_path in folders_to_clean:
    # Check if the folder exists
    if not os.path.exists(folder_path):
        print(f"Folder not found: {folder_path}")
        continue

    # Iterate over all files in the folder
    for file_name in os.listdir(folder_path):
        # Check if the file has a .jpg extension (case-insensitive)
        if file_name.lower().endswith(".png"):
            # Create the full file path
            file_path = os.path.join(folder_path, file_name)
            
            # Remove the file
            os.remove(file_path)
            print(f"Deleted: {file_path}")

print("All .jpg files have been deleted from specified folders.")

### Create the patches from the refined folder into 256x256 patches.

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
import subprocess
import json

# -------------------------------------------------------------
# 1) Environment Setup
# -------------------------------------------------------------
# Add the path to the Segment Anything library
sys.path.append("../third_party/segment-anything/")

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# -------------------------------------------------------------
# 2) Load or Initialize Annotations from JSON
# -------------------------------------------------------------
annotations_path = "../data/annotations/boxes.json"
if os.path.exists(annotations_path):
    with open(annotations_path, "r") as f:
        boxes_dict = json.load(f)
else:
    boxes_dict = {}  # Start with an empty dictionary if no file exists

# -------------------------------------------------------------
# 3) Define build_totalmask() function
# -------------------------------------------------------------
def build_totalmask(pred) -> np.ndarray:
    """
    Builds a binary mask from SAM predictions.
    Stones = white (255), mortar = black (0).
    Fills small holes using morphological closing.
    
    Input:
      - pred: list of dicts (output of SAM segmentation)
    Output:
      - Binary mask (np.uint8) or None if no masks generated.
    """
    if len(pred) == 0:
        return None

    height, width = pred[0]['segmentation'].shape
    total_mask = np.zeros((height, width), dtype=np.uint8)

    for seg in pred:
        seg_bin = seg['segmentation'].astype(np.uint8)
        total_mask += seg_bin

    _, total_mask_bin = cv2.threshold(total_mask, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    kernel = np.ones((2, 2), np.uint8)
    total_mask_bin = cv2.morphologyEx(total_mask_bin, cv2.MORPH_CLOSE, kernel)

    return total_mask_bin

# -------------------------------------------------------------
# 4) Load the SAM Model and Create Generators/Predictors
# -------------------------------------------------------------
sam_checkpoint = "../models/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"  # Change to "cuda" if you have a GPU

print("Loading SAM model on device:", device)
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=48,
    pred_iou_thresh=0.9,
    stability_score_thresh=0.95,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=80
)

mask_predictor = SamPredictor(sam)

# -------------------------------------------------------------
# 5) Process Images in the Refined Folder and Extract Patches
# -------------------------------------------------------------
refined_folder = "../data/refined/img-stones"  # Adjust as needed
output_dir_img = "../data/patches/images"
output_dir_mask = "../data/patches/masks"
os.makedirs(output_dir_img, exist_ok=True)
os.makedirs(output_dir_mask, exist_ok=True)

patch_size = 256
stride = 128
stone_threshold = 0.1  # 10% of the patch must be stone

# Gather image files
image_files = [
    f for f in os.listdir(refined_folder)
    if f.lower().endswith((".jpg", ".jpeg", ".png"))
]

patch_count = 0

for img_file in image_files:
    img_path = os.path.join(refined_folder, img_file)
    image_bgr = cv2.imread(img_path)
    if image_bgr is None:
        print(f"Skipping {img_file} -> Unable to load image.")
        continue

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    height, width, _ = image_rgb.shape
    print(f"\n--- Processing {img_file} ---")
    print(f"Image path: {img_path}")
    print(f"Image size: {width} x {height}")

    # 5A) Generate Automatic Masks with SAM
    masks_auto = mask_generator.generate(image_rgb)
    print(f"  Number of auto masks: {len(masks_auto)}")
    if len(masks_auto) == 0:
        print("  No auto masks found; skipping.")
        continue

    auto_mask_bin = build_totalmask(masks_auto)
    if auto_mask_bin is None:
        print("  auto_mask_bin is None; skipping.")
        continue

    # 5B) Display the image and its auto mask
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image_rgb)
    plt.title("Original Image")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(auto_mask_bin, cmap="gray")
    plt.title("Automatic Mask")
    plt.axis("off")
    plt.show()

    # 5C) Check if annotation for missing stones exists in the JSON.
    # If not, ask the user whether to annotate.
    if img_file not in boxes_dict:
        user_response = input(f"Annotation for {img_file} not found. Do you want to annotate missing stones? (y/n): ")
        if user_response.lower().startswith("y"):
            # Call create-boxes.py with the image path.
            # It should update (or write) the boxes into the same JSON file.
            subprocess.run(["python", "create-boxes.py", img_path])
            # Reload the JSON file to update boxes_dict.
            if os.path.exists(annotations_path):
                with open(annotations_path, "r") as f:
                    boxes_dict = json.load(f)
            else:
                print("Annotation file not updated; proceeding without manual boxes.")

    # 5D) If the JSON now has boxes for this image, refine the mask.
    if img_file in boxes_dict:
        manual_boxes = boxes_dict[img_file]
        if len(manual_boxes) > 0:
            mask_predictor.set_image(image_rgb)
            input_boxes = torch.tensor(manual_boxes, device=mask_predictor.device)
            transformed_boxes = mask_predictor.transform.apply_boxes_torch(
                input_boxes, image_rgb.shape[:2]
            )
            masks_box, scores, logits = mask_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False
            )
            masks_box = masks_box.squeeze(1).cpu().numpy()  # (num_boxes, H, W)
            combined_mask_bool = auto_mask_bin.astype(bool)
            for i in range(masks_box.shape[0]):
                stone_bool = masks_box[i].astype(bool)
                combined_mask_bool = np.logical_or(combined_mask_bool, stone_bool)
            final_mask_bin = (combined_mask_bool.astype(np.uint8) * 255)
        else:
            final_mask_bin = auto_mask_bin
    else:
        final_mask_bin = auto_mask_bin

    # 5E) Check that the final mask contains stone areas.
    nz_count = cv2.countNonZero(final_mask_bin)
    print(f"  Non-zero pixels in final mask: {nz_count}")
    if nz_count == 0:
        print("  Final mask is all black; skipping.")
        continue

    # 5F) Extract 256x256 patches from the final mask and original image.
    if height < patch_size or width < patch_size:
        print(f"  {img_file} is too small for {patch_size}x{patch_size} patches; skipping.")
        continue

    local_patch_count = 0
    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            patch_mask = final_mask_bin[y:y+patch_size, x:x+patch_size]
            patch_img = image_rgb[y:y+patch_size, x:x+patch_size]

            stone_pixels = cv2.countNonZero(patch_mask)
            total_pixels = patch_size * patch_size
            if stone_pixels > stone_threshold * total_pixels:
                patch_img_bgr = cv2.cvtColor(patch_img, cv2.COLOR_RGB2BGR)
                patch_img_name = f"patch_img_{patch_count}.png"
                patch_mask_name = f"patch_mask_{patch_count}.png"
                cv2.imwrite(os.path.join(output_dir_img, patch_img_name), patch_img_bgr)
                cv2.imwrite(os.path.join(output_dir_mask, patch_mask_name), patch_mask)
                patch_count += 1
                local_patch_count += 1

    print(f"  -> Patches saved from {img_file}: {local_patch_count}")

print("\nDone!")
print("Total patches saved:", patch_count)
