In [1]:
# ! pip install pymilvus==2.3.1

In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
import os
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.utils import load_img, save_img, img_to_array
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D
from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema

### Read images

In [3]:
images = []
for i in os.listdir(os.path.join("Vegetable Images", "test")):
    for j in os.listdir(os.path.join("Vegetable Images", "test", i))[:100]: # onsider 100 images from each class
        images.append(os.path.join("Vegetable Images", "test", i, j)) # appending image paths to images list

In [4]:
len(images)

1500

### Vectorization using VGG19 model fine tuned on vegetable images

In [5]:
class ImageVectorizer:
    '''
    Get vector representation of an image using VGG19 model fine tuned on vegetable images for classification
    '''
    
    def __init__(self):
        self.__model = self.get_model()
    
    @staticmethod
    def get_model():
        model = load_model('vegetable_classification_model_vgg.h5') # loading saved VGG model finetuned on vegetable images for classification
        top = model.get_layer('block5_pool').output
        top = GlobalAveragePooling2D()(top)
        model = Model(inputs=model.input, outputs=top)
        return model
    
    def vectorize(self, img_path: str):
        model = self.__model
        test_image = load_img(img_path, color_mode="rgb", target_size=(224, 224))
        test_image = img_to_array(test_image)
        test_image = preprocess_input(test_image)
        test_image = np.array([test_image])
        return model(test_image).numpy()[0]

In [6]:
vectorizer = ImageVectorizer()

In [7]:
# getting max length of image path to be used for VARCHAR while defining schema
max_path_len = max([len(s) for s in images])
max_path_len

43

In [8]:
# Reading milvus URI & API token from secrets.env
load_dotenv('secrets.env')
uri = os.environ.get("URI")
token = os.environ.get("TOKEN")

In [9]:
# connecting to db
connections.connect("default", uri=uri, token=token)
print(f"Connected to DB")

Connected to DB


In [10]:
collection_name = os.environ.get("COLLECTION_NAME")
check_collection = utility.has_collection(collection_name)

In [11]:
if check_collection:
    drop_result = utility.drop_collection(collection_name)
    print("Droped Existing collection")

Droped Existing collection


In [12]:
# Creating collection schema
dim = 512 # image vector dim
image_id = FieldSchema(name="image_id", dtype=DataType.INT64, is_primary=True, description="primary id") # primary key
image_embed_field = FieldSchema(name="image_vector", dtype=DataType.FLOAT_VECTOR, dim=dim) # image vector
image_desc = FieldSchema(name="image_path", dtype=DataType.VARCHAR, max_length=(max_path_len + 50), # using max_path_len to specify VARCHAR len 
                           is_primary=False, description="path of the image") # path of image
schema = CollectionSchema(fields=[image_id, image_embed_field, image_desc], 
                          auto_id=False, description="collection of vegetable images")
print(f"Creating the collection")
collection = Collection(name=collection_name, schema=schema)
print(f"Schema: {schema}")
print("Success!")

Creating the collection
Schema: {'auto_id': False, 'description': 'collection of vegetable images', 'fields': [{'name': 'image_id', 'description': 'primary id', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'image_vector', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}, {'name': 'image_path', 'description': 'path of the image', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 93}}]}
Success!


In [13]:
# Preparing data to load
image_id = []
image_path = []
image_vector = []
for i in tqdm(range(len(images))):
    image_id.append(i)
    image_path.append(images[i])
    image_vector.append(vectorizer.vectorize(image_path[i]))
docs = [image_id, image_vector, image_path]

100%|██████████| 1500/1500 [06:08<00:00,  4.07it/s]


In [16]:
# insert images into collection
ins_resp = collection.insert(docs)
ins_resp # insertion result

(insert count: 1500, delete count: 0, upsert count: 0, timestamp: 444634998929883137, success count: 1500, err count: 0)

In [17]:
# creating index on vector field (image_vector)
# metric type: L2 (euclidean dist). supported: [L2 IP]
index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}} 
collection.create_index(field_name='image_vector', index_params=index_params)

alloc_timestamp unimplemented, ignore it


Status(code=0, message=)