# Vector Database Preparation for Wardrobe AI

This notebook prepares the vector database for the Wardrobe AI application by processing clothing images and creating embeddings for similarity search.

## Overview
- Process t-shirt and pant images from data samples
- Generate AI-powered descriptions using OpenAI's vision models
- Create embeddings and populate ChromaDB vector store
- Setup MultiVectorRetriever for efficient image retrieval

## 1. Environment Setup and Library Imports

Import all necessary libraries and load environment variables.

In [None]:
import os
import base64
import glob
from pathlib import Path
import uuid

# Image processing
import cv2
from PIL import Image
import numpy as np

# AI and ML libraries
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.messages import HumanMessage
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.vectorstores import Chroma
from langchain.storage import LocalFileStore
from langchain.schema import Document

# Load environment variables
from dotenv import load_dotenv
load_dotenv()

# Verify OpenAI API key is loaded
if not os.getenv("OPENAI_API_KEY"):
    print("⚠️  Warning: OPENAI_API_KEY not found in environment variables")
    print("Please add your OpenAI API key to a .env file")
else:
    print("✅ OpenAI API key loaded successfully")

## 2. Configuration and Constants

Define all configuration variables and file paths.

In [None]:
# Configuration
CONFIG = {
    "embeddings_model": "text-embedding-3-small",
    "vision_model": "gpt-4o-mini",
    "image_formats": [".jpg", ".jpeg", ".png", ".bmp", ".webp"],
    "target_size": (512, 512),
    "chroma_persist_dir": "../chroma_langchain_db",
    "docstore_dir": "../TSHIRT_DOCSTORE"
}

# File paths
PATHS = {
    "tshirt_samples": "../data/tshirt_samples",
    "pant_samples": "../data/pant_samples",
    "tshirt_dataset": "../tshirt",
    "pant_dataset": "../pant-dataset"
}

# Collection names
COLLECTIONS = {
    "tshirt": "tshirt",
    "pant": "pant"
}

print("Configuration loaded:")
print(f"📁 T-shirt samples: {PATHS['tshirt_samples']}")
print(f"📁 Pant samples: {PATHS['pant_samples']}")
print(f"🔍 Embeddings model: {CONFIG['embeddings_model']}")
print(f"👁️ Vision model: {CONFIG['vision_model']}")

## 3. Image Processing Functions

Implement functions for image loading, validation, and preprocessing.

In [None]:
def validate_image(image_path):
    """
    Validate if an image file is readable and has proper format.
    
    Args:
        image_path (str): Path to the image file
        
    Returns:
        bool: True if image is valid, False otherwise
    """
    try:
        with Image.open(image_path) as img:
            img.verify()
        return True
    except Exception as e:
        print(f"❌ Invalid image {image_path}: {e}")
        return False

def resize_image(image_path, target_size=(512, 512)):
    """
    Resize image to target size while maintaining aspect ratio.
    
    Args:
        image_path (str): Path to the input image
        target_size (tuple): Target size (width, height)
        
    Returns:
        PIL.Image: Resized image
    """
    try:
        with Image.open(image_path) as img:
            # Convert to RGB if necessary
            if img.mode != 'RGB':
                img = img.convert('RGB')
            
            # Resize with aspect ratio preservation
            img.thumbnail(target_size, Image.Resampling.LANCZOS)
            
            # Create new image with target size and paste resized image
            new_img = Image.new('RGB', target_size, (255, 255, 255))
            paste_x = (target_size[0] - img.width) // 2
            paste_y = (target_size[1] - img.height) // 2
            new_img.paste(img, (paste_x, paste_y))
            
            return new_img
    except Exception as e:
        print(f"❌ Error resizing {image_path}: {e}")
        return None

def get_image_files(directory):
    """
    Get all valid image files from a directory.
    
    Args:
        directory (str): Directory path
        
    Returns:
        list: List of valid image file paths
    """
    if not os.path.exists(directory):
        print(f"⚠️  Directory not found: {directory}")
        return []
    
    image_files = []
    for ext in CONFIG['image_formats']:
        pattern = os.path.join(directory, f"*{ext}")
        image_files.extend(glob.glob(pattern))
        pattern = os.path.join(directory, f"*{ext.upper()}")
        image_files.extend(glob.glob(pattern))
    
    # Validate images
    valid_images = []
    for img_path in image_files:
        if validate_image(img_path):
            valid_images.append(img_path)
    
    print(f"📸 Found {len(valid_images)} valid images in {directory}")
    return valid_images

# Test the functions
print("\n🧪 Testing image processing functions...")
tshirt_files = get_image_files(PATHS['tshirt_samples'])
pant_files = get_image_files(PATHS['pant_samples'])

