## Referencesm

https://docs.trychroma.com/guides/multimodal

https://docs.trychroma.com/guides#using-collections

https://cookbook.chromadb.dev/embeddings/gpu-support/#openclipm

## Prepare Image Data

### Download Image Data

In [None]:
# %pip install gdown

In [None]:
import gdown

# !gdown 1msLVo0g0LFmL9-qZ73vq9YEVZwbzOePF
url = "https://drive.google.com/uc?id=1msLVo0g0LFmL9-qZ73vq9YEVZwbzOePF"
output = "image_data.zip"

gdown.download(url, output, quiet=False)

In [None]:
!unzip -q image_data.zip

In [None]:
rm image_data.zip

### Load Embedding

In [1]:
import os 
root = 'data/train'

def get_image_uris(root):
    image_uris = []
    for class_name in os.listdir(root):
        class_path = os.path.join(root, class_name)
        images_name = os.listdir(class_path)
        image_uris += [ os.path.join(class_path, fn) for fn in images_name ]
    return image_uris


In [2]:
# %pip install matplotlib

In [3]:
# import matplotlib.pyplot as plt

# image = np.array(Image.open(img_path[0]))
# print(image.shape)
# plt.imshow(image)
# plt.axis('off')
# plt.show()

In [4]:
# dir(clip_embedding_function)
# _PILImage,  _encode_image, _encode_text, _model,_preprocess ,_tokenizer , _torch

In [5]:
# import torch
# from PIL import Image
# import numpy as np
from typing import  cast
from chromadb.api.types import Embeddings
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
clip_embedding_function = OpenCLIPEmbeddingFunction(device='cuda')

def prepocessing_image(img_path_list, size):
    img_batch = []
    for img_path in img_path_list:
        # OpenClip receive PIL image and pre-process single image as pre-processing in Pytorch
        # img = Image.open(img_path).resize(size=size)
        img = clip_embedding_function._PILImage.open(img_path).resize(size=size)
        img = clip_embedding_function._preprocess(img)
        img_batch.append(img)
    return img_batch


def get_batch_embedding(img_batch):
    # img_batch = torch.stack(img_batch, axis=0)
    img_batch = clip_embedding_function._torch.stack(img_batch, axis=0)
    img_batch = img_batch.to('cuda')
    with clip_embedding_function._torch.no_grad():
        image_features = clip_embedding_function._model.encode_image(img_batch)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return cast(Embeddings, image_features.to('cpu').tolist())


def get_embedding_loader(img_path_list, batch_size, size):
    for i in range(0, len(img_path_list), batch_size):
        start = i
        end = min(i+batch_size, len(img_path_list))
        # preprocessing images
        img_batch = prepocessing_image(img_path_list[start:end], size=size)
        
        # get embeddings
        image_features = get_batch_embedding(img_batch)
        
        # generate ids and metadata
        image_ids = [f"img_{idx_}" for idx_ in range(start, end, 1)]
        image_metadata = [{'ver': idx_%10} for idx_ in range(start, end, 1)]
        
        yield image_features, image_ids, image_metadata

# def get_batch_embedding(img_path_list, batch_size, size):
#     for i in range(0, len(img_path_list), batch_size):
#         start = i
#         end = min(i+batch_size, len(img_path_list))
#         img_batch = [ Image.open(img_path).resize(size=size) for img_path in img_path_list[start:end] ]
#         img_batch = [ torch.tensor(np.array(img)) for img in img_batch ]
#         img_batch = torch.stack(img_batch)
#         img_batch = img_batch.permute(0, 3, 1, 2).to('cuda')
#         img_batch = clip_embedding_function(img_batch)
#         yield img_batch


  from .autonotebook import tqdm as notebook_tqdm
  checkpoint = torch.load(checkpoint_path, map_location=map_location)


In [81]:
# image_uris = get_image_uris(root)
# embedding_loader = get_embedding_loader(image_uris, batch_size=32, size=(224,224))
# a_embedd, a_ids, a_meta = next(embedding_loader)
# print(len(a_embedd), len(a_ids), len(a_meta))

