In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

In [2]:
import sys
sys.path.append(r'c:\Users\ice\projects\iris')

from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.config.embedding_pipeline_config_manager import EmbeddingPipelineConfigManager
from iris.data_pipeline.mongodb_manager import MongoDBManager
from iris.embedding_pipeline.embedding_handler import EmbeddingHandler
from iris.embedding_pipeline.embedding_database import EmbeddingDatabase
from tqdm.notebook import tqdm

In [3]:
# Initialize configuration managers
data_config = DataPipelineConfigManager()
embedding_config = EmbeddingPipelineConfigManager()

shop_config = data_config.shop_configs["nikolaj_storm"]  # Select shop
mongodb_config = data_config.mongodb_config

# Create MongoDB manager
mongodb_manager = MongoDBManager(shop_config, mongodb_config)

# Initialize the EmbeddingHandler and Database
embedding_handler = EmbeddingHandler(embedding_config.clip_config)
embedding_db = EmbeddingDatabase(embedding_config.database_config, shop_config)
embedding_db.load()

In [4]:
def get_best_product_for_mask(mask_hash, top_k=20):
    """Get the best matching product for a given mask hash using weighted voting."""
    # Get the mask's embedding and find similar images
    hash_index = embedding_db.ids.index(mask_hash)
    query_embedding = embedding_db.embeddings[hash_index]
    results = embedding_db.search(query_embedding, k=top_k)
    
    # Get product scores using weighted distances
    scores = {}
    
    # Skip first result (index 0) since it's the query mask itself
    for hash_val, dist in results[1:]:
        # Get product hash from the source image
        source_img = mongodb_manager.find_one(
            mongodb_manager.mongodb_config.image_metadata_collection,
            {'image_hash': hash_val}
        )
        if source_img:
            product_hash = source_img['source_product']
            # Use inverse distance as weight (add small epsilon to avoid division by zero)
            weight = 1.0 / (dist + 1e-6)
            scores[product_hash] = scores.get(product_hash, 0.0) + weight
    
    # Return the product with highest score if any matches found
    if scores:
        return max(scores.items(), key=lambda x: x[1])[0]
    return None

In [6]:
# Get all images with masks
image_collection = mongodb_manager.get_collection(mongodb_manager.mongodb_config.image_metadata_collection)
images_with_masks = list(image_collection.find({'masks': {'$exists': True, '$ne': []}}))

print(f"Found {len(images_with_masks)} images with masks")

# Process each image and its masks
total_masks = sum(len(img.get('masks', [])) for img in images_with_masks)
print(f"Total masks to process: {total_masks}")

for image in tqdm(images_with_masks, desc="Processing images"):
    for mask in image.get('masks', []):
        mask_hash = mask.get('mask_hash')
        if not mask_hash:
            continue
            
        # Skip if already has a product_hash and mask_point
        if mask.get('product_hash') and mask.get('mask_point'):
            continue
            
        try:
            # Find best matching product
            product_hash = get_best_product_for_mask(mask_hash)
            
            if product_hash:
                # Update MongoDB with product hash and mask point
                filter_query = {
                    'image_hash': image['image_hash'],
                    'masks.mask_hash': mask_hash
                }
                
                # Get point coordinates from the mask data
                point_coords = mask['point_coords'][0]  # Get first point
                
                update_data = {
                    'masks.$.mask_point': {
                        'x': float(point_coords[0]),
                        'y': float(point_coords[1])
                    },
                    'masks.$.product_hash': product_hash
                }
                
                # Update the document
                success = mongodb_manager.update_one(
                    mongodb_manager.mongodb_config.image_metadata_collection,
                    filter_query,
                    update_data,
                    upsert=False
                )
                
                if success:
                    print(f"Updated mask {mask_hash} with product {product_hash[:8]}...")
                else:
                    print(f"Failed to update mask {mask_hash}")
                    
        except Exception as e:
            print(f"Error processing mask {mask_hash}: {str(e)}")

print("\nProcessing complete!")

Found 366 images with masks
Total masks to process: 4403


Processing images:   0%|          | 0/366 [00:00<?, ?it/s]

Updated mask eab52cd1cda39c117bb63293ae8341aa with product 7462886b...
Updated mask d032c850fea4f285a202451826af7ff7 with product 7462886b...
Updated mask d032c850fea4f285a202451826af7ff7 with product 7462886b...
Updated mask 6d082230e3185cf04d0b17c4cc8cdaea with product c7f9e163...
Updated mask 6d082230e3185cf04d0b17c4cc8cdaea with product c7f9e163...
Updated mask 878e499614a64e8631dd940a754eabaf with product c7f9e163...
Updated mask 878e499614a64e8631dd940a754eabaf with product c7f9e163...
Updated mask ee7965f75187aedb85b0dd503e3f654f with product c7f9e163...
Updated mask ee7965f75187aedb85b0dd503e3f654f with product c7f9e163...
Updated mask c20aa9253818877d1ed1bf16e3d92554 with product c7f9e163...
Updated mask c20aa9253818877d1ed1bf16e3d92554 with product c7f9e163...
Updated mask 5359190ef50f22b3ffa27a9a6eb18c8e with product c7f9e163...
Updated mask 5359190ef50f22b3ffa27a9a6eb18c8e with product c7f9e163...
Updated mask 8a01affd0bb3e7be7b13a70149069ba1 with product ed833bf4...
Update