In [None]:
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
import torch
import numpy as np
from pathlib import Path
from PIL import Image
import tensorflow as tf

demo_directory = Path("California/Photos/")
images_to_paths = {image_path.stem: image_path for image_path in demo_directory.iterdir() if image_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif']}

images = [np.array(Image.open(path)) for path in images_to_paths.values()]
images = [torch.tensor(image, dtype=torch.float32) for image in images]

images = [tf.constant(image.numpy(), dtype=tf.float32) for image in images]

inputs = processor(images=images, return_tensors="pt", padding=True)

with torch.no_grad():
    outputs = model.get_image_features(**inputs)

images_to_embeddings = {image_id: tensor_embedding.detach().numpy() for image_id, tensor_embedding in zip(images_to_paths.keys(), outputs)}


In [None]:
import numpy as np
from sklearn.cluster import DBSCAN
from collections import defaultdict

image_ids = list(images_to_embeddings.keys())
embeddings = list(images_to_embeddings.values())

clustering = DBSCAN(min_samples=2, eps=3).fit(np.stack(embeddings))

image_id_communities = defaultdict(set)
independent_image_ids = set()

for image_id, cluster_idx in zip(image_ids, clustering.labels_):
    cluster_idx = int(cluster_idx)
    if cluster_idx == -1:
        independent_image_ids.add(image_id)
        continue

    image_id_communities[cluster_idx].add(image_id)


In [None]:
len(independent_image_ids)

image_id_communities

In [None]:
import matplotlib.pyplot as plt
for image_id_community in image_id_communities.values():
    for image_id in image_id_community:
        plt.figure()
        plt.imshow(Image.open(images_to_paths[image_id]))

# images that have not got a cluster with similar images assigned to them
for image_id in independent_image_ids:
    plt.figure()
    plt.imshow(Image.open(images_to_paths[image_id]))