# Testing to search in Weaviate H&M catalogue

In [None]:
import pandas as pd

from pathlib import Path

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import requests
from io import BytesIO
from PIL import Image

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import my_mirror_on_cloud.embedding_manager as em
import my_mirror_on_cloud.weaviate_manager as wm

In [None]:
from my_mirror_on_cloud.utils import clean_name

### Search by Text Query using FashionCLIP

In [None]:
## Query parameters
query = "red dress with white dots"
model = "fashion-clip"
collection = "catalogue_HM"
limit=5

## Get the vector for the query
embedder = em.create_embedder(model_name=model)
query_vector = embedder.encode_texts([query], batch_size=1)[0] 

## Search Weaviate
with wm.WeaviateManager() as wmgr:
    results = wmgr.search_by_vector(
        query_vector,
        limit=limit,
        target_vector=f"embedding_{clean_name(model)}",
        collection_name=collection,
    )

In [None]:
def get_url_from_image_path(image_path):
    base = Path('../data/h-and-m-personalized-fashion-recommendations')
    relative_path = Path(image_path).relative_to(base)
    base_url = "https://storage.googleapis.com/catalogue_hm/"
    return base_url + str(relative_path)

def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    return image

In [None]:
cols=5
rows=cols//limit
fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
axes = axes.flatten()  

# images = [item.properties["image_path"] for item in results.objects] ## version for local files
images = [get_url_from_image_path(item.properties["image_path"]) for item in results.objects] ## version for cloud files

for i, ax in enumerate(axes):
    # img = mpimg.imread(str(images[i]))
    img = load_image_from_url(images[i])
    ax.imshow(img)
    ax.set_title(Path(f"{images[i]}").name, fontsize=8)  
    ax.axis('off')  
else:
    ax.axis('off')  

plt.tight_layout()
plt.show()

In [None]:
## Query parameters
query = "red dress with white dots"
model = "fashion-clip"
collection = "catalogue_HM"
limit=5

## Get the vector for the query
embedder = em.create_embedder(model_name=model)
query_vector = embedder.encode_texts([query], batch_size=1)[0] 

## Search Weaviate
with wm.WeaviateManager() as wmgr:
    results = wmgr.search_by_vector(
        query_vector,
        limit=limit,
        target_vector=f"embedding_{clean_name(model)}",
        collection_name=collection,
    )

In [None]:
## Query parameters
query = "red dress with white dots"
model = "fashion-clip"
collection = "catalogue_HM"
limit=5
properties=["product_type_original", "colour_original"] ## properties to look for exact match
alpha=0.5

## Get the vector for the query
embedder = em.create_embedder(model_name=model)
query_vector = embedder.encode_texts([query], batch_size=1)[0] 

## Search Weaviate
with wm.WeaviateManager() as wmgr:
    results = wmgr.search_hybrid(
        query=query,
        vector=query_vector,
        collection_name=collection,
        target_vector=f"embedding_{clean_name(model)}",
        alpha=alpha,
        limit=limit,
        query_properties=properties
    )

In [None]:
cols=5
rows=cols//limit
fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
axes = axes.flatten()  

# images = [item.properties["image_path"] for item in results.objects] ## version for local files
images = [get_url_from_image_path(item.properties["image_path"]) for item in results.objects] ## version for cloud files

for i, ax in enumerate(axes):
    # img = mpimg.imread(str(images[i]))
    img = load_image_from_url(images[i])
    ax.imshow(img)
    ax.set_title(Path(f"{images[i]}").name, fontsize=8)  
    ax.axis('off')  
else:
    ax.axis('off')  

plt.tight_layout()
plt.show()

### Search from an image 

In [None]:
## Query parameters
query = "../data/farfetch/images/12809784_3.jpg" ### image relative path
model = "fashion-clip"
collection = "catalogue_HM"
limit=5

## Get the vector for the query
embedder = em.create_embedder(model_name=model)
query_vector = embedder.encode_images([query], batch_size=1)[0] 

## Search Weaviate
with wm.WeaviateManager() as wmgr:
    results = wmgr.search_by_vector(
        query_vector,
        limit=limit,
        target_vector=f"embedding_{clean_name(model)}",
        collection_name=collection,
    )

In [None]:
cols=6
rows=cols//limit
fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
axes = axes.flatten()  

images = [query]  # First image is the query image
images += [item.properties["image_path"] for item in results.objects]
for i, ax in enumerate(axes):
    img = mpimg.imread(str(images[i]))
    ax.imshow(img)
    ax.set_title(Path(f"{images[i]}").name, fontsize=8)  
    ax.axis('off')  
    if i==0:
        ax.set_title("Query", fontsize=8) 
    else:
        ax.set_title(Path(f"{images[i]}").name, fontsize=8)  
 
else:
    ax.axis('off')  

plt.tight_layout()
plt.show()

In [None]:
## Query parameters
text = "red dress"
image_query = "../data/farfetch/images/12809784_3.jpg" ### image relative path
model = "fashion-clip"
collection = "catalogue_HM"
limit=5
properties=["product_type_original", "colour_original"] ## properties to look for exact match
alpha=0.5

## Get the vector for the query
embedder = em.create_embedder(model_name=model)
query_vector = embedder.encode_images([query], batch_size=1)[0] 

## Search Weaviate
with wm.WeaviateManager() as wmgr:
    results = wmgr.search_hybrid(
        query=text,
        vector=query_vector,
        collection_name=collection,
        target_vector=f"embedding_{clean_name(model)}",
        alpha=alpha,
        limit=limit,
        query_properties=properties
    )

In [None]:
cols=6
rows=cols//limit
fig, axes = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
axes = axes.flatten()  

images = [query]  # First image is the query image
images += [get_url_from_image_path(item.properties["image_path"]) for item in results.objects] ## version for cloud files

for i, ax in enumerate(axes):
    # img = mpimg.imread(str(images[i]))
    img = load_image_from_url(images[i]) if i>0 else mpimg.imread(str(images[i]))
    ax.imshow(img)
    ax.set_title(Path(f"{images[i]}").name, fontsize=8)  
    ax.axis('off')  
    if i==0:
        ax.set_title("Query", fontsize=8) 
    else:
        ax.set_title(Path(f"{images[i]}").name, fontsize=8)  
 
else:
    ax.axis('off')  

plt.tight_layout()
plt.show()