In [None]:
import sys

In [None]:
import kagglehub

In [None]:
path = kagglehub.dataset_download("matthewjansen/unsplash-lite-5k-colorization")

In [None]:
path

In [None]:
import clip

In [None]:
import torch

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model, preprocess = clip.load("RN50", device=device)
model.eval()

In [None]:
import chromadb
from chromadb.config import Settings

In [None]:
client = chromadb.Client(Settings(persist_directory="./chroma_db"))
collection = client.get_or_create_collection("clip_embeddings")

In [None]:
from PIL import Image
from tqdm.auto import tqdm
import os
import numpy as np

In [None]:
image_folder_path = "C:/Users/User/.cache/kagglehub/datasets/matthewjansen/unsplash-lite-5k-colorization/versions/2/test/color/"

In [None]:
image_paths = []

# Process and add each image to ChromaDB
for image_name in tqdm(os.listdir(image_folder_path)):
    image_path = os.path.join(image_folder_path, image_name)
    if image_path.endswith(('jpg', 'jpeg', 'png')):  # Filter for image files
        image = Image.open(image_path).convert("RGB")
        image_input = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_embedding = model.encode_image(image_input).cpu().numpy().flatten()  # Generate embedding
        
        # Add embedding and metadata to ChromaDB collection
        collection.add(
            embeddings=[image_embedding.tolist()],  # Place embeddings here
            metadatas=[{"filename": image_name, "path": image_path}],  # Store metadata here
            ids=[image_name],  # Unique IDs for each entry
            documents=[None]  # Set to None if you don’t have text documents
        )

print("All image embeddings have been added to ChromaDB.")

In [None]:
def get_text_embedding(query_text):
    with torch.no_grad():
        text_tokens = clip.tokenize([query_text]).to(device)  # Tokenize the text
        text_embedding = model.encode_text(text_tokens).cpu().numpy().flatten()  # Generate embedding
    return text_embedding

In [None]:
def search_similar_images_by_text(query_text, top_k=5):

    text_embedding = get_text_embedding(query_text)
    
    results = collection.query(
        query_embeddings=[text_embedding.tolist()], 
        n_results=top_k  
    )
    
    print(f"Top {top_k} similar images for text query '{query_text}':")
    search_results = []
    for i, result in enumerate(results["metadatas"][0]):
        print(f"{i+1}: {result['filename']} - Path: {result['path']}")
        search_results.append(result['path'])
    return search_results

In [None]:
image_results = search_similar_images_by_text("a beautiful sunset over water", top_k=5)

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(20, 5)) 

for i, img_path in enumerate(image_results):
    img = Image.open(img_path)
    
    # Display each image in a subplot
    plt.subplot(1, 5, i + 1)
    plt.imshow(img)
    plt.axis('off') 
plt.show()