In [1]:
import sys
import os
from dotenv import load_dotenv
from datetime import timedelta

import random
from typing import List
import time
import pandas as pd
from IPython.display import display, Image, HTML

# Add the services/app directory to the path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '../services/etl')))

from utils import EncoderClient, Minio, VectorDatabase, ETL, LLMEncoderClient, ReRankerClient, BM25Client

[nltk_data] Downloading package punkt to /home/taile/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
load_dotenv("../services/etl/.env")
os.environ["NO_PROXY"] = os.getenv("NO_PROXY_HOST", "None")

# Initialize the CLIP encoder
print("Initializing the CLIP encoder...")
encoder = EncoderClient(
    host=os.getenv("MODEL_SERVING_HOST"),
    port=os.getenv("MODEL_SERVING_PORT"),
)

# Initialize the Milvus client and create a collection
print("Initializing the Milvus client...")
milvus_client = VectorDatabase(
    host=os.getenv("MILVUS_HOST"),
    port=os.getenv("MILVUS_PORT")
)

milvus_collections = [os.getenv("MILVUS_COLLECTION"), os.getenv("MILVUS_NAME_COLLECTION"), os.getenv("MILVUS_CATEGORY_COLLECTION")]
for collection_name in milvus_collections:
    load_state = milvus_client.load_collection(collection_name)
    if not load_state:
        raise Exception(f"Failed to load the Milvus collection {collection_name} into memory")
milvus_index_type = os.getenv("MILVUS_INDEX_TYPE")
milvus_metric_type = os.getenv("MILVUS_METRIC_TYPE")

# Initialize the Minio client and create a bucket
print("Initializing the Minio client...")
bucket_name = os.getenv("MINIO_BUCKET_NAME")
minio_client = Minio(
    endpoint=os.getenv("MINIO_ENDPOINT"),
    access_key=os.getenv("MINIO_ACCESS_KEY_ID"),
    secret_key=os.getenv("MINIO_SECRET_ACCESS_KEY"),
    secure=False
)
if not minio_client.bucket_exists(bucket_name):
    raise Exception(f"Bucket {bucket_name} does not exist!")

# Initialize the LLM encoder
print("Initializing the LLM encoder...")
LLMEncoder = LLMEncoderClient(
    host=os.getenv("MODEL_SERVING_HOST"),
    port=os.getenv("EMBEDDER_MODEL_SERVING_PORT"),
)

# Initialize the ReRanker client
print("Initializing the ReRanker client...")
re_ranker = ReRankerClient(
    host=os.getenv("MODEL_SERVING_HOST"),
    port=os.getenv("RE_RANKER_MODEL_SERVING_PORT"),
)

# Initialize the BM25 client
print("Initializing the BM25 client...")
bm25_client = BM25Client(
    storage=minio_client,
    bucket_name=bucket_name,
)

df_processor = ETL(
    folder_name="",
    encoder=None,
    BM25_encoder=None,
    storage_client=None,
    db_client=None,
    collection_name="",
    bucket_name="",
    collection_type=""
)

Initializing the CLIP encoder...
Initializing the Milvus client...
Collection text_to_image_retrieval is loaded successfully!
Collection name_retrieval is loaded successfully!
Collection category_retrieval is loaded successfully!
Initializing the Minio client...
Initializing the LLM encoder...
Initializing the ReRanker client...
Initializing the BM25 client...


In [43]:
def category_detection(query: str):
    # Encode the query using LLM-based embedder
    query_embedding = LLMEncoder.encode_text([query])

    # Search for the top 5 categories
    results = milvus_client.search_vectors(
        collection_type="category",
        collection_name=milvus_collections[2],
        vectors=query_embedding,
        top_k=5,
        index_type=milvus_index_type,
        metric_type=milvus_metric_type
    )

    top_5_categories_list = [
        [res["entity"]["category_name"] for res in result]
        for result in results
    ]

    return top_5_categories_list

def image_retrieval(query: str, top_categories: List[str], top_p: int):
    # Encode the query using CLIP model
    dense_query_embedding = encoder.encode_text([query])

    # Search for the top p documents
    filtering_expression = " or ".join([f'product_category == "{category}"' for category in top_categories])
    results = milvus_client.search_vectors(
        collection_type="image",
        collection_name=milvus_collections[0],
        vectors=dense_query_embedding,
        top_k=top_p,
        index_type=milvus_index_type,
        metric_type=milvus_metric_type,
        filtering_expr=filtering_expression
    )
    image_id_list = [
        [res["entity"]["id"] for res in result]
        for result in results
    ]
    product_name_list = [
        [res["entity"]["product_name"] for res in result]
        for result in results
    ]

    return image_id_list, product_name_list

def image_reranking(query: str, image_id_list: List[str], top_k: int):
    # Get the images from the Minio storage
    retrieved_vectors_limit = 16300
    product_names_list = []
    for i in range(0, len(image_id_list), retrieved_vectors_limit):
        batch_ids = image_id_list[i:i + retrieved_vectors_limit]
        entities = milvus_client.get_vectors(milvus_collections[0], batch_ids)
        names = [entity["product_name"] for entity in entities]
        product_names_list.extend(names)

    # Rerank the images based on the query and the product names
    sorted_indices = re_ranker.rerank(
        query=query,
        documents=product_names_list,
        top_k=top_k
    )

    top_k_image_names = [product_names_list[i] for i in sorted_indices]
    top_k_image_ids = [image_id_list[i] for i in sorted_indices]

    return top_k_image_ids, top_k_image_names


