In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

In [None]:
import sys
sys.path.append(r'c:\Users\ice\projects\iris')

from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.config.embedding_pipeline_config_manager import EmbeddingPipelineConfigManager
from iris.data_pipeline.mongodb_manager import MongoDBManager
from iris.embedding_pipeline.embedding_handler import EmbeddingHandler
from iris.embedding_pipeline.embedding_database import EmbeddingDatabase
from tqdm.notebook import tqdm

In [None]:
# Initialize configuration managers
data_config = DataPipelineConfigManager()
embedding_config = EmbeddingPipelineConfigManager()

shop_config = data_config.shop_configs["nikolaj_storm"]  # Select shop
mongodb_config = data_config.mongodb_config

# Create MongoDB manager
mongodb_manager = MongoDBManager(shop_config, mongodb_config)

# Initialize the EmbeddingHandler and Database
embedding_handler = EmbeddingHandler(embedding_config.clip_config)
embedding_db = EmbeddingDatabase(embedding_config.database_config, shop_config)
embedding_db.load()

In [None]:
def get_product_predictions(localization_hash, top_k=20):
    """Get the best matching product for a given localization hash using weighted voting.
    Returns a dict containing best product and all predictions sorted by score."""
    # Get the localization's embedding and find similar images
    hash_index = embedding_db.ids.index(localization_hash)
    query_embedding = embedding_db.embeddings[hash_index]
    results = embedding_db.search(query_embedding, k=top_k)
    
    # Get product scores using weighted distances
    scores = {}
    
    # Skip first result (index 0) since it's the query localization itself
    for hash_val, dist in results[1:]:
        # Get product hash from the source image
        query = {
            '$or': [
                {'image_hash': hash_val},
                {'localizations.localization_hash': hash_val}
            ]
        }
        source_img = mongodb_manager.find_one(
            mongodb_manager.mongodb_config.image_metadata_collection,
            query
        )
        if source_img:
            product_hash = source_img['source_product']
            # Use inverse distance as weight (add small epsilon to avoid division by zero)
            weight = 1.0 / (dist + 1e-6)
            scores[product_hash] = scores.get(product_hash, 0.0) + weight
    
    # Sort products by score
    sorted_predictions = {
        product_hash: float(score)  # Convert numpy float32 to Python float
        for product_hash, score in sorted(scores.items(), key=lambda x: x[1], reverse=True)
    }
    
    return sorted_predictions

In [None]:
# Get all images with localizations
image_collection = mongodb_manager.get_collection(mongodb_manager.mongodb_config.image_metadata_collection)
images_with_localizations = list(image_collection.find({'localizations': {'$exists': True, '$ne': []}}))

print(f"Found {len(images_with_localizations)} images with localizations")

# Process each image and its localizations
total_localizations = sum(len(img['localizations']) for img in images_with_localizations)
print(f"Total localizations to process: {total_localizations}")

for image in tqdm(images_with_localizations, desc="Processing images"):
    for localization in image['localizations']:
        localization_hash = localization['localization_hash']
        if not localization_hash:
            continue
            
        # Skip if already has a product_hash and localization_point
        if localization.get('product_hash') and localization.get('localization_point'):
            pass # continue
            
        try:
            # Find best matching product and get all predictions
            product_predictions = get_product_predictions(localization_hash)
            
            if product_predictions:
                # Update MongoDB with localization point and predictions
                filter_query = {
                    'image_hash': image['image_hash'],
                    'localizations.localization_hash': localization_hash
                }
                
                # Get point coordinates from the localization data
                bbox = localization['bbox']
                point_coords = (
                    bbox[0] + bbox[2] / 2,  # x coordinate (center of the bounding box)
                    bbox[1] + bbox[3] / 2   # y coordinate (center of the bounding box)
                )

                update_data = {
                    'localizations.$.localization_point': {
                        'x': float(point_coords[0]),
                        'y': float(point_coords[1])
                    },
                    'localizations.$.product_predictions': product_predictions
                }
                
                # Update the document
                success = mongodb_manager.update_one(
                    mongodb_manager.mongodb_config.image_metadata_collection,
                    filter_query,
                    update_data,
                    upsert=False
                )
                
                if success:
                    print(f"Updated localization {localization_hash} with {len(product_predictions)} predictions")
                else:
                    print(f"Failed to update localization {localization_hash}")
                    
        except Exception as e:
            print(f"Error processing localization {localization_hash}: {str(e)}")

print("\nProcessing complete!")