### Delete the patches

In [None]:
import os

# Define the folders to clean
folders_to_clean = [
    "../data/patches/images",  
    "../data/patches/masks" 
]

# Iterate over each folder
for folder_path in folders_to_clean:
    # Check if the folder exists
    if not os.path.exists(folder_path):
        print(f"Folder not found: {folder_path}")
        continue

    # Iterate over all files in the folder
    for file_name in os.listdir(folder_path):
        # Check if the file has a .jpg extension (case-insensitive)
        if file_name.lower().endswith(".png"):
            # Create the full file path
            file_path = os.path.join(folder_path, file_name)
            
            # Remove the file
            os.remove(file_path)
            print(f"Deleted: {file_path}")

print("All .jpg files have been deleted from specified folders.")

### Create the patches from the refined folder into 256x256 patches.

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys

# -------------------------------------------------------------
# 1) Environment Setup
# -------------------------------------------------------------
# Add the path to the Segment Anything library
sys.path.append("../third_party/segment-anything/")

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# -------------------------------------------------------------
# 2) Define build_totalmask() function
# -------------------------------------------------------------
def build_totalmask(pred) -> np.ndarray:
    """
    input : list of dict, type: segment-anything
    output : binary mask, type: np.uint8
    
    description :
    - This function takes the output of the SAM model and builds a single binary mask.
    - It uses Otsu thresholding to convert the mask into binary format.
    - It also applies morphological closing to fill small holes in the mask.
    - The function returns the binary mask.
    - If no masks are generated, it returns None.

    Builds a binary mask from SAM predictions.
    Stones = white (255), mortar = black (0).
    Fills small holes using morphological closing.
    

    """
    if len(pred) == 0:
        # If no masks were generated, return None
        return None

    # Get image dimensions from the first mask
    height, width = pred[0]['segmentation'].shape
    total_mask = np.zeros((height, width), dtype=np.uint8)

    # Accumulate all masks using summation
    for seg in pred:
        seg_bin = seg['segmentation'].astype(np.uint8)
        total_mask += seg_bin

    # Use Otsu thresholding to convert the summation into a binary mask
    _, total_mask_bin = cv2.threshold(total_mask, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    # Apply morphological closing to fill small holes
    kernel = np.ones((2,2), np.uint8)
    total_mask_bin = cv2.morphologyEx(total_mask_bin, cv2.MORPH_CLOSE, kernel)

    return total_mask_bin

# -------------------------------------------------------------
# 3) Load the SAM model
# -------------------------------------------------------------
sam_checkpoint = "../models/sam_vit_h_4b8939.pth" 
model_type = "vit_h"
device = "cpu"  # or "cuda" if we want to use a GPU

print("Loading SAM model on device:", device)
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Create the automatic mask generator (for the initial segmentation)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2
)

# Create a SamPredictor for bounding-box refinement
mask_predictor = SamPredictor(sam)

# -------------------------------------------------------------
# 4) Define bounding boxes for each image :
#   
# -------------------------------------------------------------
boxes_dict = {
    # "filename.jpg": [ (x0, y0, x1, y1), (x0, y0, x1, y1), ... ],
    "FSE_35_004.jpg": [
        (4, 12, 178, 107), (182, 12, 382, 94), (92, 111, 298, 203), (220, 303, 409, 405), (414, 302, 492, 395), (498, 303, 561, 394), (520, 199, 561, 301), (3, 319, 23, 414)
    ],
    # We repeat for other images:
    # "FSE_24_010.jpg": [...],
    # "FNE_100_136.jpg": [...],
    # etc.
}

# -------------------------------------------------------------
# 5) Loop over images in the refined folder
#    Then refine segmentation with bounding boxes
# -------------------------------------------------------------
refined_folder = "../data/refined/img-stones"  # Adjust path if needed
output_dir_img = "../data/patches/images"
output_dir_mask = "../data/patches/masks"
os.makedirs(output_dir_img, exist_ok=True)
os.makedirs(output_dir_mask, exist_ok=True)

