In [None]:
# IMPORT PACKAGES
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tkinter as tk
from matplotlib.patches import Patch
from segment_anything import sam_model_registry, SamPredictor
from tqdm import tqdm
import os
print(os.getcwd())


In [None]:
# Load SAM
sam_checkpoint = "weights\sam_vit_b.pth"
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
sam.to(device)
predictor = SamPredictor(sam)

In [None]:
# DEFINE LABELS AND COLORS
label_map = {
    5: "Weed",
    4: "Tall Fescue",
    3: "White Clover",
    2: "Alfalfa",
    1: "Soil",
    0: "Background"
}

color_map = {
    0: (0, 0, 0),         # Background ‚Äì Black
    1: (139, 69, 19),     # Soil ‚Äì Brown
    2: (128, 0, 128),     # Alfalfa ‚Äì Purple
    3: (128, 128, 128),   # White Clover ‚Äì Grey
    4: (255, 165, 0),     # Tall Fescue ‚Äì Orange
    5: (255, 192, 203)    # Weed ‚Äì Pink
}

# --------- TRACK LAST POPUP POSITION ---------
last_popup_position = "+100+100"  # Default position (top-left)

def choose_class(image):
    global last_popup_position
    root = tk.Tk()
    root.title("Select Class")
    root.geometry(last_popup_position)

    selected_class = tk.IntVar(value=-1)

    tk.Label(root, text="Click to assign a class:").pack(pady=10)

    def select_and_close(class_id):
        nonlocal root
        selected_class.set(class_id)
        last_popup_position = f"+{root.winfo_x()}+{root.winfo_y()}"
        root.destroy()

    for class_id, name in label_map.items():
        tk.Button(root, text=name, width=30, height=2,
                  command=lambda c=class_id: select_and_close(c)).pack(padx=10, pady=5)

    root.protocol("WM_DELETE_WINDOW", root.destroy)
    root.mainloop()

    return selected_class.get()

def show_masks_and_get_selection(image, masks, scores):
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    axs[0].imshow(image)
    axs[0].set_title("Original Image")
    axs[0].axis('off')
    for i in range(3):
        axs[i+1].imshow(image)
        axs[i+1].imshow(masks[i], alpha=0.5, cmap='jet')
        axs[i+1].set_title(f"Mask {i} (Score: {scores[i]:.2f})")
        axs[i+1].axis('off')
    plt.tight_layout()
    plt.show(block=False)

    root = tk.Tk()
    root.title("Select a Mask or Skip")
    tk.Label(root, text="Click the mask you want to save, or skip if none are good:").pack(pady=10)
    selected_mask_index = tk.IntVar(value=-1)

    def select_mask(i):
        selected_mask_index.set(i)
        root.quit()
        root.destroy()

    def skip_all():
        selected_mask_index.set(-1)
        root.quit()
        root.destroy()

    for i in range(len(masks)):
        tk.Button(root, text=f"‚úÖ Select Mask {i} (Score: {scores[i]:.2f})",
                  width=40, height=2, command=lambda i=i: select_mask(i)).pack(padx=10, pady=4)

    tk.Button(root, text="‚ùå None of these masks are good ‚Äî skip this point",
              width=40, height=2, bg='red', fg='white', command=skip_all).pack(pady=15)

    root.mainloop()
    return selected_mask_index.get()

def show_grid_overlay(image, points):
    h, w = image.shape[:2]
    fig, ax = plt.subplots(figsize=(5, 5 * (h / w)))
    ax.imshow(image)
    for (x, y) in points:
        ax.plot(x, y, marker='o', color='blue', markersize=4)
    ax.set_title("BLUE = Grid Points (No Borders)")
    ax.axis('off')
    plt.tight_layout()
    plt.show()

def show_active_grid_point(image, points, active_index):
    h, w = image.shape[:2]
    fig, ax = plt.subplots(figsize=(5, 5 * (h / w)))
    ax.imshow(image)
    for i, (x, y) in enumerate(points):
        color = 'red' if i == active_index else 'lightgray'
        size = 8 if i == active_index else 4
        ax.plot(x, y, marker='o', color=color, markersize=size)
    ax.set_title(f"RED = Active Grid Point: {points[active_index]}")
    ax.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
# SET PATHS AND LOAD PROCESSED LOG
image_folder = "dataset_mixture/images/train/"
mask_output_dir = "dataset_mixture/masks/train/sam_masks"
log_file = "dataset_mixture/masks/train/processed_log.txt"
os.makedirs(mask_output_dir, exist_ok=True)
os.makedirs(os.path.dirname(log_file), exist_ok=True)

if os.path.exists(log_file):
    with open(log_file, "r") as f:
        processed_images = set(line.strip() for line in f)
else:
    processed_images = set()

