In [1]:
import sys
import math
import numpy as np
import cv2
from PyQt6.QtWidgets import (
    QApplication,
    QGraphicsView,
    QGraphicsScene,
    QGraphicsPixmapItem,
    QVBoxLayout,
    QWidget,
    QPushButton
)
from PyQt6.QtGui import (
    QPixmap,
    QPen,
    QColor,
    QTransform,
    QPainter,
    QBrush,
    QImage,
    QPainterPath
)
from PyQt6.QtCore import Qt, QRectF, QPointF, QSizeF
from PIL import Image
import torch
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry

# ---------------------------------------------------------------------------
# 1) Adjust SAM auto mask generator parameters for better performance
# ---------------------------------------------------------------------------
SAM_CHECKPOINT = "models/sam_vit_h_4b8939.pth"  # Update path if needed
MODEL_TYPE = "vit_h"  # Options: "vit_h", "vit_b", "vit_l"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

sam_model = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam_model.to(DEVICE)
predictor = SamPredictor(sam_model)

# Tuned parameters for finer masks
auto_mask_generator = SamAutomaticMaskGenerator(
    sam_model,
    points_per_side=64,          # More points yield finer details (more compute)
    pred_iou_thresh=0.80,        # Lower threshold yields more candidate masks
    stability_score_thresh=0.90, # Filter out unstable masks
    min_mask_region_area=500     # Skip very small masks
)

