# LiT Model

In [None]:
%pip install -q git+https://github.com/google-research/vision_transformer
%pip install -q tensorstore

In [None]:
# Import libraries
import os
import jax
import numpy as np
import pandas as pd
from vit_jax import models

import requests
from io import BytesIO
from PIL import Image

In [None]:
# Initialize model (currently available models: LiT-B16B, LiT-B16B_2, LiT-L16L, LiT-L16S, LiT-L16Ti)
model_name = 'LiT-L16Ti'

lit_model = models.get_model(model_name)
tokenizer = lit_model.get_tokenizer()
image_preprocessing = lit_model.get_image_preprocessing()
lit_variables = lit_model.load_variables()

In [None]:
# Function that returns the top k results from similarity search (using Pinecone's cosine similarity)
def image_embedding(url):
  response = requests.get(url, stream=True)
  image = Image.open(BytesIO(response.content))

  if image.mode == 'RGBA':
      image = image.convert('RGB')

  image = image.resize((500, 500))
  image = np.array(image)
  preprocessed_images = image_preprocessing([image])

  image_features, _, _ = lit_model.apply(lit_variables, images=preprocessed_images)
  # print(image_features.tolist()[0])

  return image_features.tolist()[0]

In [None]:
def text_embeddings(texts):
    query_tokens = tokenizer(texts)
    _, query_features, _ = lit_model.apply(lit_variables, tokens=query_tokens)
    # for vector in query_features:
    #     print(vector)
    
    return query_features.tolist()

# Setup

In [None]:
%pip install -q -U pymilvus

In [None]:
from pymilvus import connections
from pymilvus import Collection
from pprint import pprint
import json

In [None]:
connections.connect(
  alias="default", 
  uri='https://in01-1efb60df0cf919e.aws-us-west-2.vectordb.zillizcloud.com:19535', # Endpoint URI obtained from Zilliz Cloud
  secure=True,
  user='db_admin', # Username specified when you created this database
  password='zillizDatabase!' # Password specified when you created this database
)

In [None]:
collection = Collection("Product")
# collection.create_partition("eateries")

# Insert Data

In [None]:
data = [
    {"vector": [], "name": ""},
    {"vector": [], "name": ""},
    {"vector": [], "name": ""},
    {"vector": [], "name": ""}
]

collection.insert(data)

# Query

# Load collection to memory before conducting vector similarity search
collection.load()

In [None]:
def vector_search(texts): # list of text
    vectors = text_embeddings(texts)
    
    results = collection.search(
        data=vectors,
        anns_field="vector", 
        param={"metric_type": "IP", "params": {"ef": 3}, "offset": 0},
        limit=3, 
        expr=None,
        # set the names of the fields you want to retrieve from the search result.
        output_fields=['name'],
        consistency_level="Strong"
    )
    print(results)

    # get the value of an output field specified in the search request (vector fields are not supported yet.)
    # for hit in results[0]:
    #     print(hit.distance)
    #     print(hit.entity.get('name'))

In [None]:
data = [
    "white table",
    "hamburger",
]

vector_search(data)

In [None]:
# Release the collection loaded in Milvus to reduce memory consumption when the search is completed.
collection.release()