In [None]:
%pip install -e ..

In [1]:
import sys
sys.path.append(r'c:\Users\ice\projects\iris')

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import distinctipy

from iris.config.segmentation_pipeline_config_manager import SegmentationPipelineConfigManager
from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.data_pipeline.mongodb_manager import MongoDBManager
from iris.segmentation_pipeline.segmentation_handler import SegmentationHandler
from iris.utils import load_image
from iris.segmentation_pipeline.utils import convert_mask_format

In [4]:
def display_image_with_masks(image, masks, title="Segmentation Results"):
    """Display the original image and its segmentation masks."""
    plt.figure(figsize=(15, 5))
    
    # Display original image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.axis("off")
    
    # Display image with masks
    plt.subplot(1, 2, 2)
    plt.imshow(image)

    # Generate distinct colors for each mask
    colors = distinctipy.get_colors(len(masks), pastel_factor=0.7)
    # Convert colors to RGBA (add alpha channel)
    colors = [(r*1.0, g*1.0, b*1.0, 0.5) for r, g, b in colors]

    for i, mask_data in enumerate(masks):
        mask = convert_mask_format(mask_data["segmentation"], 'binary')
        overlay = np.zeros((*image.shape[:2], 4))  # RGBA array
        overlay[mask.astype(bool)] = colors[i]
        plt.imshow(overlay)
        plt.title("Image with Masks")
        plt.axis("off")
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
# Initialize configuration managers
segmentation_config = SegmentationPipelineConfigManager()
data_config = DataPipelineConfigManager()
shop_config = data_config.shop_configs["nikolaj_storm"]  # Select shop
mongodb_config = data_config.mongodb_config

# Create MongoDB manager
mongodb_manager = MongoDBManager(shop_config, mongodb_config)

# Create segmentation handler
segmentation_handler = SegmentationHandler(segmentation_config.sam2_config)

try:
    # Get image metadata collection
    image_metadata = mongodb_manager.get_collection(mongodb_config.image_metadata_collection)
    test_images = list(image_metadata.find().limit(5))  # Get 5 images for testing
    
    # Process each test image
    for image_data in test_images:
        print(f"Processing image: {image_data['local_path']}")
        
        # Get the local image path
        if not image_data.get('local_path'):
            print("No local path found for this image")
            continue
            
        image_path = Path(image_data['local_path'])
        
        try:
            # Load image using utils function
            image = load_image(image_path, target_format="numpy", ensure_rgb=True)
            
            # Generate masks
            masks = segmentation_handler.segment_image(image)

            # After generating masks in your processing loop
            segmentation_handler.save_segmentation_metadata(
                image_data['image_hash'], 
                masks, 
                mongodb_manager
            )
            
            # Display results
            display_image_with_masks(image, masks, f"Segmentation Results - {image_path.name}")
            
        except Exception as e:
            print(f"Error processing image: {e}")
            continue
        
finally:
    # Close MongoDB connection
    mongodb_manager.close()