In [1]:
import cv2
import os
import numpy as np
from os.path import join, isfile

# Directory paths
image_dir = './v3_val/images/'
output_mask_dir = './v3_val/modified_masks/'
os.makedirs(output_mask_dir, exist_ok=True)

# Get sorted list of image files
image_files = sorted([f for f in os.listdir(image_dir) if isfile(join(image_dir, f))])

# Screen dimensions for scaling (assuming 1920x1080 as an example)
screen_width = 1920
screen_height = 1080

# Current index for navigation
current_index = 0

# Zoom, pan, and mask overlay variables
zoom_factor = 1.0
pan_x = 0
pan_y = 0
show_mask_overlay = True
is_drawing = False
is_erasing = False  # Eraser mode
last_point = None  # To keep track of the last point for continuous drawing

# Drawing tool size
brush_size = 3  # Size of the drawing/erasing brush, adjustable
current_class = 1  # Default drawing class

# Colors for each class for visualization
class_colors = {
    1: (0, 255, 0),     # Green for class 1
    2: (255, 0, 0),     # Blue for class 2
    3: (0, 0, 255)      # Red for class 3
}

# Cached image for faster panning and zooming
cached_image = None
cached_mask = None
resize_factors = (1.0, 1.0)  # Width and height resize factors
crop_offsets = (0, 0)  # x1, y1 offsets for cropping

# Stack for undo functionality
undo_stack = []

def save_state_to_undo_stack():
    """Save the current state of the mask to the undo stack."""
    if cached_mask is not None:
        undo_stack.append(cached_mask.copy())

# Mouse callback for drawing/erasing
def draw_on_mask(event, x, y, flags, param):
    global is_drawing, is_erasing, cached_mask, resize_factors, crop_offsets, last_point, current_class

    resize_factor_x, resize_factor_y = resize_factors
    crop_x1, crop_y1 = crop_offsets

    mask_x = int(x / resize_factor_x) + crop_x1
    mask_y = int(y / resize_factor_y) + crop_y1

    if event == cv2.EVENT_LBUTTONDOWN:
        save_state_to_undo_stack()  # Save the initial state before drawing/erasing
        is_drawing = True
        last_point = (mask_x, mask_y)
    elif event == cv2.EVENT_LBUTTONUP:
        is_drawing = False
        last_point = None
    elif event == cv2.EVENT_MOUSEMOVE and is_drawing:
        if 0 <= mask_x < cached_mask.shape[1] and 0 <= mask_y < cached_mask.shape[0]:
            color = 0 if is_erasing else current_class
            # Draw a line between the last point and the current point
            if last_point is not None:
                cv2.line(cached_mask, last_point, (mask_x, mask_y), (color), brush_size)
            last_point = (mask_x, mask_y)

# Function to load and cache the image and start a fresh mask
def load_image_with_fresh_mask(image_path):
    global cached_image, cached_mask, undo_stack
    cached_image = cv2.imread(image_path)
    cached_mask = np.zeros(cached_image.shape[:2], dtype=np.uint8)  # Start with an empty mask
    undo_stack = []  # Clear undo stack when loading a new image

