In [5]:
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
import glob, os
from tqdm import tqdm
import torch
from PIL import Image

from transformers import AutoProcessor, CLIPModel, CLIPProcessor
from sentence_transformers import SentenceTransformer

In [None]:
COLLECTION_NAME = 'ayam_collection'  # Collection name
DIMENSION = 512  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

BATCH_SIZE = 128
TOP_K = 3

MODEL_NAME = 'clip-ViT-B-32' #Using CLIP model

client = MilvusClient("demo.db")
model = SentenceTransformer(MODEL_NAME)

data_dir = "../../can_food"

In [None]:
# Creating schema
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=False),
    FieldSchema(name='category', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
    FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=int(DIMENSION))
]
schema = CollectionSchema(fields=fields)

index_params = client.prepare_index_params()

index_params.add_index(
    field_name="image_embedding",
    index_name="image_embedding_index",
    index_type="IVF_FLAT",
    metric_type="COSINE",
    params= {"nlist": 1000}, # Can be adjust depends on the total list of datas
)

# Create the collection based on schema and index params, include the dimension

# client.create_collection(COLLECTION_NAME, schema=schema, index_params=index_params, dimension=DIMENSION)

In [None]:
# Deleting the collection
# client.drop_collection(COLLECTION_NAME)

In [13]:
print(client.describe_collection(COLLECTION_NAME))
# print(client.list_collections())

{'collection_name': 'ayam_collection', 'auto_id': False, 'num_shards': 0, 'description': '', 'fields': [{'field_id': 100, 'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'params': {}, 'is_primary': True}, {'field_id': 101, 'name': 'category', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 200}}, {'field_id': 102, 'name': 'image_embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}], 'functions': [], 'aliases': [], 'collection_id': 0, 'consistency_level': 0, 'properties': {}, 'num_partitions': 0, 'enable_dynamic_field': False}


In [14]:
# Load images from directory and generate embedding
image_paths = glob.glob(os.path.join(data_dir, "**/*.jpg"))
data = []

for i, filepath in enumerate(tqdm(image_paths, desc= "Generating embedding ..")):
  try:
    image = Image.open(filepath)
    category = os.path.basename(os.path.dirname(filepath))
    image_embedding = model.encode(image)
    # Assign the ids
    image_id = i
    data.append({"id" : image_id, "image_embedding" : image_embedding, "category" : category})
  except Exception as e:
    print(f"Skipping file: {filepath} due to an error occurs during an embedding process: \n{e}" )
    continue

# Inserting data into milvus
mr = client.insert(collection_name= COLLECTION_NAME, data=data)

print(f"Total number of images inserted : {mr['insert_count']} ")

Generating embedding ..: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:02<00:00,  7.40it/s]


Total number of images inserted : 21 


In [16]:
# Image Query
query_image = "./test.jpg"
query_image_read = Image.open(query_image)


results = client.search(     collection_name=COLLECTION_NAME,
                             data=[model.encode(query_image_read)],
                             output_fields=["category"],
                             search_params={"metric_type": "COSINE"},
                             limit=2
                             )

search_results = results[0]
print(search_results)
     


[{'id': 13, 'distance': 0.8256827592849731, 'entity': {'category': 'sardines'}}, {'id': 14, 'distance': 0.8064610362052917, 'entity': {'category': 'sardines'}}]


In [17]:
# Text Query
query_text = "mackerel in tin"

vector_embedding = model.encode([query_text], convert_to_numpy=True, show_progress_bar=False)

results = client.search(     collection_name=COLLECTION_NAME,
                             data=vector_embedding,
                             output_fields=["category"],
                             search_params={"metric_type": "COSINE"},
                             limit=2
                             )

print(results[0])



[{'id': 3, 'distance': 0.32279837131500244, 'entity': {'category': 'mackerel'}}, {'id': 2, 'distance': 0.3210201561450958, 'entity': {'category': 'mackerel'}}]
