# Data preparation


I'm using a [small plant disease dataset](https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset) from Kaggle. Although the types of plants are not extensive, there are sufficient images for each type of plant. In total, there 87K images of healthy and diseased crop leaves which is categorized into 38 different classes (i.e. plant - disease pair). 

In [35]:
import dotenv
import os
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, MilvusClient
from transformers import AutoModel
import matplotlib.pyplot as plt
import base64
import io
from PIL import Image
import math
from transformers import AutoModel

model = AutoModel.from_pretrained('jinaai/jina-clip-v2', trust_remote_code=True)

def get_image_embedding(image_path):
    image_embeddings = model.encode_image([image_path], truncate_dim=None)
    return image_embeddings[0]





In [34]:
# Define the path to the Kaggle dataset's training directory
train_dir = 'kaggle-data/train'

# Initialize a list to hold dictionaries of plant type and disease
plant_disease_data = []
plant_disease_embeddings = []

# Loop through all folders in the training directory
for folder_name in os.listdir(train_dir):
    folder_path = os.path.join(train_dir, folder_name)
    if '.DS_Store' in folder_path:
        continue
    if os.path.isdir(folder_path) and '__' in folder_name:
        try:
            # Split the folder name into plant type and disease
            plant, disease = folder_name.split('__')
            base_dict = {'plant': plant, 'disease': disease.removeprefix('_')}
            plant_disease_data.append({'plant': plant, 'disease': disease.removeprefix('_')})
        except ValueError:
            print(f"Skipping folder with unexpected format: {folder_name}")

    image_files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    image_files = image_files[:20]
    for filename in image_files:
        image_path = os.path.join(folder_path, filename)
        
        embedding = get_image_embedding(image_path)
        
        new_dict = base_dict.copy()
        new_dict['embedding'] = embedding
        
        plant_disease_embeddings.append(new_dict)
    
# Output the result
for entry in plant_disease_data:
    print(entry)

{'plant': 'Strawberry', 'disease': 'healthy'}
{'plant': 'Grape', 'disease': 'Black_rot'}
{'plant': 'Potato', 'disease': 'Early_blight'}
{'plant': 'Blueberry', 'disease': 'healthy'}
{'plant': 'Corn_(maize)', 'disease': 'healthy'}
{'plant': 'Tomato', 'disease': 'Target_Spot'}
{'plant': 'Peach', 'disease': 'healthy'}
{'plant': 'Potato', 'disease': 'Late_blight'}
{'plant': 'Tomato', 'disease': 'Late_blight'}
{'plant': 'Tomato', 'disease': 'Tomato_mosaic_virus'}
{'plant': 'Pepper,_bell', 'disease': 'healthy'}
{'plant': 'Orange', 'disease': 'Haunglongbing_(Citrus_greening)'}
{'plant': 'Tomato', 'disease': 'Leaf_Mold'}
{'plant': 'Grape', 'disease': 'Leaf_blight_(Isariopsis_Leaf_Spot)'}
{'plant': 'Cherry_(including_sour)', 'disease': 'Powdery_mildew'}
{'plant': 'Apple', 'disease': 'Cedar_apple_rust'}
{'plant': 'Tomato', 'disease': 'Bacterial_spot'}
{'plant': 'Grape', 'disease': 'healthy'}
{'plant': 'Tomato', 'disease': 'Early_blight'}
{'plant': 'Corn_(maize)', 'disease': 'Common_rust_'}
{'plan

In [12]:
len(plant_disease_embeddings[:5][-1]['embedding'])

1024

In [17]:
from pymilvus import MilvusClient, DataType

dotenv.load_dotenv()

client = MilvusClient(
    uri=os.getenv("ZILLIZ_URI"),
    token=os.getenv("ZILLIZ_TOKEN")
)

collection_name = "plant_disease_embeddings"
embedding_dim = 1024  # jinaai/jina-clip-v2

if client.has_collection(collection_name):
    client.drop_collection(collection_name)

schema = MilvusClient.create_schema()

schema.add_field(
    field_name="_id",
    datatype=DataType.INT64,
    is_primary=True,
    auto_id=True,
)

schema.add_field(
    field_name="embedding",
    datatype=DataType.FLOAT_VECTOR,
    dim=1024
)
schema.add_field(
    field_name="plant",
    datatype=DataType.VARCHAR,
    max_length=1024
)
schema.add_field(
    field_name="disease",
    datatype=DataType.VARCHAR,
    max_length=1024
)
index_params = client.prepare_index_params()

index_params.add_index(
    field_name="embedding",
    index_name="embedding_index",
    index_type="AUTOINDEX",
    metric_type="COSINE"
)
client.create_collection(
    collection_name=collection_name,
    schema=schema,
    index_params=index_params
)

insert_data = plant_disease_embeddings
insert_result = client.insert(collection_name=collection_name, data=insert_data)


In [36]:
image_path = 'kaggle-data/train/Apple___Black_rot/0bc40cc3-6a85-480e-a22f-967a866a56a1___JR_FrgE.S 2784.JPG'
#image_path = 'kaggle-data/train/Apple___Cedar_apple_rust/0a41c25a-f9a6-4c34-8e5c-7f89a6ac4c40___FREC_C.Rust 9807_90deg.JPG'

vectors = get_image_embedding(image_path)
result = client.search(collection_name, filter="plant=='Apple'", data=[vectors], output_fields=['plant','disease']) 
for r in result[0]:
    print(r)

{'_id': 461508499762535689, 'distance': 0.9810512065887451, 'entity': {'plant': 'Apple', 'disease': 'Black_rot'}}
{'_id': 461508499762535688, 'distance': 0.9670436382293701, 'entity': {'plant': 'Apple', 'disease': 'Black_rot'}}
{'_id': 461508499762535648, 'distance': 0.9554774165153503, 'entity': {'plant': 'Apple', 'disease': 'Apple_scab'}}
{'_id': 461508499762535691, 'distance': 0.9520385265350342, 'entity': {'plant': 'Apple', 'disease': 'Black_rot'}}
{'_id': 461508499762535600, 'distance': 0.9468621015548706, 'entity': {'plant': 'Apple', 'disease': 'Cedar_apple_rust'}}
{'_id': 461508499762535601, 'distance': 0.9433561563491821, 'entity': {'plant': 'Apple', 'disease': 'Cedar_apple_rust'}}
{'_id': 461508499762535690, 'distance': 0.9401673674583435, 'entity': {'plant': 'Apple', 'disease': 'Black_rot'}}
{'_id': 461508499762535687, 'distance': 0.9359273910522461, 'entity': {'plant': 'Apple', 'disease': 'Black_rot'}}
{'_id': 461508499762535649, 'distance': 0.9349789619445801, 'entity': {'p