In [18]:
import os
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import chromadb
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE, Settings

# Ensure the variables are defined
persist_directory = '../data/chromadb'
slides_path = '../data/slides'

collection_name = 'CLIP_slides_collection'

# Initialize ChromaDB client with the existing settings
client = chromadb.PersistentClient(
    path=persist_directory,
    settings=Settings(),
    tenant=DEFAULT_TENANT,
    database=DEFAULT_DATABASE,
)

# List all collections in ChromaDB
collections = client.list_collections()
print("Existing collections:")
for collection_n in collections:
    collection = client.get_collection(collection_n)
    print(collection.name)

# Create or get the collection in ChromaDB
collection = client.create_collection(collection_name, get_or_create=True)
print("Created or got collection 'CLIP_slides_collection'")

# Load the CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("Loaded CLIP model and processor")

# Function to embed an image using CLIP
def embed_image(image_path):
    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        embeddings = model.get_image_features(**inputs)
    return embeddings.squeeze().tolist()

# Iterate over each image in the slides folder and embed it
for filename in os.listdir(slides_path):
    if filename.endswith('.png'):
        image_path = os.path.join(slides_path, filename)
        embeddings = embed_image(image_path)
        collection.add(
            ids=[filename],  # Use filename as the unique ID
            documents=[filename],
            embeddings=[embeddings]
        )
        print(f"Added {filename} to collection")


Existing collections:
Created or got collection 'CLIP_slides_collection'
Loaded CLIP model and processor
Added AirBnB_Pitch_Deck_slide10.png to collection
Added AirBnB_Pitch_Deck_slide11.png to collection
Added AirBnB_Pitch_Deck_slide13.png to collection
Added AirBnB_Pitch_Deck_slide12.png to collection
Added AirBnB_Pitch_Deck_slide9.png to collection
Added AirBnB_Pitch_Deck_slide16.png to collection
Added AirBnB_Pitch_Deck_slide17.png to collection
Added AirBnB_Pitch_Deck_slide8.png to collection
Added AirBnB_Pitch_Deck_slide15.png to collection
Added AirBnB_Pitch_Deck_slide14.png to collection
Added AirBnB_Pitch_Deck_slide6.png to collection
Added AirBnB_Pitch_Deck_slide18.png to collection
Added AirBnB_Pitch_Deck_slide7.png to collection
Added AirBnB_Pitch_Deck_slide5.png to collection
Added AirBnB_Pitch_Deck_slide4.png to collection
Added AirBnB_Pitch_Deck_slide1.png to collection
Added AirBnB_Pitch_Deck_slide3.png to collection
Added AirBnB_Pitch_Deck_slide2.png to collection


In [17]:
client.delete_collection(collection_name)

In [21]:
# Retrieve and print the metadata of the embeddings
documents = collection.get(ids=[filename for filename in os.listdir(slides_path) if filename.endswith('.png')], include=['documents', 'embeddings'])
for doc_id, doc, embedding in zip(documents['ids'], documents['documents'], documents['embeddings']):
    print(f"ID: {doc_id}, Document: {doc}, Embedding: {embedding[:5]}...")  # Print first 5 elements of the embedding for brevity

ID: AirBnB_Pitch_Deck_slide10.png, Document: AirBnB_Pitch_Deck_slide10.png, Embedding: [-0.29650497 -0.45568323  0.12390557  0.14443952 -0.04381865]...
ID: AirBnB_Pitch_Deck_slide11.png, Document: AirBnB_Pitch_Deck_slide11.png, Embedding: [-0.31927934 -0.19493011  0.05249168  0.12463375  0.28774643]...
ID: AirBnB_Pitch_Deck_slide13.png, Document: AirBnB_Pitch_Deck_slide13.png, Embedding: [-0.20980214 -0.07775372  0.08181545 -0.63865143  0.54446042]...
ID: AirBnB_Pitch_Deck_slide12.png, Document: AirBnB_Pitch_Deck_slide12.png, Embedding: [-0.13213493 -0.36214155  0.06294962 -0.22132587  0.12163597]...
ID: AirBnB_Pitch_Deck_slide9.png, Document: AirBnB_Pitch_Deck_slide9.png, Embedding: [-0.02860507  0.09930226 -0.01386034  0.28498939  0.31792754]...
ID: AirBnB_Pitch_Deck_slide16.png, Document: AirBnB_Pitch_Deck_slide16.png, Embedding: [-0.49150854 -0.40084404  0.01018639  0.23208091  0.17551343]...
ID: AirBnB_Pitch_Deck_slide17.png, Document: AirBnB_Pitch_Deck_slide17.png, Embedding: [-0

In [16]:
# Retrieve all documents and their embeddings from the collection
documents = collection.get(include=['embeddings'])

# Print the embeddings
for doc_id, embedding in zip(documents['ids'], documents['embeddings']):
    print(f"ID: {doc_id}, Embedding: {embedding}")

ID: AirBnB_Pitch_Deck_slide10.png
ID: AirBnB_Pitch_Deck_slide11.png
ID: AirBnB_Pitch_Deck_slide13.png
ID: AirBnB_Pitch_Deck_slide12.png
ID: AirBnB_Pitch_Deck_slide9.png
ID: AirBnB_Pitch_Deck_slide16.png
ID: AirBnB_Pitch_Deck_slide17.png
ID: AirBnB_Pitch_Deck_slide8.png
ID: AirBnB_Pitch_Deck_slide15.png
ID: AirBnB_Pitch_Deck_slide14.png
ID: AirBnB_Pitch_Deck_slide6.png
ID: AirBnB_Pitch_Deck_slide18.png
ID: AirBnB_Pitch_Deck_slide7.png
ID: AirBnB_Pitch_Deck_slide5.png
ID: AirBnB_Pitch_Deck_slide4.png
ID: AirBnB_Pitch_Deck_slide1.png
ID: AirBnB_Pitch_Deck_slide3.png
ID: AirBnB_Pitch_Deck_slide2.png