In [None]:
# PROCESS ALL IMAGES IN FOLDER
image_files = sorted([f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

for image_file in image_files:
    image_name = os.path.splitext(image_file)[0]

    if image_name in processed_images:
        print(f"‚úÖ Already processed: {image_name} ‚Äî skipping.")
        continue

    image_path = os.path.join(image_folder, image_file)
    image = cv2.imread(image_path)
    if image is None:
        print(f"‚ùå Skipping unreadable image: {image_path}")
        continue

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    # GENERATE GRID
    desired_num_points = 90
    height, width = image.shape[:2]
    aspect_ratio = width / height
    rows = int(np.sqrt(desired_num_points / aspect_ratio)) or 1
    cols = max(int(desired_num_points / rows), 1)
    step_x = width // cols
    step_y = height // rows
    points = [(x, y) for y in range(step_y, height - step_y +1, step_y)
                      for x in range(step_x, width - step_x +1, step_x)]
    max_masks = len(points)
    show_grid_overlay(image, points)

    # Identify already saved masks for this image
    existing_mask_indices = set()
    for fname in os.listdir(mask_output_dir):
        if fname.startswith(image_name) and "_mask" in fname and fname.endswith(".png"):
            parts = fname.split("_")
            for p in parts:
                if p.startswith("p") and p[1:].isdigit():
                    existing_mask_indices.add(int(p[1:]))
                    break

    accepted_masks = max(existing_mask_indices) + 1 if existing_mask_indices else 0
    print(f"üîÑ Resuming {image_name} from mask index: {accepted_masks}")

    for idx, (x, y) in tqdm(enumerate(points), total=len(points), desc=f"{image_name} - Masking Progress"):
        if idx in existing_mask_indices:
            continue
        if accepted_masks >= max_masks:
            break

        show_active_grid_point(image, points, idx)

        print(f"\nüîç Trying point ({x}, {y})")
        input_point = np.array([[x, y]])
        input_label = np.array([1])
        masks, scores, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True
        )

        selected_idx = show_masks_and_get_selection(image, masks, scores)
        if selected_idx == -1:
            print("‚è© Skipped this point.")
            continue

        chosen_mask = masks[selected_idx]

        # --- SHOW CROPPED MASKED REGION (TIGHT BBOX) ---
        ys, xs = np.where(chosen_mask)
        y_min, y_max = ys.min(), ys.max()
        x_min, x_max = xs.min(), xs.max()

        # Crop image and mask
        cropped_image = image[y_min:y_max+1, x_min:x_max+1]
        cropped_mask = chosen_mask[y_min:y_max+1, x_min:x_max+1]

        # Apply mask (background black)
        cropped_masked_region = np.zeros_like(cropped_image)
        cropped_masked_region[cropped_mask] = cropped_image[cropped_mask]

#        h_crop, w_crop = cropped_masked_region.shape[:2]
#        fig, ax = plt.subplots(figsize=(8, 8 * h_crop / w_crop))
        fig, ax = plt.subplots(figsize=(10, 10))

        ax.imshow(cropped_masked_region)
        ax.set_title(
            f"Cropped Masked Region\nBBox: x[{x_min},{x_max}] y[{y_min},{y_max}]",
            fontsize=12
        )
        ax.axis("off")
        plt.tight_layout()
        plt.show()

        
        class_label = choose_class(image)
        mask_out = (chosen_mask.astype(np.uint8)) * class_label
        base_name = f"{image_name}_p{idx}_x{x}_y{y}_class{class_label}_mask{selected_idx}"
        mask_path = os.path.join(mask_output_dir, base_name + ".png")
        rgb_legend_path = mask_path.replace(".png", "_rgb_legend.png")

        if os.path.exists(mask_path):
            print(f"‚ö†Ô∏è Mask already exists, skipping save: {mask_path}")
            if not os.path.exists(mask_path) and os.path.exists(rgb_legend_path):
                os.remove(rgb_legend_path)
                print(f"üóëÔ∏è Deleted orphaned RGB legend: {rgb_legend_path}")
            continue

        cv2.imwrite(mask_path, mask_out)
        print(f"‚úÖ Saved grayscale mask: {mask_path}")

        rgb_mask = np.zeros((*mask_out.shape, 3), dtype=np.uint8)
        for class_val, color in color_map.items():
            rgb_mask[mask_out == class_val] = color

        fig, ax = plt.subplots()
        ax.imshow(rgb_mask)
        ax.set_title(f"{image_name} ‚Äî Mask {idx + 1} (Class: {label_map[class_label]})", fontsize=12)
        ax.axis('off')

        legend_elements = [
            Patch(
                facecolor=np.array(color_map[k]) / 255.0,
               label=label_map[k]
            )
            for k in sorted(label_map.keys())
        ]

        ax.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=3)

        plt.tight_layout()
        plt.savefig(rgb_legend_path, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"üñºÔ∏è Saved RGB mask with legend: {rgb_legend_path}")

        accepted_masks += 1

    print(f"\nüì¶ Done with image: {image_name}. Saved {accepted_masks} masks.")
    with open(log_file, "a") as f:
        f.write(f"{image_name}\n")