In [2]:
import os
import chromadb
from tqdm import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction

In [3]:
ROOT = './data'
CLASS_NAME = sorted(list(os.listdir(f'{ROOT}/train')))
HNSW_SPACE = "hnsw:space"

In [4]:
def get_files_path(path):
    file_path = []
    for label in CLASS_NAME:
        label_path = path + "/" + label
        filenames = os.listdir(label_path)
        for filename in filenames:
            filepath = label_path + "/" + filename
            file_path.append(filepath)

    return file_path

data_path = "data/train"
files_path = get_files_path(data_path)
files_path

['data/train/African_crocodile/n01697457_10393.JPEG',
 'data/train/African_crocodile/n01697457_104.JPEG',
 'data/train/African_crocodile/n01697457_1331.JPEG',
 'data/train/African_crocodile/n01697457_14906.JPEG',
 'data/train/African_crocodile/n01697457_18587.JPEG',
 'data/train/African_crocodile/n01697457_260.JPEG',
 'data/train/African_crocodile/n01697457_5586.JPEG',
 'data/train/African_crocodile/n01697457_8136.JPEG',
 'data/train/African_crocodile/n01697457_8331.JPEG',
 'data/train/African_crocodile/n01697457_85.JPEG',
 'data/train/American_egret/n02009912_1358.JPEG',
 'data/train/American_egret/n02009912_13895.JPEG',
 'data/train/American_egret/n02009912_15872.JPEG',
 'data/train/American_egret/n02009912_16896.JPEG',
 'data/train/American_egret/n02009912_26245.JPEG',
 'data/train/American_egret/n02009912_36395.JPEG',
 'data/train/American_egret/n02009912_4403.JPEG',
 'data/train/American_egret/n02009912_5700.JPEG',
 'data/train/American_egret/n02009912_7609.JPEG',
 'data/train/Ame

In [5]:
def plot_results(image_path, files_path, results):
    query_image = Image.open(image_path).resize((448,448))
    images = [query_image]
    class_name = []
    for id_img in results['ids'][0]:
        id_img = int(id_img.split('_')[-1])
        img_path = files_path[id_img]
        img = Image.open(img_path).resize((448,448))
        images.append(img)
        class_name.append(img_path.split('/')[2])

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))

    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i])
        if i == 0:
            ax.set_title(f"Query Image: {image_path.split('/')[2]}")
        else:
            ax.set_title(f"Top {i+1}: {class_name[i-1]}")
        ax.axis('off')
    plt.show()

In [6]:
embedding_function = OpenCLIPEmbeddingFunction()

def get_single_image_embedding(image):
    embedding = embedding_function._encode_image(image=np.array(image))
    return embedding

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
img = Image.open('data/train/African_crocodile/n01697457_260.JPEG')
get_single_image_embedding(image=img)

[0.020757220685482025,
 0.01700526848435402,
 -0.054676514118909836,
 0.019303275272250175,
 0.0003828889166470617,
 0.012467755936086178,
 -0.01628330908715725,
 -0.002152826404199004,
 0.050817959010601044,
 0.013185952790081501,
 -0.012089519761502743,
 0.007774751633405685,
 0.01889338158071041,
 -0.03941907733678818,
 0.011428206227719784,
 0.022981470450758934,
 -0.06464022397994995,
 0.02643793448805809,
 0.01488641556352377,
 0.0027106276247650385,
 -0.04403828829526901,
 -0.0064188637770712376,
 0.0056654540821909904,
 -0.03082825243473053,
 -0.01642991043627262,
 -0.00035373438731767237,
 0.005325534380972385,
 -0.015266201458871365,
 -0.00775753753259778,
 -0.025533467531204224,
 -0.012705416418612003,
 0.01918025128543377,
 0.0023712213151156902,
 0.01473577506840229,
 -0.026996096596121788,
 -0.03940664231777191,
 0.015418018214404583,
 -0.014308112673461437,
 -0.032431453466415405,
 0.010327630676329136,
 -0.016555285081267357,
 0.041197482496500015,
 -0.03550374135375023

In [8]:
def add_embedding(collection, files_path):
    ids = []
    embeddings = []
    for id_filepath, filepath in tqdm(enumerate(files_path)):
        ids.append(f'id_{id_filepath}')
        image = Image.open(filepath)
        embedding = get_single_image_embedding(image=image)
        embeddings.append(embedding)
    collection.add(
        embeddings=embeddings,
        ids=ids
    )

In [9]:
chroma_client = chromadb.Client()

l2_collection = chroma_client.get_or_create_collection(name="l2_collection",
                                                metadata={HNSW_SPACE: "l2"})
add_embedding(collection=l2_collection, files_path=files_path)

595it [00:56, 10.48it/s]


: 

In [None]:
def search(image_path, collection, n_results):
    query_image = Image.open(image_path)
    query_embedding = get_single_image_embedding(query_image)
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n_results
    )
    return results

In [None]:
test_path = f'{ROOT}/test'
test_files_path = get_files_path(path=test_path)
test_path = test_files_path[1]
l2_results = search(image_path=test_path, collection=l2_collection, n_results=5)

In [None]:
l2_results

In [None]:
plot_results(image_path=test_path, files_path=files_path, results=l2_results)