In [1]:
import csv
from glob import glob
from pathlib import Path
from statistics import mean

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from towhee.dc2 import pipe, ops, DataCollection
from PIL import Image
from datasets import load_dataset

# Towhee parameters
MODEL = 'resnet50'
DEVICE = None # if None, use default device (cuda is enabled if available)

# Milvus parameters
HOST = '127.0.0.1'
PORT = '19530'
TOPK = 10
DIM = 2048 # dimension of feature vectors for "resnet50"
COLLECTION_NAME = 'reverse_image_search'
INDEX_TYPE = 'IVF_FLAT'
METRIC_TYPE = 'L2'

# path to csv (column_1 indicates image path) OR a pattern of image paths
INSERT_SRC = 'reverse_image_search.csv'
INSERT_SRC_TINY = "tiny_imagenet_200_reverse_image_search.csv"
QUERY_SRC = './test/*/*.JPEG'

## Create Image Embeddings

In [2]:
# Load image path
def load_image(x):
    if x.endswith('csv'):
        with open(x) as f:
            reader = csv.reader(f)
            next(reader)
            for item in reader:
                yield item[1]
    else:
        for item in glob(x):
            yield item

def display_img_vec(img):
    display(img)
    return img

p_embed = (
    pipe.input('img')
        .map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE))
)

In [3]:
tiny_imagenet = load_dataset('Maysee/tiny-imagenet', split='train')[:1]["image"][0]
p_display = p_embed.output('vec')
DataCollection(p_display(tiny_imagenet)).show()



vec
"[0.0, 0.0, 0.00839724, ...] shape=(2048,)"


## Create Database Collection

In [4]:
# Create milvus collection (delete first if exists)
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
        FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=500, 
                    is_primary=True, auto_id=False),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    index_params = {
        'metric_type': METRIC_TYPE,
        'index_type': INDEX_TYPE,
        'params': {"nlist": 2048}
    }
    collection.create_index(field_name='embedding', index_params=index_params)
    return collection

In [5]:
# Connect to Milvus service
connections.connect(host=HOST, port=PORT)

# Create collection
collection = create_milvus_collection(COLLECTION_NAME, DIM)
print(f'A new collection created: {COLLECTION_NAME}')



A new collection created: reverse_image_search


## Insert Image embeddings into Milvus

In [6]:
tiny_imagenet = load_dataset('Maysee/tiny-imagenet', split='train')

def write_csv():
    with open('tiny_imagenet.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow(['id', 'image'])
        # for i, path in enumerate(glob('./dataset/tiny-imagenet-200/train/*/*/*.JPEG')):
        #     writer.writerow([i, path])
        for i, img in enumerate(tiny_imagenet):
            writer.writerow([i, img["image"]])

write_csv()




In [None]:
# Load image path
def load_image(x):
    if x.endswith('csv'):
        with open(x) as f:
            reader = csv.reader(f)
            next(reader)
            for item in reader:
                yield item[1]
    else:
        for item in glob(x):
            yield item

In [13]:
from PIL import Image
# Insert pipeline
def load_Image(x):
    yield x["image"]

p_embed_insert = (
    pipe.input('dataset')
        .flat_map('dataset', 'img', load_Image)
        .map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE))
)

p_insert = (
        p_embed_insert.map(('img', 'vec'), 'mr', ops.ann_insert.milvus_client(
                    host=HOST,
                    port=PORT,
                    collection_name=COLLECTION_NAME
                    ))
          .output('mr')
)

In [9]:
tiny_imagenet = load_dataset('Maysee/tiny-imagenet', split='train')[:5]
tiny_imagenet



{'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>],
 'label': [0, 0, 0, 0, 0]}

In [11]:
for i in tiny_imagenet:
    DataCollection(p_insert(i)).show()

RuntimeError: Node-load_Image-0 runs failed, error msg: string indices must be integers, Traceback (most recent call last):
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/towhee/runtime/nodes/node.py", line 170, in process
    if self.process_step():
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/towhee/runtime/nodes/_flat_map.py", line 54, in process_step
    for output in outputs:
  File "/var/folders/kg/v9q1cm1x6q9cqyr9zrc3f9480000gn/T/ipykernel_99402/2565511194.py", line 4, in load_Image
    yield x["image"]
TypeError: string indices must be integers



In [15]:
# Insert data to Milvus from dataset
tiny_imagenet = load_dataset('Maysee/tiny-imagenet', split='train')[0]
print(tiny_imagenet)
# print(tiny_imagenet[0])