class CustomGraphicsView(QGraphicsView):
    def __init__(self):
        super().__init__()
        self.scene = QGraphicsScene()
        self.setScene(self.scene)
        self.main_pixmap_item = None
        self.original_pixmap = None
        self.selected_pixmap_item = None
        self.selection_feedback_item = None  # For drawing the final outline
        self.dragging = False

        # Modes: "selection" (prompt-based), "auto" (auto-based), "transform" (move/scale)
        self.mode = "selection"

        # For prompt-based mode: store positive/negative points
        self.positive_points = []  # Left-click => add
        self.negative_points = []  # Right-click => remove

        # We'll always merge selections into one union mask.
        # Auto-based selections are maintained here.
        self.auto_selection_mask = None  # numpy array (uint8) with 0 or 255 values
        self.image_shape = None  # (height, width)

        # 3) Toggle for morphological post-processing (smoothing edges)
        self.use_morphology = True

        self.setRenderHint(QPainter.RenderHint.Antialiasing)
        self.setRenderHint(QPainter.RenderHint.SmoothPixmapTransform)

    def load_image(self, image_path):
        self.image_path = image_path
        self.original_pixmap = QPixmap(image_path)
        if self.original_pixmap.isNull():
            print(f"Error: Could not load image from {image_path}")
            return

        # Load image with cv2 to get its shape
        img = cv2.imread(image_path)
        if img is not None:
            self.image_shape = (img.shape[0], img.shape[1])
            # Initialize union mask as all zeros (no selection)
            self.auto_selection_mask = np.zeros(self.image_shape, dtype=np.uint8)

        self.main_pixmap_item = QGraphicsPixmapItem(self.original_pixmap)
        self.scene.addItem(self.main_pixmap_item)
        self.setSceneRect(self.main_pixmap_item.boundingRect())

    def set_mode(self, mode):
        self.mode = mode
        print(f"Mode set to: {mode}")
        # Clear prompt points if leaving prompt-based mode.
        if mode != "selection":
            self.positive_points = []
            self.negative_points = []

    # -----------------------------------------------------------------------
    # 2) Always merge prompt and auto selections
    # -----------------------------------------------------------------------
    def ai_salient_object_selection(self):
        """Perform prompt-based SAM selection and automatically merge it."""
        if not self.image_path:
            print("No image loaded")
            return

        img = cv2.imread(self.image_path)
        if img is None:
            print("Error: Could not load image")
            return

        # Set up predictor
        image_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        predictor.set_image(image_rgb)

        # Require at least one prompt point
        if not self.positive_points and not self.negative_points:
            print("No selection points provided")
            return

        # Build point and label arrays
        points = []
        labels = []
        if self.positive_points:
            points.extend(self.positive_points)
            labels.extend([1] * len(self.positive_points))
        if self.negative_points:
            points.extend(self.negative_points)
            labels.extend([0] * len(self.negative_points))
        points_array = np.array(points)
        labels_array = np.array(labels)

        masks, scores, logits = predictor.predict(
            point_coords=points_array,
            point_labels=labels_array,
            multimask_output=False
        )

        mask = masks[0]
        mask_uint8 = (mask.astype(np.uint8)) * 255

        # 3) Optional morphological closing to smooth the mask
        if self.use_morphology:
            kernel = np.ones((5, 5), np.uint8)
            mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)

        # Merge the prompt-based mask into the union (auto) mask
        self.auto_selection_mask = cv2.bitwise_or(self.auto_selection_mask, mask_uint8)
        print("Automatically merged prompt-based selection into union mask.")

        # Clear prompt-based points after merging (optional)
        self.positive_points = []
        self.negative_points = []

        self.update_auto_selection_display()

    # -----------------------------------------------------------------------
    # Auto mode: add or remove an object via auto mask generator
    # -----------------------------------------------------------------------
    def auto_salient_object_update(self, click_point, action="add"):
        """Update auto mode selection: add or remove object at the clicked point."""
        if not self.image_path:
            print("No image loaded")
            return

        img = cv2.imread(self.image_path)
        if img is None:
            print("Error: Could not load image")
            return

        image_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Generate candidate masks automatically
        masks = auto_mask_generator.generate(image_rgb)
        if not masks:
            print("No masks generated")
            return

        x = int(click_point.x())
        y = int(click_point.y())
        selected_mask = None
        best_area = 0

        # Find a candidate mask covering the clicked point
        for m in masks:
            seg = m["segmentation"]  # Boolean mask
            if seg[y, x]:
                area = m.get("area", np.sum(seg))
                if area > best_area:
                    best_area = area
                    selected_mask = seg

        # If no mask covers the point, choose the largest candidate mask
        if selected_mask is None:
            selected_mask = max(masks, key=lambda m: m.get("area", np.sum(m["segmentation"])))["segmentation"]

        new_mask = (selected_mask.astype(np.uint8)) * 255

        # 3) Optional morphological closing
        if self.use_morphology:
            kernel = np.ones((5, 5), np.uint8)
            new_mask = cv2.morphologyEx(new_mask, cv2.MORPH_CLOSE, kernel)

        # Add or remove from the union mask based on the action
        if action == "add":
            self.auto_selection_mask = cv2.bitwise_or(self.auto_selection_mask, new_mask)
            print("Added object to selection (auto).")
        elif action == "remove":
            inv = cv2.bitwise_not(new_mask)
            self.auto_selection_mask = cv2.bitwise_and(self.auto_selection_mask, inv)
            print("Removed object from selection (auto).")
        else:
            print("Unknown action")

        self.update_auto_selection_display()

    def update_auto_selection_display(self):
        """Update the display based on the union mask (black outline)."""
        if self.image_path is None or self.auto_selection_mask is None:
            return

        img = cv2.imread(self.image_path)
        if img is None:
            return

        mask_uint8 = self.auto_selection_mask.copy()

        # Draw a black outline around the union mask
        contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        path = QPainterPath()
        for cnt in contours:
            if len(cnt) > 0:
                cnt = cnt.squeeze()
                if cnt.ndim < 2:
                    continue
                start = cnt[0]
                path.moveTo(start[0], start[1])
                for pt in cnt[1:]:
                    path.lineTo(pt[0], pt[1])
                path.closeSubpath()
        if self.selection_feedback_item:
            self.scene.removeItem(self.selection_feedback_item)
        self.selection_feedback_item = self.scene.addPath(path, QPen(QColor("black"), 2))

        # Create a pixmap for the union mask overlay
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_rgba = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2RGBA)
        img_rgba[:, :, 3] = mask_uint8  # Set alpha channel to mask
        result = cv2.bitwise_and(img_rgba, img_rgba, mask=mask_uint8)

        h, w, ch = result.shape
        bytes_per_line = ch * w
        q_img = QImage(result.data, w, h, bytes_per_line, QImage.Format.Format_RGBA8888)
        selected_pixmap = QPixmap.fromImage(q_img)

        if self.selected_pixmap_item:
            self.scene.removeItem(self.selected_pixmap_item)
        self.selected_pixmap_item = QGraphicsPixmapItem(selected_pixmap)
        self.scene.addItem(self.selected_pixmap_item)

        # Update the main image to show transparency where the union mask exists
        img_rgba[:, :, 3] = cv2.bitwise_not(mask_uint8)
        q_img_main = QImage(img_rgba.data, w, h, bytes_per_line, QImage.Format.Format_RGBA8888)
        self.original_pixmap = QPixmap.fromImage(q_img_main)
        self.main_pixmap_item.setPixmap(self.original_pixmap)

    # -----------------------------------------------------------------------
    # Event handling
    # -----------------------------------------------------------------------
    def mousePressEvent(self, event):
        pos = self.mapToScene(event.pos())
        if self.mode == "selection":
            # In prompt-based mode, left-click adds a positive point, right-click a negative point.
            if event.button() == Qt.MouseButton.LeftButton:
                self.positive_points.append([pos.x(), pos.y()])
                print(f"Added positive point: ({pos.x()}, {pos.y()})")
            elif event.button() == Qt.MouseButton.RightButton:
                self.negative_points.append([pos.x(), pos.y()])
                print(f"Added negative point: ({pos.x()}, {pos.y()})")
            self.ai_salient_object_selection()

        elif self.mode == "auto":
            # In auto mode, left-click adds object; right-click removes object.
            if event.button() == Qt.MouseButton.LeftButton:
                self.auto_salient_object_update(pos, action="add")
            elif event.button() == Qt.MouseButton.RightButton:
                self.auto_salient_object_update(pos, action="remove")

        elif self.mode == "transform" and self.selected_pixmap_item:
            self.dragging = True
            self.drag_start = pos

        super().mousePressEvent(event)

    def mouseMoveEvent(self, event):
        pos = self.mapToScene(event.pos())
        if self.dragging and self.selected_pixmap_item:
            delta = pos - self.drag_start
            self.selected_pixmap_item.moveBy(delta.x(), delta.y())
            self.drag_start = pos
        super().mouseMoveEvent(event)

    def mouseReleaseEvent(self, event):
        if event.button() == Qt.MouseButton.LeftButton:
            self.dragging = False
        super().mouseReleaseEvent(event)

    def wheelEvent(self, event):
        if self.mode == "transform" and self.selected_pixmap_item:
            scale_factor = 1.1 if event.angleDelta().y() > 0 else 0.9
            self.selected_pixmap_item.setScale(self.selected_pixmap_item.scale() * scale_factor)
        super().wheelEvent(event)