patch_size = 256
stride = 128
stone_threshold = 0.1  # 10% of the patch must be stone

# Gather image files
image_files = [
    f for f in os.listdir(refined_folder)
    if f.lower().endswith((".jpg", ".jpeg", ".png"))
]

patch_count = 0

for img_file in image_files:
    img_path = os.path.join(refined_folder, img_file)
    image_bgr = cv2.imread(img_path)

    if image_bgr is None:
        print(f"Skipping {img_file} -> Unable to load image.")
        continue

    # Convert to RGB
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    height, width, _ = image_rgb.shape
    print(f"\n--- Processing {img_file} ---")
    print(f"Image path: {img_path}")
    print(f"Image size: {width} x {height}")

    # 5A) Generate Automatic Masks with SAM
    masks_auto = mask_generator.generate(image_rgb)
    print(f"  Number of auto masks: {len(masks_auto)}")

    if len(masks_auto) == 0:
        print("  No auto masks found; skipping.")
        continue

    # Build the initial mask from auto
    auto_mask_bin = build_totalmask(masks_auto)
    if auto_mask_bin is None:
        print("  auto_mask_bin is None; skipping.")
        continue

    # 5B) Check if we have bounding boxes for this image
    if img_file in boxes_dict:
        # 1) Set the image on the predictor
        mask_predictor.set_image(image_rgb)

        # 2) Convert bounding boxes to Torch tensor
        raw_boxes = boxes_dict[img_file]
        input_boxes = torch.tensor(raw_boxes, device=mask_predictor.device)

        # 3) Transform boxes for SAM
        transformed_boxes = mask_predictor.transform.apply_boxes_torch(
            input_boxes, 
            image_rgb.shape[:2]
        )

        # 4) Predict masks for each bounding box
        masks_box, scores, logits = mask_predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False
        )
        # shape: (num_boxes, 1, H, W)
        masks_box = masks_box.squeeze(1).cpu().numpy()  # -> (num_boxes, H, W)

        # 5) Merge bounding-box masks with the auto mask
        combined_mask_bool = auto_mask_bin.astype(bool)
        for i in range(masks_box.shape[0]):
            # OR the bounding-box mask into the auto mask
            stone_bool = masks_box[i].astype(bool)
            combined_mask_bool = np.logical_or(combined_mask_bool, stone_bool)

        # Convert back to 0/255
        final_mask_bin = (combined_mask_bool.astype(np.uint8) * 255)
    else:
        # If no bounding boxes for this image, just use the auto mask
        final_mask_bin = auto_mask_bin

    # 5C) Check how many nonzero pixels in the final mask
    nz_count = cv2.countNonZero(final_mask_bin)
    print(f"  Non-zero pixels in final_mask_bin: {nz_count}")

    if nz_count == 0:
        print("  final_mask_bin is all black; skipping.")
        continue

    # 5D) Slide over the image to extract patches
    if height < patch_size or width < patch_size:
        print(f"  {img_file} is too small for {patch_size}x{patch_size} patches; skipping.")
        continue

    local_patch_count = 0
    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            patch_mask = final_mask_bin[y:y+patch_size, x:x+patch_size]
            patch_img = image_rgb[y:y+patch_size, x:x+patch_size]

            stone_pixels = cv2.countNonZero(patch_mask)
            total_pixels = patch_size * patch_size

            # If patch meets the stone threshold
            if stone_pixels > stone_threshold * total_pixels:
                patch_img_bgr = cv2.cvtColor(patch_img, cv2.COLOR_RGB2BGR)
                patch_img_name = f"patch_img_{patch_count}.png"
                patch_mask_name = f"patch_mask_{patch_count}.png"
                cv2.imwrite(os.path.join(output_dir_img, patch_img_name), patch_img_bgr)
                cv2.imwrite(os.path.join(output_dir_mask, patch_mask_name), patch_mask)
                
                patch_count += 1
                local_patch_count += 1

    print(f"  -> Patches saved from {img_file}: {local_patch_count}")

print("\nDone!")
print("Total patches saved:", patch_count)
