In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

In [None]:
import sys
sys.path.append(r'c:\Users\ice\projects\iris')

from PIL import Image
from tqdm.notebook import tqdm
import plotly.graph_objects as go
from IPython.display import display, HTML
import base64
from io import BytesIO

from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.config.embedding_pipeline_config_manager import EmbeddingPipelineConfigManager
from iris.data_pipeline.mongodb_manager import MongoDBManager
from iris.embedding_pipeline.embedding_handler import EmbeddingHandler
from iris.embedding_pipeline.embedding_database import EmbeddingDatabase

In [None]:
# Initialize configuration managers
data_config = DataPipelineConfigManager()
embedding_config = EmbeddingPipelineConfigManager()

shop_config = data_config.shop_configs["nikolaj_storm"]  # Select shop
mongodb_config = data_config.mongodb_config

# Create MongoDB manager
mongodb_manager = MongoDBManager(shop_config, mongodb_config)

# Initialize the EmbeddingHandler and Database
embedding_handler = EmbeddingHandler(embedding_config.clip_config)
embedding_db = EmbeddingDatabase(embedding_config.database_config, shop_config)

In [None]:
# Convert cursor to list since it can only be iterated once
image_dataset = list(mongodb_manager.get_collection(
    mongodb_manager.mongodb_config.image_metadata_collection
).find())


# Add a progress bar with total count
for image_data in tqdm(image_dataset, desc="Processing images", total=len(image_dataset)):
    # Process main image
    with Image.open(image_data['local_path']).convert("RGBA") as img:
        width, height = img.size  # Get image dimensions

        # Use augmented embedding
        embedding = embedding_handler.get_augmented_embedding(img)
        embedding_db.add_embedding(embedding, id=image_data['image_hash'])
        
        # Process localizations if they exist
        if 'localizations' in image_data:
            for localization in image_data['localizations']:
                # Get bounding box from original image
                bbox = localization['bbox']  # [x, y, width, height]
                bbox_img = img.crop((
                    bbox[0] * width, 
                    bbox[1] * height, 
                    (bbox[0] + bbox[2]) * width, 
                    (bbox[1] + bbox[3]) * height
                ))
                
                # Use augmented embedding for localizations too
                mask_embedding = embedding_handler.get_augmented_embedding(bbox_img)
                embedding_db.add_embedding(mask_embedding, id=localization['localization_hash'])

# Save the embeddings database to disk
embedding_db.save()
print(f"Saved {len(embedding_db.ids)} embeddings to {embedding_db.database_directory}")
print(f"Complete! Database created with {len(embedding_db.ids)} embeddings")

In [None]:
# Get the query hash's index and embedding
query_hash = '3068271dd98042007498c8ee1a4604dc'

# Find index of the query hash in the database
hash_index = embedding_db.ids.index(query_hash)
query_embedding = embedding_db.embeddings[hash_index]

# Search for nearest neighbors
results = embedding_db.search(query_embedding, k=50)
print("Nearest neighbors for query hash", query_hash, ":", results)

In [None]:
def get_hash_source(hash_val):
    """Get source data for a hash (either image path or mask data)"""
    # First check if it's a regular image hash
    for data in image_dataset:
        if data['image_hash'] == hash_val:
            return {'type': 'image', 'path': data['local_path']}
        # Then check if it's a mask hash
        if 'masks' in data:
            for mask in data['masks']:
                if mask['mask_hash'] == hash_val:
                    return {'type': 'mask', 'mask_data': mask, 'parent_image': data}
    return None

def get_image_html(source, size=(100, 140)):
    """Create HTML img tag for either an image path or mask data"""
    if source['type'] == 'image':
        with Image.open(source['path']) as img:
            img.thumbnail(size)
            buffered = BytesIO()
            img.save(buffered, format='JPEG', quality=70)
            img_b64 = base64.b64encode(buffered.getvalue()).decode()
            return f'<img src="data:image/jpeg;base64,{img_b64}" style="max-width:none">'
    else:  # mask
        mask_data = source['mask_data']
        parent_data = source['parent_image']
        
        # Get bounding box from original image
        with Image.open(parent_data['local_path']) as img:
            bbox = mask_data['bbox']  # [x, y, width, height]
            bbox_img = img.crop((bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]))
            bbox_img.thumbnail(size)
            
            buffered = BytesIO()
            bbox_img.save(buffered, format='JPEG', quality=70)
            img_b64 = base64.b64encode(buffered.getvalue()).decode()
            return f'<img src="data:image/jpeg;base64,{img_b64}" style="max-width:none">'

In [None]:
# Extract data for plot
indices = list(range(len(results)))
distances = [dist for _, dist in results]

# Create scatter plot with simpler hover text
fig = go.Figure(data=go.Scatter(
    x=indices,
    y=distances,
    mode='lines+markers',
    marker=dict(size=8),
    text=[f'Index: {i}<br>Hash: {hash_val}<br>Distance: {dist:.4f}' 
          for i, (hash_val, dist) in enumerate(results)],
    hovertemplate='%{text}<extra></extra>'
))

fig.update_layout(
    title='Image Similarity Plot',
    xaxis_title='Index',
    yaxis_title='Distance',
    width=800,
    height=400,
    showlegend=False
)

# Display plot
fig.show()

# Create grid of thumbnails below the plot
html = ['<div style="display: flex; flex-wrap: wrap; gap: 10px;">']

for i, (hash_val, dist) in enumerate(results):
    source = get_hash_source(hash_val)
    if source is None:
        continue

    img_html = get_image_html(source, size=(150, 210))
    item_html = f"""
    <div style='text-align: center; border: 1px solid #ddd; padding: 5px;'>
        {img_html}
        <br>
        <small>Index: {i}</small><br>
        <small>Distance: {dist:.4f}</small><br>
        <small>Hash: {hash_val}</small>
        <small>{'(Mask)' if source['type'] == 'mask' else ''}</small>
    </div>
    """
    html.append(item_html)

html.append('</div>')
display(HTML(''.join(html)))