# Also check larger datasets if they exist
tshirt_dataset_files = get_image_files(PATHS['tshirt_dataset'])
pant_dataset_files = get_image_files(PATHS['pant_dataset'])

## 4. Vector Database Setup

Initialize ChromaDB collections and configure the retrieval system.

In [None]:
def setup_vector_stores():
    """
    Initialize ChromaDB vector stores and document store.
    
    Returns:
        tuple: (tshirt_retriever, pant_retriever, file_store)
    """
    # Initialize embeddings
    embeddings = OpenAIEmbeddings(model=CONFIG['embeddings_model'])
    
    # Initialize document store
    file_store = LocalFileStore(CONFIG['docstore_dir'])
    
    # Initialize ChromaDB collections
    tshirt_vectorstore = Chroma(
        collection_name=COLLECTIONS['tshirt'],
        embedding_function=embeddings,
        persist_directory=CONFIG['chroma_persist_dir']
    )
    
    pant_vectorstore = Chroma(
        collection_name=COLLECTIONS['pant'],
        embedding_function=embeddings,
        persist_directory=CONFIG['chroma_persist_dir']
    )
    
    # Initialize retrievers
    tshirt_retriever = MultiVectorRetriever(
        vectorstore=tshirt_vectorstore,
        docstore=file_store,
        id_key="doc_id",
        return_doc_ids=True
    )
    
    pant_retriever = MultiVectorRetriever(
        vectorstore=pant_vectorstore,
        docstore=file_store,
        id_key="doc_id",
        return_doc_ids=True
    )
    
    print("✅ Vector stores initialized successfully")
    print(f"📂 ChromaDB directory: {CONFIG['chroma_persist_dir']}")
    print(f"📂 Document store: {CONFIG['docstore_dir']}")
    
    return tshirt_retriever, pant_retriever, file_store

# Initialize vector stores
tshirt_retriever, pant_retriever, file_store = setup_vector_stores()

## 5. Image Encoding and Base64 Conversion

Functions to encode images for storage and AI processing.

In [None]:
def encode_image_to_base64(image_path):
    """
    Encode an image file to base64 string.
    
    Args:
        image_path (str): Path to the image file
        
    Returns:
        str: Base64 encoded string or None if error
    """
    try:
        with open(image_path, "rb") as img_file:
            base64_string = base64.b64encode(img_file.read()).decode("utf-8")
        return base64_string
    except Exception as e:
        print(f"❌ Error encoding {image_path}: {e}")
        return None

def pil_to_base64(pil_image, format='PNG'):
    """
    Convert PIL Image to base64 string.
    
    Args:
        pil_image (PIL.Image): PIL Image object
        format (str): Image format (PNG, JPEG, etc.)
        
    Returns:
        str: Base64 encoded string
    """
    from io import BytesIO
    
    buffer = BytesIO()
    pil_image.save(buffer, format=format)
    img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return img_str

def create_base64_data_url(base64_string, image_format='png'):
    """
    Create a data URL from base64 string.
    
    Args:
        base64_string (str): Base64 encoded image
        image_format (str): Image format
        
    Returns:
        str: Data URL string
    """
    return f"data:image/{image_format};base64,{base64_string}"

# Test encoding function
print("\n🧪 Testing image encoding...")
if tshirt_files:
    test_image = tshirt_files[0]
    encoded = encode_image_to_base64(test_image)
    if encoded:
        print(f"✅ Successfully encoded {os.path.basename(test_image)}")
        print(f"📏 Encoded length: {len(encoded)} characters")
    else:
        print("❌ Failed to encode test image")

## 6. Embedding Generation and Storage

Generate AI descriptions and embeddings for images.

In [None]:
def generate_image_description(image_base64, clothing_type, model_name=None):
    """
    Generate description of clothing item using OpenAI's vision model.
    
    Args:
        image_base64 (str): Base64 encoded image
        clothing_type (str): Type of clothing ('tshirt' or 'pant')
        model_name (str): Model name to use
        
    Returns:
        str: Generated description or None if error
    """
    if model_name is None:
        model_name = CONFIG['vision_model']
    
    try:
        chat = ChatOpenAI(model=model_name, api_key=os.getenv("OPENAI_API_KEY"))
        
        if clothing_type.lower() == 'tshirt':
            prompt = """You are an assistant tasked with summarizing t-shirt images for retrieval.
            These summaries will be embedded and used to retrieve the raw image.
            Give a concise summary of the t-shirt that is well optimized for retrieval in a single line.
            Focus on style, color, patterns, neckline, sleeves, and any distinctive features.
            Do not talk about anything else just the t-shirt you see in the image."""
        else:  # pant
            prompt = """You are an assistant tasked with summarizing pant images for retrieval.
            These summaries will be embedded and used to retrieve the raw image.
            Give a concise summary of the pants that is well optimized for retrieval in a single line.
            Focus on style, color, fit, material, and any distinctive features.
            Do not talk about anything else just the pants you see in the image."""
        
        msg = chat.invoke([
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
                    }
                ]
            )
        ])
        
        return msg.content.strip()
        
    except Exception as e:
        print(f"❌ Error generating description: {e}")
        return None

