# Image analysis and search using CLIP

[OpenAI CLIP Model](https://github.com/openai/CLIP) is an embeddings model that is able to work with both text and images.

We will work with a dataset of 25k images from [Unsplash Lite Dataset](https://unsplash.com/data).

## Outline

1.   Install libraries and download data
2.   Search for clusters of similar images
3.   Search images by text or another image



In [None]:
#@markdown ▶ Install libraries and load model

!pip install -q sentence_transformers
!pip install -q mediapy
!pip install -q gradio

from sentence_transformers import SentenceTransformer, util
from PIL import Image
import glob
import torch
import pickle
import zipfile
from IPython.display import display
from IPython.display import Image as IPImage
import os
import gradio as gr
from tqdm.autonotebook import tqdm
import mediapy as media

#First, we load the respective CLIP model
model = SentenceTransformer('clip-ViT-B-32')


**Google drive**

Only required if you are going to use your own images.



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@markdown Select if you want to use your own images or Unsplash dataset
use_own_dataset = False # @param {type:"boolean"}

#@markdown Specify the path to your images on Drive
dataset_path = '' # @param {type:"string"}

img_folder = 'photos/' if not use_own_dataset else dataset_path

In [None]:
#@markdown ▶ Download images

#@markdown Only required if you don't use your own images

# Next, we get about 25k images from Unsplash
if not use_own_dataset and (not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0):
    os.makedirs(img_folder, exist_ok=True)

    photo_filename = 'unsplash-25k-photos.zip'
    if not os.path.exists(photo_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)

    #Extract all images
    with zipfile.ZipFile(photo_filename, 'r') as zf:
        for member in tqdm(zf.infolist(), desc='Extracting'):
            zf.extract(member, img_folder)


In [None]:
#@markdown ▶ Download precomputed embeddings or compute your own embeddings

# Now, we need to compute the embeddings
# To speed things up, we destribute pre-computed embeddings
# Otherwise you can also encode the images yourself.
# To encode an image, you can use the following code:
# from PIL import Image
# img_emb = model.encode(Image.open(filepath))

import os

if not use_own_dataset:
    emb_filename = 'unsplash-25k-photos-embeddings.pkl'
    if not os.path.exists(emb_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename)

    with open(emb_filename, 'rb') as fIn:
        img_names, img_emb = pickle.load(fIn)
    print("Images:", len(img_names))
else:
    emb_filename = 'embeddings.pkl'
    emb_path = os.path.join(img_folder, emb_filename)
    print(emb_path)

    if os.path.exists(emb_path):
      with open(emb_path, 'rb') as fIn:
        img_names, img_emb = pickle.load(fIn)
      print("Images:", len(img_names))
    else:
      img_names = list(glob.glob(os.path.join(img_folder, '*.jpg')) + glob.glob(os.path.join(img_folder, '*.png')))
      print("Images:", len(img_names))
      print(img_names)
      img_emb = model.encode([Image.open(filepath) for filepath in img_names], batch_size=128, convert_to_tensor=True, show_progress_bar=True)

      with open(emb_path, 'wb') as handle:
        pickle.dump([img_names, img_emb], handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
#@markdown ▶ Declare needed functions

# We have implemented our own, efficient method
# to find high density regions in vector space
def community_detection(embeddings, threshold, min_community_size=10, init_max_size=1000):
    """
    Function for Fast Community Detection

    Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).

    Returns only communities that are larger than min_community_size. The communities are returned
    in decreasing order. The first element in each list is the central point in the community.
    """

    # Compute cosine similarity scores
    cos_scores = util.cos_sim(embeddings, embeddings)

    # Minimum size for a community
    top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)

    # Filter for rows >= min_threshold
    extracted_communities = []
    for i in range(len(top_k_values)):
        if top_k_values[i][-1] >= threshold:
            new_cluster = []

            # Only check top k most similar entries
            top_val_large, top_idx_large = cos_scores[i].topk(k=init_max_size, largest=True)
            top_idx_large = top_idx_large.tolist()
            top_val_large = top_val_large.tolist()

            if top_val_large[-1] < threshold:
                for idx, val in zip(top_idx_large, top_val_large):
                    if val < threshold:
                        break

                    new_cluster.append(idx)
            else:
                # Iterate over all entries (slow)
                for idx, val in enumerate(cos_scores[i].tolist()):
                    if val >= threshold:
                        new_cluster.append(idx)

            extracted_communities.append(new_cluster)

    # Largest cluster first
    extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)

    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()

    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break

        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)

    return unique_communities

# Next, we define a search function.
def search(query, k=3):
    # First, we encode the query (which can either be an image or a text string)
    query_emb = model.encode([query], convert_to_tensor=True, show_progress_bar=False)

    # Then, we use the util.semantic_search function, which computes the cosine-similarity
    # between the query embedding and all image embeddings.
    # It then returns the top_k highest ranked images, which we output
    hits = util.semantic_search(query_emb, img_emb, top_k=k)[0]

    output = []
    for hit in hits:
        image_path = os.path.join(img_folder, img_names[hit['corpus_id']])
        image = Image.open(image_path)

        output.append((image, f"{hit['score']:.2f}"))

    return output

In [None]:
#@title ▶ Cluster the images and show samples from the 10 largest clusters

# Now we run the clustering algorithm
# With the threshold parameter, we define at which threshold we identify
# two images as similar. Set the threshold lower, and you will get larger clusters which have
# less similar images in it (e.g. black cat images vs. cat images vs. animal images).
# With min_community_size, we define that we only want to have clusters of a certain minimal size
clusters = community_detection(img_emb, threshold=0.9, min_community_size=10)

num_clusters = 10
num_images = 5

images = []

# Now we output the first 10 (largest) clusters
for cluster in clusters[0:num_clusters]:

    #Output 3 images
    for idx in cluster[0:num_images]:
        image_path = os.path.join(img_folder, img_names[idx])
        images.append(Image.open(image_path))

media.show_images(images, height=256, columns=5)

In [None]:
#@title  ▶  Search the dataset by text or image

def search_fn(input_file, input_text, number_of_results):
    if input_file is not None:
        return search(Image.open(input_file), k=number_of_results)
    else:
        return search(input_text, k=number_of_results)

with gr.Blocks() as demo:
  with gr.Column():
    input_file = gr.File(file_count="single", file_types=[".jpg", ".png"], label="Search image")
    input_text = gr.Textbox(label="Search text")
    number_of_results = gr.Number(label="Number of results", minimum=1, value=3, maximum=10)
    search_button = gr.Button(value="Search")

  with gr.Column():
    gallery = gr.Gallery(label="Results", show_label=True, columns=[3], object_fit="contain", height="auto")

  search_button.click(search_fn, [input_file, input_text, number_of_results], [gallery])

demo.launch(quiet=True, debug=False, height=768)

# Finalizing

When you finish working you have to remember to **stop the runtime**, because there is a time limit and to avoid wasting resources. To stop the runtime click Manage Sessions on the Runtime menu. Once the dialog opens click terminate on the current runtime.

> But when you stop the runtime everything you have not saved is ⚠ **lost** ⚠, so be sure to **download** everything you want to keep before stopping it.


# Credits

Taller Estampa https://tallerestampa.com / https://github.com/estampa

### Based on

[Sentence Transformers](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/image-search)