{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64 at 0x1698E9400>, 'label': 0}


In [16]:
# Dataset({
#     features: ['image', 'label'],
#     num_rows: 100000
# })
p_insert(tiny_imagenet)

# Check collection
collection = Collection(name=COLLECTION_NAME)
print('Number of data inserted:', collection.num_entities)

RuntimeError: Node-ann-insert/milvus-client-2 runs failed, error msg: 'JpegImageFile' object is not subscriptable, Traceback (most recent call last):
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/towhee/runtime/nodes/node.py", line 156, in _call
    return True, self._op(*inputs), None
  File "/Users/nilsertle/.towhee/operators/ann-insert/milvus_client/main/milvus_client.py", line 45, in __call__
    mr = self._collection.insert(row)
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/pymilvus/orm/collection.py", line 424, in insert
    check_insert_data_schema(self._schema, data)
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/pymilvus/orm/schema.py", line 312, in check_insert_data_schema
    infer_fields = parse_fields_from_data(data)
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/pymilvus/orm/schema.py", line 343, in parse_fields_from_data
    fields = [FieldSchema("", infer_dtype_bydata(d[0])) for d in data]
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/pymilvus/orm/schema.py", line 343, in <listcomp>
    fields = [FieldSchema("", infer_dtype_bydata(d[0])) for d in data]
  File "/Users/nilsertle/Code/bachelor-thesis/venv/lib/python3.9/site-packages/pymilvus/orm/types.py", line 123, in infer_dtype_bydata
    elem = data[0]
TypeError: 'JpegImageFile' object is not subscriptable



In [8]:
# Check collection
collection = Collection(name=COLLECTION_NAME)
print('Number of data inserted:', collection.num_entities)

Number of data inserted: 0


## Search the databse

In [None]:
# Search pipeline
p_search_pre = (
        p_embed.map('vec', ('search_res'), ops.ann_search.milvus_client(
                    host=HOST, port=PORT, limit=TOPK,
                    collection_name=COLLECTION_NAME))
               .map('search_res', 'pred', lambda x: [str(Path(y[0]).resolve()) for y in x])
#                .output('img_path', 'pred')
)
p_search = p_search_pre.output('img_path', 'pred')

In [None]:
# Search for example query image(s)
collection.load()
dc = p_search('test/goldfish/*.JPEG')

# Display search results with image paths
DataCollection(dc).show()

In [None]:
# Display search results with images, no need for implementation

import cv2
from towhee.types.image import Image

def read_images(img_paths):
    imgs = []
    for p in img_paths:
        imgs.append(Image(cv2.imread(p), 'BGR'))
    return imgs

p_search_img = (
    p_search_pre.map('pred', 'pred images', read_images)
                .output('img', 'pred images')
)
DataCollection(p_search_img('test/dishwasher/*.JPEG')).show()

## Measure performance

In [None]:
# Get ground truth by path of query image
def ground_truth(path):
    train_path = str(Path(path).parent).replace('test', 'train')
    return [str(Path(x).resolve()) for x in glob(train_path + '/*.JPEG')]

# Calculate Average Precision by a list of predictions and a list of ground truths
def get_ap(pred: list, gt: list):
    ct = 0
    score = 0.
    for i, n in enumerate(pred):
        if n in gt:
            ct += 1
            score += (ct / (i + 1))
    if ct == 0:
        ap = 0
    else:
        ap = score / ct
    return ap

In [None]:
# Evaluation pipeline returns AP
p_eval = (
    p_search_pre.map('img_path', 'gt', ground_truth)
                .map(('pred', 'gt'), 'ap', get_ap)
                .output('ap')
)

In [None]:
import time
import psutil
import os

# Get the current process ID
pid = os.getpid()

# Create a Process object for the current process
process = psutil.Process(pid)

# Run evaluation pipeline over all test data
start = time.time()
bm = p_eval('test/*/*.JPEG')
end = time.time()

# Get the CPU and memory usage
cpu_percent = process.cpu_percent()
mem_percent = process.memory_percent()

# Group AP in a list
res = DataCollection(bm).to_list()

# Calculate mAP
mAP = mean([x.ap for x in res])

# mean average precision at 10
print(f'mAP@{TOPK}: {mAP}')
# queries per second
print(f'qps: {len(res) / (end - start)}')
# time per query
print(f'time per query: {(end - start) / len(res)}')
# Print the results
print(f"CPU usage: {cpu_percent}%")
print(f"Memory usage: {mem_percent}%")