In [205]:
df = pd.read_csv('./dataset/GLAMI-1M-test-en-queries.csv')
df = df.iloc[:10000]

random_index = random.randint(0, len(df))
query = df['query'][random_index]
labels = df['name_en'][random_index]

print(f"Query: {query}")
print(f"Labels: {labels}")

Query: Speed ​​Racer Motorcycles Women's Dark Gray T-Shirt.
Labels: Speed ​​Racer Motorcycles - Pure Women's T-Shirt - XS Dark gray highlights


In [206]:
start_time = time.time()
query = query.replace(".", "")
top_5_categories = category_detection(query)[0]
category_time = time.time() - start_time
print(f"Category detection time: {category_time:.2f} seconds\n")

print(f"The top 5 categories for the query '{query}' are:")
none = [print(f"- {category}") for category in top_5_categories]

Category detection time: 0.04 seconds

The top 5 categories for the query 'Speed ​​Racer Motorcycles Women's Dark Gray T-Shirt' are:
- women's tops tank tops or t-shirts
- men's t-shirts or tank tops
- girl's t-shirts or shirts
- women's undershirts
- women's sweatshirts


In [207]:
top_p = 100
top_images_shown = 10
start_time = time.time()
top_p_image_ids, top_p_image_name = image_retrieval(query, top_5_categories, top_p)
image_retrieval_time = time.time() - start_time
print(f"\nImage retrieval time: {image_retrieval_time:.2f} seconds\n")

print(f"The top {top_images_shown} images in all of the retrieved images for the query '{query}' are:")
none = [print(f"- {product_name}") for product_name in top_p_image_name[0][:top_images_shown]]


Image retrieval time: 0.03 seconds

The top 10 images in all of the retrieved images for the query 'Speed ​​Racer Motorcycles Women's Dark Gray T-Shirt' are:
- Speed ​​Racer 2 - Viper FIT Men's T-Shirt - S Dark gray highlights
- Full Speed ​​- Pure women's t-shirt - XS Dark gray highlights
- Speed ​​Racer Motorcycles - Pure Women's T-Shirt - XS Dark gray highlights
- Two running horses - Pure women's t-shirt - XS Dark gray highlights
- Skate above the channel - Pure women's shirt - XS Black
- The Dirty Wheel - Pure women's shirt - XS Black
- Moto Racer Classic - Viper FIT Men's T-Shirt - S Dark gray highlights
- Ride The Waves - Pure women's t-shirt - XS Dark gray highlights
- Iron Rider - Viper FIT men's t-shirt - S Dark gray highlights
- Born in Czechoslovakia - Pure women's t-shirt - XS Dark gray highlights


In [208]:
top_k = 3
start_time = time.time()
reranked_image_ids, reranked_image_names = image_reranking(query, top_p_image_ids[0], top_k)
reranking_time = time.time() - start_time
print(f"\nImage reranking time: {reranking_time:.2f} seconds\n")

print(f"The top {top_k} reranked images for the query '{query}' are:")
none = [print(f"\"{image_name}\"") for image_name in reranked_image_names]


Image reranking time: 0.53 seconds

The top 3 reranked images for the query 'Speed ​​Racer Motorcycles Women's Dark Gray T-Shirt' are:
"Speed ​​Racer Motorcycles - Pure Women's T-Shirt - XS Dark gray highlights"
"Full Speed ​​- Pure women's t-shirt - XS Dark gray highlights"
"Off Road Racing - Pure women's t-shirt - XS Dark gray highlights"


In [None]:
# Generate HTML for displaying images in rows
html_content = "<div style='display: flex; flex-wrap: wrap;'>"

for i, image_id in enumerate(reranked_image_ids):
    image_path = f"images/{image_id}.jpg"
    image_url = minio_client.presigned_get_object(
        bucket_name=bucket_name,
        object_name=image_path,
        expires=timedelta(hours=1)
    )
    
    # Add image to HTML content
    html_content += f"<div style='flex: 1 0 30%; margin: 5px;'><img src='{image_url}' style='width: 100%; height: auto;'></div>"
    
    # Add a new row after every 3 images
    if (i + 1) % 3 == 0:
        html_content += "<div style='flex-basis: 100%; height: 0;'></div>"

html_content += "</div>"

# Display the HTML content
display(HTML(html_content))

In [210]:
print(f"\nTotal time taken: {category_time + image_retrieval_time + reranking_time:.2f} seconds")


Total time taken: 0.60 seconds


In [1]:
!nvidia-smi

Wed Oct 30 09:05:31 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:01:00.0 Off |                  Off |
| 30%   53C    P8    32W / 300W |  30554MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:21:00.0 Off |                  Off |
| 30%   45C    P8    25W / 300W |      8MiB / 49140MiB |      0%      Default |
|       