### Create the patches from the refined folder into 256x256 patches.

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
import subprocess
import json

# -------------------------------------------------------------
# 1) Environment Setup
# -------------------------------------------------------------
# If your 'segment-anything' folder is in "../third_party/segment-anything/", append it:
sys.path.append("../third_party/segment-anything/")

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# -------------------------------------------------------------
# 2) Load or Initialize Annotations from JSON
# -------------------------------------------------------------
annotations_path = "../data/annotations/boxe.json"
if os.path.exists(annotations_path):
    with open(annotations_path, "r") as f:
        boxes_dict = json.load(f)
else:
    boxes_dict = {}  # Start with an empty dictionary if no file exists

# -------------------------------------------------------------
# 2.1) Define path to create_boxes.py
# -------------------------------------------------------------
create_boxes_script = "../scripts/create_boxes.py"

# -------------------------------------------------------------
# 3) Define build_totalmask() function
# -------------------------------------------------------------
def build_totalmask(pred) -> np.ndarray:
    """
    Builds a binary mask from SAM predictions.
    Stones = white (255), mortar = black (0).
    Fills small holes using morphological closing.
    """
    if len(pred) == 0:
        return None

    height, width = pred[0]['segmentation'].shape
    total_mask = np.zeros((height, width), dtype=np.uint8)

    for seg in pred:
        seg_bin = seg['segmentation'].astype(np.uint8)
        total_mask += seg_bin

    # Convert summed mask to binary with Otsu threshold
    _, total_mask_bin = cv2.threshold(total_mask, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    # Fill small holes with morphological closing
    kernel = np.ones((2, 2), np.uint8)
    total_mask_bin = cv2.morphologyEx(total_mask_bin, cv2.MORPH_CLOSE, kernel)

    return total_mask_bin

# -------------------------------------------------------------
# 4) Load the SAM Model and Create Generators/Predictors
# -------------------------------------------------------------
sam_checkpoint = "../models/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"  # Change to "cuda" if you have a GPU

print("Loading SAM model on device:", device)
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,       
    pred_iou_thresh=0.86,     
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100
)

mask_predictor = SamPredictor(sam)

# -------------------------------------------------------------
# 5) Process Images in the Refined Folder and Extract Patches
# -------------------------------------------------------------
refined_folder = "../data/refined/img-stones"
output_dir_img = "../data/patches/images"
output_dir_mask = "../data/patches/masks"
os.makedirs(output_dir_img, exist_ok=True)
os.makedirs(output_dir_mask, exist_ok=True)

patch_size = 256   # 256x256 patches
stride = 32        # 32 pixel stride (not too low because we want different patches)
stone_threshold = 0.1  # At least 10% of patch must be stone

image_files = [
    f for f in os.listdir(refined_folder)
    if f.lower().endswith((".jpg", ".JPG",".jpeg", ".png"))
]

patch_count = 0

