In [14]:
from transformers import CLIPProcessor, CLIPModel
import torch
from PIL import Image
import numpy as np

In [3]:
class ProductModel:
    def __init__(self):
        self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

    def get_embeddings(self, image_path):
        image = Image.open(image_path).convert('RGB')
        inputs = self.processor(images=image, return_tensors="pt")
        with torch.no_grad():
            return self.model(**inputs).image_embeds.numpy()

In [4]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [6]:
image = Image.open('data/image2.jpg').convert('RGB')

In [7]:
inputs = processor(images=image, return_tensors="pt")

In [8]:
with torch.no_grad():
    # Use only the vision model to get image embeddings
    vision_outputs = model.vision_model(pixel_values=inputs['pixel_values'])
    image_embeds = vision_outputs.pooler_output  # Pooled output from vision transformer
    image_embeds = model.visual_projection(image_embeds)  # Project to common space

In [9]:
em = image_embeds.numpy()

In [12]:
metadata = {'name': 'image2', 'category': 'image2', 'price': 2.0, 'filename': 'image2.jpg'}

In [15]:
product_id = metadata.get("id", str(np.random.randint(10000)))

In [16]:
product_id

'3428'

In [17]:
from pymongo import MongoClient
import chromadb
from config import Config

class Database:
    def __init__(self):
        # MongoDB Atlas
        self.mongo_client = MongoClient(Config.MONGO_URI)
        self.db = self.mongo_client[Config.DB_NAME]
        self.products = self.db['products']
        self.logs = self.db['logs']
        
        # Local ChromaDB
        self.chroma_client = chromadb.Client()
        self.collection = self.chroma_client.create_collection(Config.CHROMA_COLLECTION, get_or_create=True)

    def add_product(self, embedding, metadata, product_id):
        try:
            self.collection.add(embeddings=embedding.tolist(), ids=[product_id])
            metadata["embedding_id"] = product_id
            self.products.insert_one(metadata)
            return True
        except Exception as e:
            self.logs.insert_one({"error": str(e)})
            return False

    def find_product(self, product_id):
        return self.products.find_one({"embedding_id": product_id})

    def query_vector(self, embedding):
        return self.collection.query(query_embeddings=embedding.tolist(), n_results=1)

    def log_error(self, error):
        self.logs.insert_one({"error": str(error)})

In [18]:
db = Database()

In [None]:
db.add_product(em, metadata, product_id)