# Function to display the image with optional mask overlay
def display_image():
    global zoom_factor, pan_x, pan_y, show_mask_overlay, cached_image, cached_mask, resize_factors, crop_offsets, current_class, is_erasing, brush_size

    # Create the overlay
    overlay = cached_image.copy()
    for class_id, color in class_colors.items():
        overlay[cached_mask == class_id] = color

    # Apply zoom
    height, width = overlay.shape[:2]
    crop_width = int(width / zoom_factor)
    crop_height = int(height / zoom_factor)

    # Center coordinates
    center_x, center_y = width // 2, height // 2

    # Crop coordinates
    x1 = max(0, min(center_x - crop_width // 2 + pan_x, width - crop_width))
    y1 = max(0, min(center_y - crop_height // 2 + pan_y, height - crop_height))
    x2 = x1 + crop_width
    y2 = y1 + crop_height

    cropped = overlay[y1:y2, x1:x2]
    cropped_mask = cached_mask[y1:y2, x1:x2]

    # Scale the cropped image to fit the screen
    scale_factor = min(screen_width / cropped.shape[1], screen_height / cropped.shape[0])
    new_width = int(cropped.shape[1] * scale_factor)
    new_height = int(cropped.shape[0] * scale_factor)
    resized = cv2.resize(cropped, (new_width, new_height), interpolation=cv2.INTER_AREA)

    # Update resize factors and crop offsets
    resize_factors = (new_width / cropped.shape[1], new_height / cropped.shape[0])
    crop_offsets = (x1, y1)

    # Display keys at the top of the window
    eraser_status = "Eraser ON" if is_erasing else "Eraser OFF"
    keys_info = (
        f"Keys: Q-Quit | N-Next | P-Previous | +/- Zoom | IJKL-Pan | T-Toggle Mask | S-Save | Z-Undo | E-{eraser_status} | F-Increase Brush | G-Decrease Brush (Current: {brush_size}) | 1-3 Classes"
    )
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    font_color = (255, 255, 255)
    thickness = 1
    y0, dy = 20, 20

    for i, line in enumerate(keys_info.split(" | ")):
        y = y0 + i * dy
        cv2.putText(resized, line, (10, y), font, font_scale, font_color, thickness, cv2.LINE_AA)

    # Highlight the selected class
    class_info = f"Current Class: {current_class}"
    class_color = class_colors[current_class]
    cv2.putText(resized, class_info, (10, y0 + len(keys_info.split(" | ")) * dy), font, font_scale, class_color, thickness, cv2.LINE_AA)

    # Display the image
    cv2.imshow('Image Viewer', resized)

# Load the first image with a fresh mask
load_image_with_fresh_mask(join(image_dir, image_files[current_index]))
cv2.namedWindow('Image Viewer')
cv2.setMouseCallback('Image Viewer', draw_on_mask)

# Main loop for navigation
while True:
    # Display the current image
    display_image()

    # Wait for key press
    key = cv2.waitKey(1) & 0xFF

    if key == ord('q'):  # Quit
        break
    elif key == ord('n'):  # Next
        current_index = (current_index + 1) % len(image_files)
        pan_x, pan_y, zoom_factor = 0, 0, 1.0
        load_image_with_fresh_mask(join(image_dir, image_files[current_index]))
    elif key == ord('p'):  # Previous
        current_index = (current_index - 1 + len(image_files)) % len(image_files)
        pan_x, pan_y, zoom_factor = 0, 0, 1.0
        load_image_with_fresh_mask(join(image_dir, image_files[current_index]))
    elif key == ord('+'):  # Zoom in
        zoom_factor = min(zoom_factor * 1.2, 10.0)  # Limit max zoom
    elif key == ord('-'):  # Zoom out
        zoom_factor = max(zoom_factor / 1.2, 1.0)  # Limit min zoom
    elif key == ord('i'):  # Pan up
        pan_y -= int(100 / zoom_factor)  # Faster panning
    elif key == ord('k'):  # Pan down
        pan_y += int(100 / zoom_factor)
    elif key == ord('j'):  # Pan left
        pan_x -= int(100 / zoom_factor)
    elif key == ord('l'):  # Pan right
        pan_x += int(100 / zoom_factor)
    elif key == ord('t'):  # Toggle mask overlay
        show_mask_overlay = not show_mask_overlay
    elif key == ord('s'):  # Save modified mask
        output_path = join(output_mask_dir, f"{os.path.splitext(image_files[current_index])[0]}_mask.png")
        cv2.imwrite(output_path, cached_mask)
        print(f"Modified mask saved to {output_path}")
    elif key == ord('z'):  # Undo last drawing action
        if undo_stack:
            cached_mask = undo_stack.pop()
            print("Undo performed.")
    elif key == ord('e'):  # Toggle eraser mode
        is_erasing = not is_erasing
        print("Eraser mode toggled.")
    elif key == ord('f'):  # Increase brush size
        brush_size = min(brush_size + 1, 50)  # Limit max brush size
        print(f"Brush size increased to {brush_size}")
    elif key == ord('g'):  # Decrease brush size
        brush_size = max(brush_size - 1, 1)  # Limit min brush size
        print(f"Brush size decreased to {brush_size}")
    elif key in [ord('1'), ord('2'), ord('3')]:
        current_class = int(chr(key))
        print(f"Switched to class {current_class}")

cv2.destroyAllWindows()


Switched to class 2
Switched to class 3