class MainWindow(QWidget):
    def __init__(self):
        super().__init__()
        self.init_ui()

    def init_ui(self):
        layout = QVBoxLayout()
        self.view = CustomGraphicsView()

        self.auto_selection_button = QPushButton("Auto Selection Mode")
        self.auto_selection_button.clicked.connect(lambda: self.view.set_mode("auto"))

        self.prompt_selection_button = QPushButton("Prompt Selection Mode")
        self.prompt_selection_button.clicked.connect(lambda: self.view.set_mode("selection"))

        self.transform_button = QPushButton("Transform Mode")
        self.transform_button.clicked.connect(lambda: self.view.set_mode("transform"))

        # The combine button is removed since merging is done automatically.
        layout.addWidget(self.view)
        layout.addWidget(self.prompt_selection_button)
        layout.addWidget(self.auto_selection_button)
        layout.addWidget(self.transform_button)

        self.setLayout(layout)
        self.setWindowTitle("SAM: Always Merged Selection & Transformation Tool")
        self.resize(800, 600)
        self.view.load_image("images/test/2_people_together.png")

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = MainWindow()
    window.show()
    sys.exit(app.exec())


Mode set to: auto
Added object to selection (auto).
Added object to selection (auto).
Mode set to: auto
Added object to selection (auto).
Mode set to: selection
Added positive point: (1068.0, 247.0)
Automatically merged prompt-based selection into union mask.
Added positive point: (1074.0, 311.0)
Automatically merged prompt-based selection into union mask.
Added negative point: (1093.0, 191.0)
Automatically merged prompt-based selection into union mask.
Added negative point: (1098.0, 219.0)
Automatically merged prompt-based selection into union mask.
Added negative point: (1003.0, 225.0)
Automatically merged prompt-based selection into union mask.
Mode set to: transform


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
