In [31]:
import pygame
import sys
import os

In [32]:
from ultralytics import SAM
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

class ImageSegmentationTool:
    def __init__(self, model_path='mobile_sam.pt', device='cuda'):
        """
        Initialize the segmentation tool with GPU support
        Args:
            model_path: Path to MobileSAM weights
            device: 'cuda' for GPU, 'cpu' for CPU, or specific GPU like 'cuda:0'
        """
        # Check if CUDA is available
        if device == 'cuda' and not torch.cuda.is_available():
            print("CUDA not available, falling back to CPU")
            device = 'cpu'
        
        self.device = device
        self.model = SAM(model_path)
        print(f"Model loaded on: {device}")
        print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
    
    def segment_with_points(self, image_path, points, labels):
        """GPU-accelerated point-based segmentation"""
        results = self.model.predict(
            image_path, 
            points=points, 
            labels=labels,
       
            device=self.device  # Specify GPU device
        )
        return results[0].masks.data.cpu().numpy()
    
    def segment_with_box(self, image_path, bbox):
        """GPU-accelerated box-based segmentation"""
        results = self.model.predict(
            image_path, 
            bboxes=[bbox],
            device=self.device,
      
        )
        return results[0].masks.data.cpu().numpy()
    
    def segment_everything(self, image_path, conf=0.4, iou=0.9):
        """GPU-accelerated automatic segmentation"""
        results = self.model(
            image_path,
            device=self.device,
            conf=conf,
            iou=iou,
     
        )
        return results[0].masks.data.cpu().numpy()
    
    def visualize_masks(self, image_path, masks, output_path='output.png'):
        """
        Visualize segmentation masks overlaid on original image
        """
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        plt.figure(figsize=(12, 8))
        plt.imshow(image)
        
        # Overlay each mask with different color
        for i, mask in enumerate(masks):
            color = np.array([255,0,0])
            h, w = mask.shape
            mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
            plt.imshow(mask_image, alpha=0.7)
        
        plt.axis('off')
        # plt.show()
        plt.savefig(output_path, bbox_inches='tight', dpi=150)
        plt.close()
        
        return output_path


# Initialize segmentation tool
seg_tool = ImageSegmentationTool(model_path='mobile_sam.pt')


CUDA not available, falling back to CPU
Model loaded on: cpu
GPU: None


In [33]:
screen_width = 1280