def process_single_image(image_path, clothing_type, retriever, file_store):
    """
    Process a single image and add it to the vector store.
    
    Args:
        image_path (str): Path to the image
        clothing_type (str): Type of clothing
        retriever: Vector retriever object
        file_store: Document store object
        
    Returns:
        bool: True if successful, False otherwise
    """
    try:
        # Generate unique document ID
        doc_id = str(uuid.uuid4())
        
        # Resize and encode image
        resized_img = resize_image(image_path, CONFIG['target_size'])
        if resized_img is None:
            return False
        
        base64_encoded = pil_to_base64(resized_img)
        if not base64_encoded:
            return False
        
        # Generate description
        description = generate_image_description(base64_encoded, clothing_type)
        if not description:
            print(f"⚠️  Could not generate description for {image_path}")
            return False
        
        # Store in document store
        file_store.mset([(doc_id, base64_encoded.encode('utf-8'))])
        
        # Create document with metadata
        doc = Document(
            page_content=description,
            metadata={
                "doc_id": doc_id,
                "source": image_path,
                "clothing_type": clothing_type,
                "description": description
            }
        )
        
        # Add to vector store
        retriever.vectorstore.add_documents([doc])
        
        print(f"✅ Processed: {os.path.basename(image_path)} -> {description[:50]}...")
        return True
        
    except Exception as e:
        print(f"❌ Error processing {image_path}: {e}")
        return False

# Test with a single image
print("\n🧪 Testing single image processing...")
if tshirt_files:
    success = process_single_image(
        tshirt_files[0], 
        'tshirt', 
        tshirt_retriever, 
        file_store
    )
    if success:
        print("✅ Single image processing test passed")
    else:
        print("❌ Single image processing test failed")

## 7. Vector Store Population

Batch process all sample images and populate the vector database.

In [None]:
def populate_vector_store(image_files, clothing_type, retriever, file_store, max_images=None):
    """
    Populate vector store with multiple images.
    
    Args:
        image_files (list): List of image file paths
        clothing_type (str): Type of clothing
        retriever: Vector retriever object
        file_store: Document store object
        max_images (int): Maximum number of images to process (None for all)
        
    Returns:
        dict: Processing statistics
    """
    if max_images:
        image_files = image_files[:max_images]
    
    stats = {
        'total': len(image_files),
        'successful': 0,
        'failed': 0,
        'descriptions': []
    }
    
    print(f"\n🚀 Processing {stats['total']} {clothing_type} images...")
    
    for i, image_path in enumerate(image_files, 1):
        print(f"Progress: {i}/{stats['total']} - {os.path.basename(image_path)}")
        
        success = process_single_image(image_path, clothing_type, retriever, file_store)
        
        if success:
            stats['successful'] += 1
        else:
            stats['failed'] += 1
        
        # Progress update every 10 images
        if i % 10 == 0:
            print(f"📊 Progress: {i}/{stats['total']} ({stats['successful']} successful, {stats['failed']} failed)")
    
    print(f"\n✅ {clothing_type.capitalize()} processing complete!")
    print(f"📊 Successful: {stats['successful']}/{stats['total']}")
    print(f"❌ Failed: {stats['failed']}/{stats['total']}")
    
    return stats

# Process T-shirt images (limit to first 20 from samples for demo)
print("=" * 60)
print("PROCESSING T-SHIRT IMAGES")
print("=" * 60)

# Combine sample and dataset files, prioritizing samples
all_tshirt_files = tshirt_files + tshirt_dataset_files[:20] if tshirt_dataset_files else tshirt_files
tshirt_stats = populate_vector_store(
    all_tshirt_files[:25],  # Limit for demo
    'tshirt',
    tshirt_retriever,
    file_store
)

# Process Pant images
print("\n" + "=" * 60)
print("PROCESSING PANT IMAGES")
print("=" * 60)

# Combine sample and dataset files, prioritizing samples
all_pant_files = pant_files + pant_dataset_files[:20] if pant_dataset_files else pant_files
pant_stats = populate_vector_store(
    all_pant_files[:25],  # Limit for demo
    'pant',
    pant_retriever,
    file_store
)

