# Qdrant (Setup)

In [None]:
%pip install -q qdrant-client

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import VectorParams
from qdrant_client.http.models import PointStruct
from qdrant_client.http.models import SearchRequest

In [None]:
client = QdrantClient(
    url="https://7a80db84-f2a1-4663-9fd1-cc5618c1d30a.us-east-1-0.aws.cloud.qdrant.io:6333", 
    api_key="ST2Dq3kDGAjEf_vFl_jXdTHBVaqPmTyImyK7c9RqqCRzqFm1V-kdLQ",
    prefer_grpc=True
)

In [None]:
# Create a new collection
# for high precision and high speed search: https://qdrant.tech/documentation/tutorials/optimize/
client.recreate_collection(
    collection_name="products",
    vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE),
)

# print("Create collection reponse:", client)

# Check that collection was created
# collection_info = client.get_collection(collection_name="products")
# pprint(collection_info)

# LiT (Setup)

In [None]:
%pip install -q git+https://github.com/google-research/vision_transformer
%pip install tensorstore

In [None]:
# Import libraries
import os
import jax
import numpy as np
import pandas as pd
from vit_jax import models

In [None]:
# Initialize model (currently available models: LiT-B16B, LiT-B16B_2, LiT-L16L, LiT-L16S, LiT-L16Ti)
model_name = 'LiT-L16Ti'

lit_model = models.get_model(model_name)
tokenizer = lit_model.get_tokenizer()
image_preprocessing = lit_model.get_image_preprocessing()
lit_variables = lit_model.load_variables()

# Supabase

In [None]:
%pip install -q supabase

In [None]:
from supabase import create_client, Client

In [None]:
os.environ["SUPABASE_URL"] = "https://kdybpofgbqvrpbsoorkx.supabase.co"
os.environ["SUPABASE_KEY"] = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImtkeWJwb2ZnYnF2cnBic29vcmt4Iiwicm9sZSI6ImFub24iLCJpYXQiOjE2ODI1NjcyODAsImV4cCI6MTk5ODE0MzI4MH0.e_JPwNj8UDL_E-yxnsVGbgsR0V7aehhvmG9K2VUnB6Q"
url: str = os.environ.get("SUPABASE_URL")
key: str = os.environ.get("SUPABASE_KEY")
supabase: Client = create_client(url, key)

# Qdrant (Insert Data)

In [None]:
# Upsert data
import requests
from io import BytesIO
from PIL import Image
import uuid

data = supabase.table('tayho').select('*').execute()

for row in data.data:
  if len(row['image_urls']) > 0:
    converted_images = []
    preprocessed_images = []
    
    # Make images to preprocessable format
    for image_url in row['image_urls']:
      response = requests.get(image_url, stream=True)
      image = Image.open(BytesIO(response.content))

      if image.mode == 'RGBA':
          image = image.convert('RGB')

      image = image.resize((500, 500))
      image = np.array(image)
      converted_images.append(image)

    converted_images = np.array(converted_images)
    preprocessed_images = image_preprocessing(converted_images)
    image_features, _, _ = lit_model.apply(lit_variables, images=preprocessed_images)
    
    try:
        for i, image_url in enumerate(row['image_urls']):
            client.upsert(
                collection_name="products",
                wait=True,
                points=[
                    PointStruct(id=uuid.uuid1().int>>64, vector=image_features.tolist()[i], payload={"url": f"{image_url}"})
                ]
            )
    except Exception as e:
        print(f"Exception: {e}")
        continue

# Query (Single Vector Search)
- Docs: https://qdrant.tech/documentation/concepts/search/
- Geospatial search: https://geo.rocks/post/geospatial-vector-search-qdrant/#6-semantic-queries-with-geospatial-filters
- Recommendation (reward/punish): https://qdrant.tech/documentation/concepts/search/#recommendation-api

In [None]:
def single_query(prompt):
    query_tokens = tokenizer([prompt])
    _, query_features, _ = lit_model.apply(lit_variables, tokens=query_tokens)

    result = client.search(
        collection_name="products",
        query_vector=query_features.tolist()[0], 
        limit=5
    )
    # print(result)
    
    for item in result:
        print(item.score)
        print(item.payload['url'])

In [None]:
single_query("sushi on black table")

# Query (Batch Search)
https://blog.qdrant.tech/batch-vector-search-with-qdrant-8c4d598179d5

In [None]:
def batch_query(prompt1, prompt2):
    query_tokens = tokenizer([prompt1, prompt2])
    _, query_features, _ = lit_model.apply(lit_variables, tokens=query_tokens)

    result = client.search_batch(
        collection_name="products",
        requests=[
            SearchRequest(
                vector=query_features.tolist()[0],
                with_payload=True,
                limit=2,
            ),
            SearchRequest(
                vector=query_features.tolist()[1],
                with_payload=True,
                limit=2,
            )
        ]
    )
    # print(result)
    
    data = []
    
    for item in result:
        for product in item:
            data.append({"score": product.score, "url": product.payload['url']})
            # print(product.score)
            # print(product.payload['url'])
            
    sorted_data = sorted(data, key=lambda x:x["score"], reverse=True)
    # print(sorted_data)
    
    for item in sorted_data:
        print(item['score'])
        print(item['url'])

In [None]:
batch_query("noodles", "meat")