In [1]:
import matplotlib.pyplot as plt
import os
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import torch
import os
from qdrant_client import QdrantClient
from qdrant_client import models
# from qdrant_client.http import models
import tqdm


In [None]:
# Preprocess data by renaming each file
def preprocess_data(directory):
    # Get all files in the directory
    files = os.listdir(directory)
    # Rename each file
    for i, file in tqdm.tqdm(enumerate(files), total=len(files), desc="Preprocessing data"):
        # Get the extension of the file
        extension = file.split('.')[-1]
        # Rename the file
        new_name = f"{i}.{extension}"
        
        # Rename the file
        os.rename(f"{directory}/{file}", f"{directory}/{new_name}")
        
# Preprocess the data
preprocess_data("data/Images")


In [4]:
def setup_qdrant():
    client = QdrantClient(path="database/")
    client.create_collection(collection_name="images",
                               vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE, on_disk=True))
    return client

def index_images(directory, client, processor, model, device):
    for filename in tqdm.tqdm(os.listdir(directory), desc="Indexing images", total=len(os.listdir(directory))):
        if filename.endswith((".jpg", ".png")):
            image_path = os.path.join(directory, filename)
            image = Image.open(image_path)

            with torch.no_grad():
                inputs = processor(images=image, return_tensors="pt").to(device)
                outputs = model(**inputs)
                features = outputs.last_hidden_state.mean(dim=1).cpu().numpy()  
            
            client.upsert(collection_name="images", 
                          points = [
                              models.PointStruct(
                                    id=int(filename.split('.')[0]),
                                    vector=features.flatten().tolist()
                                )
                          ])
            
            # move file to new directory
            os.rename(image_path, f"data/indexed/{filename}")
                          
    print("Indexing complete")

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
model = AutoModel.from_pretrained('facebook/dinov2-small').to(device)
model.eval()

client = setup_qdrant()

# Index images
directory = 'data/indexed'
index_images(directory, client, processor, model, device)


Indexing images: 100%|██████████| 8091/8091 [13:45<00:00,  9.81it/s]

Indexing complete





In [3]:
with torch.no_grad():
    %timeit model(processor(images=Image.open("data/indexed/0.jpg"), return_tensors="pt").to(device).pixel_values).last_hidden_state.mean(dim=1)


118 ms ± 25.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# test upload time
c = QdrantClient(":memory:")
c.create_collection(collection_name="images",
                               vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE, on_disk=True))

%timeit c.upsert(collection_name="images", points = [models.PointStruct(id=0, vector=[0]*384)])


In [None]:
def find_most_similar_image(client, reference_image_path, processor, model, device):
    image_ref = Image.open(reference_image_path)
    with torch.no_grad():
        inputs_ref = processor(images=image_ref, return_tensors="pt").to(device)
        outputs_ref = model(**inputs_ref)
        features_ref = outputs_ref.last_hidden_state.mean(dim=1).cpu().numpy()

    # Perform the search
    search_result = client.search(
        collection_name="images",
        query_vector=features_ref.flatten().tolist(),
        limit=5
    )

    return search_result


# Find most similar image
reference_image_path = 'data/indexed/15.jpg'
result = find_most_similar_image(client, reference_image_path, processor, model, device)
print(result)


In [None]:
# show reference image
reference_image = Image.open(reference_image_path)
plt.imshow(reference_image)


In [None]:
# Display the most similar images
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
for i, point in enumerate(result):
    image_path = os.path.join('data/indexed', f"{point.id}.jpg")
    image = Image.open(image_path)
    axes[i].imshow(image)
    axes[i].axis('off')
    
    if i == 0:
        axes[i].set_title("Reference Image")
    else:
        axes[i].set_title(f"Similar Image {i}")
        
plt.show()