## 8. Database Validation and Testing

Test the vector database functionality with sample queries.

In [None]:
def test_similarity_search(retriever, query, collection_name, top_k=3):
    """
    Test similarity search functionality.
    
    Args:
        retriever: Vector retriever object
        query (str): Search query
        collection_name (str): Name of the collection
        top_k (int): Number of results to return
        
    Returns:
        list: Search results
    """
    try:
        print(f"\n🔍 Testing {collection_name} search with query: '{query}'")
        
        results = retriever.vectorstore.similarity_search(query, k=top_k)
        
        print(f"📋 Found {len(results)} results:")
        for i, result in enumerate(results, 1):
            print(f"{i}. {result.page_content}")
            if 'doc_id' in result.metadata:
                print(f"   Doc ID: {result.metadata['doc_id']}")
        
        return results
        
    except Exception as e:
        print(f"❌ Error during search: {e}")
        return []

def test_image_retrieval(retriever, file_store, query, collection_name):
    """
    Test complete image retrieval workflow.
    
    Args:
        retriever: Vector retriever object
        file_store: Document store object
        query (str): Search query
        collection_name (str): Name of the collection
        
    Returns:
        str: Base64 encoded image or None
    """
    try:
        print(f"\n🖼️  Testing image retrieval for {collection_name}: '{query}'")
        
        # Search for similar items
        results = retriever.vectorstore.similarity_search(query, k=1)
        
        if not results:
            print("❌ No results found")
            return None
        
        # Get document ID
        doc_id = results[0].metadata.get('doc_id')
        if not doc_id:
            print("❌ No document ID found")
            return None
        
        # Retrieve image from document store
        image_bytes = file_store.mget([doc_id])
        if not image_bytes or not image_bytes[0]:
            print("❌ Could not retrieve image from document store")
            return None
        
        image_base64 = image_bytes[0].decode("utf-8")
        print(f"✅ Successfully retrieved image (length: {len(image_base64)})")
        print(f"📝 Description: {results[0].page_content}")
        
        return image_base64
        
    except Exception as e:
        print(f"❌ Error during image retrieval: {e}")
        return None

# Test queries
test_queries = {
    'tshirt': [
        "red polo shirt",
        "casual white t-shirt",
        "vintage band tee",
        "athletic wear"
    ],
    'pant': [
        "blue jeans",
        "black formal pants",
        "casual khakis",
        "athletic shorts"
    ]
}

print("\n" + "=" * 60)
print("TESTING VECTOR DATABASE")
print("=" * 60)

# Test T-shirt searches
print("\n🔍 Testing T-shirt Collection:")
for query in test_queries['tshirt'][:2]:  # Test first 2 queries
    results = test_similarity_search(tshirt_retriever, query, "T-shirt", top_k=2)
    if results:
        # Test full retrieval for first result
        image_data = test_image_retrieval(tshirt_retriever, file_store, query, "T-shirt")

# Test Pant searches
print("\n🔍 Testing Pant Collection:")
for query in test_queries['pant'][:2]:  # Test first 2 queries
    results = test_similarity_search(pant_retriever, query, "Pant", top_k=2)
    if results:
        # Test full retrieval for first result
        image_data = test_image_retrieval(pant_retriever, file_store, query, "Pant")

# Database statistics
print("\n" + "=" * 60)
print("DATABASE STATISTICS")
print("=" * 60)

try:
    tshirt_count = tshirt_retriever.vectorstore._collection.count()
    pant_count = pant_retriever.vectorstore._collection.count()
    
    print(f"📊 T-shirt collection: {tshirt_count} items")
    print(f"📊 Pant collection: {pant_count} items")
    print(f"📊 Total items: {tshirt_count + pant_count}")
    
except Exception as e:
    print(f"⚠️  Could not get collection counts: {e}")

print("\n✅ Vector database preparation complete!")
print("🚀 Ready to run the Wardrobe AI application!")

## Summary

The vector database has been successfully prepared with:

1. **Image Processing**: Resized and validated clothing images
2. **AI Descriptions**: Generated using OpenAI's vision models
3. **Vector Embeddings**: Created using OpenAI's embedding models
4. **ChromaDB Storage**: Organized in separate collections for t-shirts and pants
5. **Document Store**: Base64 encoded images stored for retrieval

### Next Steps

1. Run the main Streamlit application: `streamlit run src/app.py`
2. Upload a pant image to get t-shirt recommendations
3. Try the AR overlay feature with your webcam

### Notes

- The database is persistent and will be saved to disk
- You can add more images to the data folders and re-run this notebook
- Adjust the `max_images` parameter to process more/fewer images
- Monitor OpenAI API usage when processing large image datasets