In [34]:
from PIL import Image
class pygame_image():
    def __init__(self,image_path = None,image = None):
        
        self.image_path = None
        self.PIL_image = None

        if image_path != None:
            self.image_path = image_path
            self.PIL_image = Image.open(image_path).convert('RGBA')
            loaded_image = pygame.image.load(image_path).convert_alpha()
            
      
        
        if image_path == None:
            self.PIL_image = image.convert('RGBA')
            loaded_image = self.pygame_image = pygame.image.fromstring(
                            self.PIL_image.tobytes(), 
                            self.PIL_image.size, 
                            self.PIL_image.mode
                            ).convert_alpha()
        
       
        
        img_width, img_height = loaded_image.get_size()
        self.og_width = img_width
        self.og_height = img_height
        max_width = 700
        max_height = 400
        
        if img_width > max_width or img_height > max_height:
            scale_factor = min(max_width/img_width, max_height/img_height)
            new_width = int(img_width * scale_factor)
            new_height = int(img_height * scale_factor)
            loaded_image = pygame.transform.scale(loaded_image, (new_width, new_height))
        
        image_rect = loaded_image.get_rect()
        image_rect.center = (screen_width // 2, 350)
        self.pygame_img = loaded_image
        self.image_rect = image_rect   

In [35]:
def dist(a,b):
    return (a[0] - b[0])*(a[0] - b[0]) + (a[1] - b[1])*(a[1] - b[1])

class bbox():
    def __init__(self):
        self.c1 = (10,10)
        self.c2 = (30,30)

    def draw(self,screen):
        pygame.draw.circle(screen, (0, 210, 0), self.c1 , 10, 7)
        pygame.draw.circle(screen, (0, 210, 0), self.c2 , 10, 7)
        
        pygame.draw.line(screen, (0, 210, 0), self.c1, (self.c1[0],self.c2[1]), 3)
        pygame.draw.line(screen, (0, 210, 0), self.c1, (self.c2[0],self.c1[1]), 3)  
        pygame.draw.line(screen, (0, 210, 0), self.c2, (self.c1[0],self.c2[1]), 3)
        pygame.draw.line(screen, (0, 210, 0), self.c2, (self.c2[0],self.c1[1]), 3)
    
    def on(self,moues_pos):
        if(dist(self.c1,moues_pos) < 450) or (dist(self.c2,moues_pos) < 450):
            return True
        else:
            return False   

In [36]:
class SimpleButton:
    def __init__(self, color, x, y, width, height, text=''):
        self.color = color
        self.x = x
        self.y = y
        self.width = width
        self.height = height
        self.text = text

    def draw(self, screen, outline=None):
        if outline:
            pygame.draw.rect(screen, outline, 
                           (self.x-2, self.y-2, self.width+4, self.height+4))
        
        pygame.draw.rect(screen, self.color, (self.x, self.y, self.width, self.height))
        
        if self.text:
            font = pygame.font.SysFont('Arial', 30)
            text = font.render(self.text, True, (0, 0, 0))
            screen.blit(text, (self.x + (self.width/2 - text.get_width()/2), 
                              self.y + (self.height/2 - text.get_height()/2)))

    def on(self, pos):
        if pos[0] > self.x and pos[0] < self.x + self.width:
            if pos[1] > self.y and pos[1] < self.y + self.height:
                return True
        return False


In [37]:
def visualize_masks(self, image_path, masks, output_path='output.png'):
        """
        Visualize segmentation masks overlaid on original image
        """
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        plt.figure(figsize=(12, 8))
        plt.imshow(image)
        
        # Overlay each mask with different color
        for i, mask in enumerate(masks):
            color = np.array([255,0,0])
            h, w = mask.shape
            mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
            plt.imshow(mask_image, alpha=0.7)
        
        plt.axis('off')
        # plt.show()
        plt.savefig(output_path, bbox_inches='tight', dpi=150)
        plt.close()
        
        return output_path

In [38]:

# Add this function outside your main loop
def perform_segmentation(image_path, pil_image, bounding_box, result_queue):
    try:
        # Perform the blocking operation
        mask = seg_tool.segment_with_box(image_path, bounding_box)
        seg_tool.visualize_masks('easy_input.jpg', mask, 'easy_output.png')
        # Convert PIL image to numpy array
        img_array = np.array(pil_image)
      
        mask = mask[0]
       
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                if(mask[i][j]):
                    img_array[i][j] = [255,255,255,255]
                else:
                    img_array[i][j] = [0,0,0,255]

        
        img_array = Image.fromarray(img_array).convert('RGBA')
        img_array.save("edited_img_mask.png")

        img_without = np.array(pil_image)
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                if(mask[i][j]):
                    img_without[i][j] = [0,0,0,0]
        img_without = Image.fromarray(img_without).convert('RGBA')
        img_without = pygame_image(image_path=None,image=img_without)

        img_with = np.array(pil_image)
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                if(not mask[i][j]):
                    img_with[i][j] = [0,0,0,0]
        img_with = Image.fromarray(img_with).convert('RGBA')
        img_with = pygame_image(image_path=None,image=img_with)
        
        
        # Put result back to main thread
        result_queue.put(('success', (img_without,img_with)))
    except Exception as e:
        result_queue.put(('error', str(e)))

In [39]:
class movableImage():
    def __init__(self,pygame_image):
        self.pygame_image = pygame_image
    
    def draw(self,screen):
        screen.blit(self.pygame_image.pygame_img,self.pygame_image.image_rect)
        
    def on(self,mouse_pos):
        x = (self.pygame_image.image_rect[0] + self.pygame_image.image_rect[0] + self.pygame_image.image_rect[2])/2
        y = (self.pygame_image.image_rect[1] + self.pygame_image.image_rect[1] + self.pygame_image.image_rect[3])/2

        if(dist((x,y) ,mouse_pos) < 200 * 200): return True
        return False




In [None]:
pygame.init()

import threading
import queue

# Add these at the top with your other variables
segmentation_queue = queue.Queue()
is_processing = False

screen = pygame.display.set_mode((1280, 720))
clock = pygame.time.Clock()
running = True
pygame.display.set_caption("Image Display with Text Input")

# Colors
WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
ACTIVE_COLOR = pygame.Color('lightskyblue3')
PASSIVE_COLOR = pygame.Color('gray')

# Font setup
base_font = pygame.font.Font(None, 32)
label_font = pygame.font.Font(None, 24)

# Text input variables
user_text = ''
input_rect = pygame.Rect(50, 50, 700, 40)
color = PASSIVE_COLOR
active = False

# Image variables
input_img = None
edited_img = None
is_image_loaded = False

mousedown = False

# movable images
movable_images = []
index_lock = None
original_offset = (0,0)

# bbox tool
box = bbox()
segmentButton = SimpleButton((255,0,0),10,720 - 50 - 10,1280 - 2 * 10,50,"Segment")

while running:
   
    mouse_pos = pygame.mouse.get_pos()
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False


        if event.type == pygame.MOUSEBUTTONDOWN:
                mousedown = True
                if input_rect.collidepoint(event.pos):
                    active = True
                else:
                    active = False
                
                if segmentButton.on(event.pos):
                    bounding_box = [*box.c1,*box.c2]
                    bounding_box[0] -= input_img.image_rect[0]
                    bounding_box[1] -= input_img.image_rect[1]
                    bounding_box[2] -= input_img.image_rect[0]
                    bounding_box[3] -= input_img.image_rect[1]

                    bounding_box[0] = int(bounding_box[0] * (input_img.og_width/input_img.image_rect[2]))
                    bounding_box[2] = int(bounding_box[2] * (input_img.og_width/input_img.image_rect[2]))

                    bounding_box[1] = int(bounding_box[1] * (input_img.og_height/input_img.image_rect[3]))
                    bounding_box[3] = int(bounding_box[3] * (input_img.og_height/input_img.image_rect[3]))
                    print(bounding_box)

                    thread = threading.Thread(
                        target=perform_segmentation,
                        args=(input_img.image_path, input_img.PIL_image, bounding_box, segmentation_queue)
                    )
                    thread.daemon = True
                    thread.start()




        if event.type == pygame.MOUSEBUTTONUP:
                mousedown = False
                
        
        
        if event.type == pygame.KEYDOWN:
            if active:
                # Check for backspace
                if event.key == pygame.K_BACKSPACE:
                    user_text = user_text[:-1]
                
                # Check for Enter key to load image
                elif event.key == pygame.K_RETURN:
                    input_img = pygame_image(user_text,image=None)
                    if(input_img.pygame_img != None): is_image_loaded = True
                # Add typed character
                else:
                    user_text += event.unicode

    if mousedown:
        if(box.on(mouse_pos)):
            if dist(box.c1,mouse_pos) < dist(box.c2,mouse_pos):
                box.c1 = mouse_pos
            else:
                box.c2 = mouse_pos 
        else:
            if index_lock != None:
                mimg = movable_images[index_lock]
        
                mimg.pygame_image.image_rect[0] = mouse_pos[0] - original_offset
                mimg.pygame_image.image_rect[1] = mouse_pos[1] - original_offset
            else:
                for mimg in movable_images:
                    pass
        


    screen.fill((30,30,30))

    if not segmentation_queue.empty():
        result_type, result_data = segmentation_queue.get()
        is_processing = False
        
        if result_type == 'success':
            img_without,img_with = result_data
            input_img = img_without
            movable_images.append(movableImage(img_with))

            print('Segmentation Done')
            
        else:
            print(f"Segmentation error: {result_data}")
    
    # Draw label
    label_surface = label_font.render("Enter image path (press Enter to load):", True, (255,255,255))
    screen.blit(label_surface, (50, 20))
    
    # Change input box color based on active state
    if active:
        color = ACTIVE_COLOR
    else:
        color = PASSIVE_COLOR
    
    # Draw input box rectangle
    pygame.draw.rect(screen, color, input_rect, 2)
    
    # Render and draw text
    text_surface = base_font.render(user_text, True, (255,255,255))
    screen.blit(text_surface, (input_rect.x + 5, input_rect.y + 5))
    
    # Adjust input box width
    input_rect.w = max(700, text_surface.get_width() + 10)
    
    # Display loaded image if available
    if is_image_loaded:
        screen.blit(input_img.pygame_img,input_img.image_rect)

    for mimg in movable_images:
        mimg.draw(screen)

    box.draw(screen)
    segmentButton.draw(screen)
    
    # Update the display
    pygame.display.flip()
    
    # Control frame rate
    clock.tick(60)
 

pygame.quit()

[190, 155, 2885, 1890]

image 1/1 c:\Users\ayush\OneDrive\Desktop\adobe_mock\easy_input.jpg: 1024x1024 1 0, 784.9ms
Speed: 12.9ms preprocess, 784.9ms inference, 27.2ms postprocess per image at shape (1, 3, 1024, 1024)
Segmentation Done
