In [14]:
from pydantic import BaseModel
from tqdm import tqdm
import json
from secretstuff.secret import OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJ_ID
from services.mongodb import catalogue, CatalogueItem
from services.metadata import get_catalogue_metadata
from openai import OpenAI
from typing import Literal
import random
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_dir)


client = OpenAI(api_key=OPENAI_API_KEY, organization=OPENAI_ORG_ID, project=OPENAI_PROJ_ID)
bucket_count = get_catalogue_metadata().bucket_count


def add_catalogue_item(name: str, category: Literal['Tops', 'Bottoms', 'Shoes', 'Dresses'], price: float, image_url: str, product_url: str, retailer: str):
    if catalogue.find_one({"image_url": image_url}) is not None:
        print(f"Image URL exist in DB {image_url}")
        return

    tags = get_openai_tags(name, category, image_url)
    embedding = get_openai_embedding(tags)
    catalogue_item = CatalogueItem(name=name, category=category, clothing_type=tags['clothing_type'], clothing_type_embed=embedding[0],
                                   color=tags['color'], color_embed=embedding[1], material=tags['material'], material_embed=embedding[2],
                                   other_tags=tags['other'], other_tags_embed=embedding[3],
                                   price=price, image_url=image_url, product_url=product_url, retailer=retailer, gender="U", bucket_num=random.randint(1, bucket_count))
    catalogue.insert_one(dict(catalogue_item))


class ClothingTags(BaseModel):
    clothing_type: str
    color: str
    material: str
    other: list[str]


def get_openai_tags(name: str, category: Literal['Tops', 'Bottoms', 'Shoes', 'Dresses'], image_url: str) -> str:
    user_prompt = f"Generate tags for this {name} ({category}), including clothing type, color, material, other adjectives (eg. occasion, fit, sleeve, brand). Give tags in all lowercase."
    output = client.beta.chat.completions.parse(model="gpt-4o-mini",
                                                messages=[
                                                    {"role": "user", "content": [
                                                        {"type": "image_url", "image_url": {"url": image_url, "detail": "low"}}]},
                                                    {"role": "user", "content": user_prompt},
                                                ],
                                                response_format=ClothingTags
                                                )
    # Since default n = 1, we'll only always need to first element
    return json.loads(output.choices[0].message.content)


def get_openai_embedding(tags: json):
    clothing_type_embed = client.embeddings.create(
        input=tags['clothing_type'], model="text-embedding-3-large").data[0].embedding
    color_embed = client.embeddings.create(input=tags['color'], model="text-embedding-3-large").data[0].embedding
    material_embed = client.embeddings.create(input=tags['material'], model="text-embedding-3-large").data[0].embedding
    other_embed = []
    for o in tags['other']:
        other_embed.append(client.embeddings.create(input=o, model="text-embedding-3-large").data[0].embedding)
    return (clothing_type_embed, color_embed, material_embed, other_embed)

In [15]:
# Load product data from json
with open('products.json', 'r') as file:
    data = json.load(file)

# Remove duplicates based on 'image_url' key
unique_data = {item["image_url"]: item for item in data}.values()
data = list(unique_data)

In [16]:
# Insert into DB with multithreading

import threading


def insert_data(data):
    for item in data:
        try:
            add_catalogue_item(name=item['name'], category=item['category'], price=item['price'],
                               image_url=item['image_url'], product_url=item['product_link'], retailer=item['retailer'])
        except Exception as e:
            print(f"Error with {item['name']}: {e}")


number_of_threads = 20  # Having too many threads will exceed rate limit
for i in range(number_of_threads):
    thread = threading.Thread(target=insert_data, args=[
                              data[(i * len(data) // number_of_threads):((i+1)*len(data) // number_of_threads)]])
    thread.start()