In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

# Text Embeddings with OpenCLIP

This notebook demonstrates how to generate text embeddings using OpenCLIP, similar to how we handle image embeddings in our existing codebase.

In [None]:
import open_clip
import torch
from pymongo import MongoClient
import numpy as np
from iris.config.embedding_pipeline_config_manager import EmbeddingPipelineConfigManager
from iris.config.data_pipeline_config_manager import DataPipelineConfigManager
from iris.embedding_pipeline.embedding_database import EmbeddingDatabase
from iris.data_pipeline.mongodb_manager import MongoDBManager
import iris.utils.data_utils as data_utils



# Initialize CLIP model
model_name = "ViT-B-32"
device = "cuda" if torch.cuda.is_available() else "cpu"

model, _, preprocess = open_clip.create_model_and_transforms(
    model_name,
    pretrained="laion2b_s34b_b79k",
    device=device
)
model.eval()

# Initialize embedding database
db_config = EmbeddingPipelineConfigManager().database_config
data_config = DataPipelineConfigManager()
shop_config = data_config.shop_config
mongodb_manager = MongoDBManager(data_config.mongodb_config)
embedding_db = EmbeddingDatabase(db_config, shop_config)
embedding_db.load()

# Get data
product_collection = mongodb_manager.get_collection(
    mongodb_manager.config.product_collection
)

image_collection = mongodb_manager.get_collection(
    mongodb_manager.config.image_metadata_collection
)

## Basic Text Embedding Generation

Let's create a function to generate text embeddings:

In [None]:
def get_text_embedding(text: str) -> np.ndarray:
    """Generate embeddings for a text input."""
    with torch.no_grad():
        text_tokens = open_clip.tokenize([text]).to(device)
        text_features = model.encode_text(text_tokens)
        # Normalize embeddings
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features.cpu().numpy()[0]

In [None]:
def test_product(product: dict) -> None:
    # Generate text embedding for the description
    text_embedding = get_text_embedding(product['description'])

    # For random image, analyze image and its localizations
    image_hash = np.random.choice(product['images'])    
    image_data = image_collection.find_one({"image_hash": image_hash})
    
    # For each localization
    results = {}
    for loc in image_data['localizations']:
        # Get the localization hash from metadata
        loc_hash = loc['hash']
            
        # Search for this localization in embedding database
        loc_embedding = embedding_db.get_embedding(loc_hash)
        distance = np.linalg.norm(loc_embedding - text_embedding)
        results[loc_hash] = distance
    
    # Sort results by distance
    sorted_results = sorted(results.items(), key=lambda x: x[1])
    print(f"Top localizations for image {image_hash}:")
    for loc_hash, distance in sorted_results:
        print(f"\tLocalization: {loc_hash}, Distance: {distance:.4f}")
    
    data_utils.display_image_summary(mongodb_manager, image_hash)

In [None]:
# Get random products with descriptions
products_with_desc = list(product_collection.find({
    "description": {"$ne": "NOT_FOUND"}
}))

# Test 5 random products
import random
random_products = random.sample(products_with_desc, 5)

for product in random_products:
    print("\n" + "="*80)
    print(f"\nTesting product: {product['title']}: {product['hash']}")
    print("="*80 + "\n")
    test_product(product)