In [None]:
!pip install chromadb datasets

In [1]:
import chromadb
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader

import os
from datasets import load_dataset
from matplotlib import pyplot as plt

### Load dataset

In [None]:
dataset = load_dataset(path="detection-datasets/coco", name="coco", split="train", streaming=False)

IMAGE_FOLDER = "images"
N_IMAGES = 20

plot_cols = 5
plot_rows = N_IMAGES // plot_cols
fig, axes = plt.subplot(plot_rows, plot_cols, figsize=(plot_rows*2, plot_cols*2))
axes = axes.flatten()

dataset_iter = iter(dataset)
os.makedirs(IMAGE_FOLDER, exist_ok=True)

for i in range(N_IMAGES):
  image = next(dataset_iter)['image']
  axes[i].imshow(image)
  axes[i].axis("off")

  image.save(f"image/{i}.jpg")

plt.tight_layout()
plt.show()

Resolving data files:   0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
client = chromadb.Client()

embedding_function = OpenCLIPEmbeddingFunction()
image_loader = ImageLoader()

In [None]:
collection = client.create_collection(
    name="mutimodal_collection",
    embedding_function=embedding_function,
    data_loader=image_loader
)

In [None]:
image_uris = sorted([os.path.join(IMAGE_FOLDER, image_name) for image_name in os.listdir(IMAGE_FOLDER)])
ids = [str(i) for i in range(len(image_uris))]

collection.add(ids=ids, uris=image_uris)

### Querying a multi-modal collection

Using text

In [None]:
### support imageshow function
def display_result(retrived):
  for img in retrived['data'][0]:
    plt.imshow(img)
    plt.axis("off")
    return plt.show()

In [None]:
# Querying for "Animals"
retrived = collection.query(quert_text=["animals"], include=['data'], n_results=3)

display_result(retrived)

In [None]:
# Querying for "Vehicles"
retrived = collection.query(quert_text=["Vehicles"], include=['data'], n_results=3)

display_result(retrived)

In [None]:
# Querying for "Street Scenes"
retrived = collection.query(quert_text=["Street Scenes"], include=['data'], n_results=3)

display_result(retrived)

Using query image

In [None]:
from PIL import Image
import numpy as np

query_image = np.array(Image.open(f"{IMAGE_FOLDER}/1.jpg"))
print("Query Image")
plt.imshow(query_image)
plt.axis('off')
plt.show()

print("Results")
retrieved = collection.query(query_images=[query_image], include=['data'], n_results=5)
for img in retrieved['data'][0][1:]:
    plt.imshow(img)
    plt.axis("off")
    plt.show()

Using query uris

In [None]:
query_uri = image_uris[1]

query_result = collection.query(query_uris=query_uri, include=['data'], n_results=5)
for img in query_result['data'][0][1:]:
    plt.imshow(img)
    plt.axis("off")
    plt.show()