<a href="https://colab.research.google.com/github/mzafir/End_to_end_MLOPS_project/blob/master/clip_shoptalk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision clip-by-openai requests pandas chromadb milvus


Collecting clip-by-openai
  Downloading clip_by_openai-1.1-py3-none-any.whl.metadata (369 bytes)
Collecting chromadb
  Downloading chromadb-0.6.3-py3-none-any.whl.metadata (6.8 kB)
Collecting milvus
  Downloading milvus-2.3.5-py3-none-manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting ftfy (from clip-by-openai)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
INFO: pip is looking at multiple versions of clip-by-openai to determine which version is compatible with other requirements. This could take a while.
Collecting clip-by-openai
  Downloading clip_by_openai-1.0.1-py3-none-any.whl.metadata (407 bytes)
  Downloading clip_by_openai-0.1.1.5-py3-none-any.whl.metadata (8.6 kB)
  Downloading clip_by_openai-0.1.1.4-py3-none-any.whl.metadata (8.6 kB)
  Downloading clip_by_openai-0.1.1.3-py3-none-any.whl.metadata (8.7 kB)
  Downloading clip_by_openai-0.1.1.2-py3-none-any.whl.metadata (9.0 kB)
  Downloading clip_by_openai-0.1.1-py3-none-any.whl.metadata (9.0 kB)
  Downloading cl

In [None]:
import os
import torch
import clip
from PIL import Image
import requests
from io import BytesIO
import pandas as pd
from chromadb.utils.embedding_functions import EmbeddingFunction
from chromadb.client import Client
from chromadb.config import Settings


In [None]:
class DataPreprocessor:
    def __init__(self, data_path):
        self.data_path = data_path  # Path to the CSV or JSON file containing product data

    def load_data(self):
        # Load data from CSV (or other formats)
        df = pd.read_csv(self.data_path)
        return df

    def preprocess(self):
        df = self.load_data()
        # Ensure required columns exist
        required_columns = ['category', 'brand', 'description', 'image_link', 'tags']
        for col in required_columns:
            if col not in df.columns:
                raise ValueError(f"Missing required column: {col}")

        # Handle missing or invalid data
        df.fillna('', inplace=True)

        # Return preprocessed DataFrame
        return df


In [None]:
class EmbeddingGenerator:
    def __init__(self, device='cpu'):
        self.device = device
        self.model, self.preprocess = clip.load("ViT-B/32", device=self.device)

    def generate_text_embedding(self, text):
        # Generate text embeddings using CLIP
        text_tokens = clip.tokenize([text]).to(self.device)
        with torch.no_grad():
            text_embedding = self.model.encode_text(text_tokens)
        return text_embedding.cpu().numpy().flatten()

    def generate_image_embedding(self, image_url):
        # Download and preprocess the image
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content)).convert("RGB")
        image_input = self.preprocess(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            image_embedding = self.model.encode_image(image_input)
        return image_embedding.cpu().numpy().flatten()

    def generate_combined_embedding(self, text, image_url):
        # Generate both text and image embeddings and combine them
        text_embedding = self.generate_text_embedding(text)
        image_embedding = self.generate_image_embedding(image_url)
        return (text_embedding + image_embedding) / 2  # Combine by averaging


In [None]:
class VectorDB:
    def __init__(self):
        self.client = Client(
            Settings(
                chroma_db_impl="duckdb+parquet",
                persist_directory="chroma_db"
            )
        )
        self.collection_name = "product_embeddings"
        self.collection = None

    def create_collection(self):
        # Create a collection if it doesn't exist
        self.collection = self.client.get_or_create_collection(
            name=self.collection_name
        )

    def insert_embeddings(self, embeddings, metadata):
        # Insert embeddings with associated metadata
        ids = [str(i) for i in range(len(embeddings))]
        self.collection.add(
            ids=ids,
            embeddings=embeddings,
            metadatas=metadata
        )

    def query_similar(self, query_embedding, n_results=5):
        # Query the top-N similar embeddings
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results
        )
        return results


In [None]:
# Load and preprocess data
data_path = "product_data.csv"
preprocessor = DataPreprocessor(data_path)
data = preprocessor.preprocess()

# Initialize embedding generator
embedding_gen = EmbeddingGenerator(device="cuda" if torch.cuda.is_available() else "cpu")

# Generate embeddings and metadata
embeddings = []
metadata = []

for _, row in data.iterrows():
    text = row['description']
    image_url = row['image_link']
    combined_embedding = embedding_gen.generate_combined_embedding(text, image_url)
    embeddings.append(combined_embedding)
    metadata.append({
        "category": row['category'],
        "brand": row['brand'],
        "tags": row['tags']
    })

# Store embeddings in Chroma vector DB
vector_db = VectorDB()
vector_db.create_collection()
vector_db.insert_embeddings(embeddings, metadata)

# Query the vector DB for similar products
query_text = "A stylish red handbag"
query_embedding = embedding_gen.generate_text_embedding(query_text)
similar_items = vector_db.query_similar(query_embedding)
print(similar_items)