32 32 32


## In-memory Vector Database

### Create Client

In [6]:
import chromadb
client = chromadb.Client()

### Create Collection

In [7]:
collection = client.create_collection(name='image_clip_embedding_collection') 

### Add embeddings to the collection

In [10]:
def add_embeddings(image_uris, size, batch_size, collection):
    embedding_loader = get_embedding_loader(image_uris, batch_size=batch_size, size=size)
    for embeddings, ids, metadata in embedding_loader:
        collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)


In [11]:
image_uris = get_image_uris(root)
add_embeddings(image_uris=image_uris, size=(224,224), batch_size=32, collection=collection)

### Query

In [14]:
def query_embeddings(test_image_uris, size, collection):
    # preprocessing images
    test_img_batch = prepocessing_image(test_image_uris, size=size)
    # get embeddings
    test_image_features = get_batch_embedding(test_img_batch)
    # query
    results = collection.query(query_embeddings=test_image_features, n_results=5)
    return results 

In [15]:
test_image_uris = get_image_uris('data/test')
print(test_image_uris[0:1])
query_embeddings(test_image_uris[0:1], size=(224,224), collection=collection)

['data/test/ambulance/n02701002_2311.JPEG']


{'ids': [['img_4', 'img_7', 'img_6', 'img_5', 'img_1']],
 'distances': [[0.5078729391098022,
   0.5386553406715393,
   0.5713851451873779,
   0.5825256705284119,
   0.7087713479995728]],
 'metadatas': [[{'ver': 4}, {'ver': 7}, {'ver': 6}, {'ver': 5}, {'ver': 1}]],
 'embeddings': None,
 'documents': [[None, None, None, None, None]],
 'uris': None,
 'data': None,
 'included': ['metadatas', 'documents', 'distances']}

In [16]:
image_uris[4:7]

['data/train/ambulance/n02701002_15786.JPEG',
 'data/train/ambulance/n02701002_18950.JPEG',
 'data/train/ambulance/n02701002_1264.JPEG']

In [17]:
test_image_uris = get_image_uris('data/test')
print(test_image_uris[0:5])
query_embeddings(test_image_uris[0:5], size=(224,224), collection=collection)

['data/test/ambulance/n02701002_2311.JPEG', 'data/test/horizontal_bar/n03535780_9902.JPEG', 'data/test/dugong/n02074367_5140.JPEG', 'data/test/killer_whale/n02071294_20475.JPEG', 'data/test/flatworm/n01924916_6615.JPEG']


{'ids': [['img_4', 'img_7', 'img_6', 'img_5', 'img_1'],
  ['img_16', 'img_12', 'img_13', 'img_14', 'img_10'],
  ['img_25', 'img_27', 'img_24', 'img_26', 'img_28'],
  ['img_35', 'img_30', 'img_31', 'img_38', 'img_34'],
  ['img_49', 'img_42', 'img_43', 'img_45', 'img_127']],
 'distances': [[0.5078730583190918,
   0.538655161857605,
   0.5713850259780884,
   0.5825255513191223,
   0.7087711095809937],
  [0.5071476697921753,
   0.6265128254890442,
   0.667927086353302,
   0.6868309378623962,
   0.7616097331047058],
  [0.4300346374511719,
   0.549135684967041,
   0.5771130323410034,
   0.6138494610786438,
   0.6673237085342407],
  [0.3649950623512268,
   0.4591969847679138,
   0.5159907341003418,
   0.5380780100822449,
   0.7076171636581421],
  [0.27509841322898865,
   0.401824414730072,
   0.4599449038505554,
   0.47627371549606323,
   0.5155505537986755]],
 'metadatas': [[{'ver': 4}, {'ver': 7}, {'ver': 6}, {'ver': 5}, {'ver': 1}],
  [{'ver': 6}, {'ver': 2}, {'ver': 3}, {'ver': 4}, {'ver'