In [None]:
import os
import chromadb
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction

# Get Data

In [9]:
ROOT = 'data'
CLASS_NAME = sorted(list(os.listdir(f'{ROOT}/train')))
HNSW_SPACE = "hnsw:space"

In [10]:
def get_files_path(path):
    files_path = []
    for label in CLASS_NAME:
        label_path = path + "/" + label
        filenames = os.listdir(label_path)
        for filename in filenames:
            filepath = label_path + '/' + filename
            files_path.append(filepath)
    return files_path

In [11]:
data_path = f'{ROOT}/train'
files_path = get_files_path(path=data_path)
files_path

['data/train/African_crocodile/n01697457_10393.JPEG',
 'data/train/African_crocodile/n01697457_104.JPEG',
 'data/train/African_crocodile/n01697457_1331.JPEG',
 'data/train/African_crocodile/n01697457_14906.JPEG',
 'data/train/African_crocodile/n01697457_18587.JPEG',
 'data/train/African_crocodile/n01697457_260.JPEG',
 'data/train/African_crocodile/n01697457_5586.JPEG',
 'data/train/African_crocodile/n01697457_8136.JPEG',
 'data/train/African_crocodile/n01697457_8331.JPEG',
 'data/train/African_crocodile/n01697457_85.JPEG',
 'data/train/American_egret/n02009912_1358.JPEG',
 'data/train/American_egret/n02009912_13895.JPEG',
 'data/train/American_egret/n02009912_15872.JPEG',
 'data/train/American_egret/n02009912_16896.JPEG',
 'data/train/American_egret/n02009912_26245.JPEG',
 'data/train/American_egret/n02009912_36395.JPEG',
 'data/train/American_egret/n02009912_4403.JPEG',
 'data/train/American_egret/n02009912_5700.JPEG',
 'data/train/American_egret/n02009912_7609.JPEG',
 'data/train/Ame

# Find Difference

In [12]:
def plot_results(image_path, files_path, results):
    query_image = Image.open(image_path).resize((448,448))
    images = [query_image]
    class_name = []
    for id_img in results['ids'][0]:
        id_img = int(id_img.split('_')[-1])
        img_path = files_path[id_img]
        img = Image.open(img_path).resize((448,448))
        images.append(img)
        class_name.append(img_path.split('/')[2])

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))

    # Iterate through images and plot them
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i])
        if i == 0:
            ax.set_title(f"Query Image: {image_path.split('/')[2]}")
        else:
            ax.set_title(f"Top {i+1}: {class_name[i-1]}")
        ax.axis('off')  # Hide axes
    # Display the plot
    plt.show()

## Image Embedding

In [8]:
embedding_function = OpenCLIPEmbeddingFunction()

def get_single_image_embedding(image):
    embedding = embedding_function._encode_image(image=np.array(image))
    return embedding

In [9]:
img = Image.open('data/train/African_crocodile/n01697457_260.JPEG')
get_single_image_embedding(image=img)

## Chromadb L2 Embedding Collection
Tạo embedding collection từ các đường dẫn file ảnh data ở bước trên

In [13]:
def add_embedding(collection, file_path):
    ids = []
    embeddings = []
    
    for id_filepath, filepath in tqdm(enumerate(file_path)):
        ids.append(f'id_{id_filepath}')
        image = Image.open(filepath)
        embedding = get_single_image_embedding(image=image)
        embeddings.append(embedding)
        collection.add(embeddings=embeddings,ids=ids)

In [None]:
# Create a Chroma Client
chroma_client = chromadb.Client()
# Create a collection
l2_collection = chroma_client.get_or_create_collection(name="l2_collection", metadata={HNSW_SPACE: "l2"})
add_embedding(collection=l2_collection, files_path=files_path)

## Search Image With L2 Collection

In [3]:
def search(image_path, collection, n_results):
    query_image = Image.open(image_path)
    query_embedding = get_single_image_embedding(query_image)
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=n_results # how many results to return
    )
    return results

In [None]:
test_path = f'{ROOT}/test'
test_files_path = get_files_path(path=test_path)
test_path = test_files_path[1]
l2_results = search(image_path=test_path, collection=l2_collection, n_results=5)

In [14]:
l2_results

[('data/train/barn_spider/n01773549_701.JPEG', 19470219.0),
 ('data/train/American_egret/n02009912_8563.JPEG', 19809658.0),
 ('data/train/ambulance/n02701002_3315.JPEG', 20909429.0),
 ('data/train/kit_fox/n02119789_2049.JPEG', 21489281.0),
 ('data/train/flatworm/n01924916_4424.JPEG', 21598659.0),
 ('data/train/guillotine/n03467068_11567.JPEG', 21609776.0),
 ('data/train/vine_snake/n01739381_8285.JPEG', 21644385.0),
 ('data/train/castle/n02980441_5253.JPEG', 21714061.0),
 ('data/train/cornet/n03110669_95746.JPEG', 22241626.0),
 ('data/train/vine_snake/n01739381_5263.JPEG', 22545474.0),
 ('data/train/barn_spider/n01773549_2680.JPEG', 23089251.0),
 ('data/train/kit_fox/n02119789_10086.JPEG', 23578968.0),
 ('data/train/barn_spider/n01773549_10106.JPEG', 23769198.0),
 ('data/train/horizontal_bar/n03535780_18270.JPEG', 23799134.0),
 ('data/train/brain_coral/n01917289_1022.JPEG', 24331681.0),
 ('data/train/theater_curtain/n04418357_13381.JPEG', 24665234.0),
 ('data/train/horizontal_bar/n03535

In [None]:
plot_results(image_path=test_path, files_path=files_path, results=l2_results)

## Search Image With Cosine similarity Collection

In [None]:
# Create a collection
cosine_collection = chroma_client.get_or_create_collection(name="Cosine_collection",
                                                           metadata={HNSW_SPACE: "cosine"})
add_embedding(collection=cosine_collection, files_path=files_path)

In [None]:
test_path = f'{ROOT}/test'
test_files_path = get_files_path(path=test_path)
test_path = test_files_path[1]
cosine_results = search(image_path=test_path, collection=cosine_collection, n_results=5)

In [None]:
cosine_results

In [None]:
plot_results(image_path=test_path, files_path=files_path, results=cosine_results)