for img_file in image_files:
    img_path = os.path.join(refined_folder, img_file)
    image_bgr = cv2.imread(img_path)
    if image_bgr is None:
        print(f"Skipping {img_file} -> Unable to load image.")
        continue

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    height, width, _ = image_rgb.shape
    print(f"\n--- Processing {img_file} ---")
    print(f"Image path: {img_path}")
    print(f"Image size: {width} x {height}")

    # 5A) Generate Automatic Masks with SAM
    masks_auto = mask_generator.generate(image_rgb)
    print(f"  Number of auto masks: {len(masks_auto)}")
    if len(masks_auto) == 0:
        print("  No auto masks found; skipping.")
        continue

    auto_mask_bin = build_totalmask(masks_auto)
    if auto_mask_bin is None:
        print("  auto_mask_bin is None; skipping.")
        continue

    # 5B) Display the image and its auto mask for inspection
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image_rgb)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(auto_mask_bin, cmap="gray")
    plt.title("Automatic Mask (Stones=White, BG=Black)")
    plt.axis("off")
    plt.show()

    # 5C) Check if annotation for missing stones exists
    if img_file not in boxes_dict:
        user_response = input(f"Annotation for {img_file} not found. Annotate missing stones? (y/n): ")
        if user_response.lower().startswith("y"):
            subprocess.run(["python", create_boxes_script, img_path])
            if os.path.exists("temp_boxes.json"):
                with open("temp_boxes.json", "r") as f:
                    manual_boxes = json.load(f)
                # Update the boxes_dict for the current image
                boxes_dict[img_file] = manual_boxes
                # Save the updated boxes_dict to the annotations file
                with open(annotations_path, "w") as f:
                    json.dump(boxes_dict, f, indent=2)
                # Remove the temporary file
                os.remove("temp_boxes.json")
            else:
                print("Annotation file not updated; proceeding without manual boxes.")

    # 5D) Merge bounding box masks if they exist
    if img_file in boxes_dict:
        manual_boxes = boxes_dict[img_file]
        if len(manual_boxes) > 0:
            mask_predictor.set_image(image_rgb)
            input_boxes = torch.tensor(manual_boxes, device=mask_predictor.device)
            transformed_boxes = mask_predictor.transform.apply_boxes_torch(
                input_boxes, image_rgb.shape[:2]
            )
            masks_box, scores, logits = mask_predictor.predict_torch(
                point_coords=None,
                point_labels=None,
                boxes=transformed_boxes,
                multimask_output=False
            )
            masks_box = masks_box.squeeze(1).cpu().numpy()
            combined_mask_bool = auto_mask_bin.astype(bool)
            for i in range(masks_box.shape[0]):
                stone_bool = masks_box[i].astype(bool)
                combined_mask_bool = np.logical_or(combined_mask_bool, stone_bool)
            final_mask_bin = (combined_mask_bool.astype(np.uint8) * 255)
        else:
            final_mask_bin = auto_mask_bin
    else:
        final_mask_bin = auto_mask_bin

    # 5E) Verify that the final mask contains stone areas
    nz_count = cv2.countNonZero(final_mask_bin)
    print(f"  Non-zero pixels in final mask: {nz_count}")
    if nz_count == 0:
        print("  Final mask is all black; skipping.")
        continue

    # 5F) Extract 256x256 patches from the final mask and original image
    if height < patch_size or width < patch_size:
        print(f"  {img_file} is too small for {patch_size}x{patch_size} patches; skipping.")
        continue

    # We will store the patches in lists for display
    patch_imgs_list = []
    patch_masks_list = []

    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            patch_mask = final_mask_bin[y:y+patch_size, x:x+patch_size]
            patch_img = image_rgb[y:y+patch_size, x:x+patch_size]

            # Count how many stone pixels are in this patch
            stone_pixels = cv2.countNonZero(patch_mask)
            total_pixels = patch_size * patch_size

            if stone_pixels > stone_threshold * total_pixels:
                # Save patch to disk (convert image back to BGR)
                patch_img_bgr = cv2.cvtColor(patch_img, cv2.COLOR_RGB2BGR)
                patch_img_name = f"patch_img_{patch_count}.png"
                patch_mask_name = f"patch_mask_{patch_count}.png"

                cv2.imwrite(os.path.join(output_dir_img, patch_img_name), patch_img_bgr)
                cv2.imwrite(os.path.join(output_dir_mask, patch_mask_name), patch_mask)

                # Store for visualization
                patch_imgs_list.append(patch_img)
                patch_masks_list.append(patch_mask)

                patch_count += 1

    print(f"  -> Patches saved from {img_file}: {len(patch_imgs_list)}")

    # -------------------------------------------------------------
    # Display the valid patches in a 2-row layout:
    #  - Top row: original patches
    #  - Bottom row: mask patches (inverted so stones=black, BG=white)
    # -------------------------------------------------------------
    if len(patch_imgs_list) > 0:
        fig, axes = plt.subplots(
            2, len(patch_imgs_list), 
            figsize=(4*len(patch_imgs_list), 6)
        )

        for i in range(len(patch_imgs_list)):
            # Top row: original patch
            axes[0, i].imshow(patch_imgs_list[i])
            axes[0, i].axis('off')

            # Bottom row: invert the mask so stones=0(black), BG=255(white)
            inverted_mask = cv2.bitwise_not(patch_masks_list[i])
            axes[1, i].imshow(inverted_mask, cmap='gray')
            axes[1, i].axis('off')

        plt.suptitle(f"Patches for {img_file}", fontsize=16)
        plt.tight_layout()
        plt.show()

print("\nDone!")
print("Total patches saved:", patch_count)
