In [None]:
import sys
sys.path.append(r'c:\Users\ice\projects\iris')

from pathlib import Path
import numpy as np
from tqdm.notebook import tqdm

from iris.config.object_localization_pipeline_config_manager import ObjectLocalizationPipelineConfigManager
from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.data_pipeline.mongodb_manager import MongoDBManager
from iris.object_localization_pipeline.segmentation_handler import SegmentationHandler
from iris.utils.image_utils import load_image
from iris.utils.data_utils import display_image_summary

In [None]:
# Initialize configuration managers
segmentation_config = ObjectLocalizationPipelineConfigManager()
data_config = DataPipelineConfigManager()
shop_config = data_config.shop_configs["nikolaj_storm"]  # Select shop
mongodb_config = data_config.mongodb_config

# Create MongoDB manager
mongodb_manager = MongoDBManager(shop_config, mongodb_config)

# Create segmentation handler
segmentation_handler = SegmentationHandler(segmentation_config.sam2_config)

try:
    # Get image metadata collection
    image_metadata = mongodb_manager.get_collection(mongodb_config.image_metadata_collection)
    test_images = list(image_metadata.find({
        "masks": {"$exists": False},
        "has_transparency": {"$ne": True}  # Skip images marked as having transparency
    }))

    if not test_images:
        print("No new images found that need segmentation")
    else:
        print(f"Found {len(test_images)} images that need segmentation")

    # Process each test image with progress bar
    for image_data in tqdm(test_images, desc='Processing images'):
        print(f"Processing image: {image_data['local_path']}")
        
        # Get the local image path
        if not image_data.get('local_path'):
            print("No local path found for this image")
            continue
            
        image_path = Path(image_data['local_path'])
        
        try:
            # Check for transparency in the original image
            original_image = load_image(image_path, target_format="numpy")  # Load without RGB conversion
            if len(original_image.shape) == 3 and original_image.shape[2] == 4:
                # Check if there are any transparent pixels
                if np.any(original_image[:, :, 3] < 255):
                    print(f"Skipping image with transparency: {image_path}")
                    # Update metadata to mark image as having transparency
                    image_metadata.update_one(
                        {"_id": image_data["_id"]},
                        {"$set": {"has_transparency": True}}
                    )
                    continue

            # Load image
            image = load_image(image_path, target_format="numpy", ensure_rgb=True)
            
            # Generate masks
            masks = segmentation_handler.segment_image(image)

            # Save segmentation metadata
            segmentation_handler.save_segmentation_metadata(
                image_data['image_hash'], 
                masks, 
                mongodb_manager
            )
            
            # Display results
            display_image_summary(
                mongodb_manager=mongodb_manager,
                image_hash=image_data['image_hash']
            )
            
        except Exception as e:
            print(f"Error processing image: {e}")
            continue
        
finally:
    # Close MongoDB connection
    mongodb_manager.close()