<a href="https://colab.research.google.com/github/jarvisx17/OpenAI-Clip-Image-Search/blob/main/OpenAI_Clip_Faiss_Image_Search_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from PIL import Image
import PIL
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import torch
from datasets import Dataset, Image
from torch.utils.data import DataLoader
from typing import List, Union, Tuple
from transformers import CLIPProcessor, CLIPModel
import faiss

## **Download DATA**

In [None]:
import os

os.environ['KAGGLE_USERNAME'] = "xxxxxxx"  # Replace with your Kaggle username
os.environ['KAGGLE_KEY'] = "xxxxxxxxxxxx"  # Replace with your Kaggle API key

!kaggle datasets download -d adityajn105/flickr30k
!unzip /content/flickr30k.zip

## **Creating Image Embedding Using OpenAI Clip**

In [None]:
device = "cuda"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
image_path = os.listdir('/content/Images/')
image_path = ['/content/Images/' + path for path in image_path if '.jpg' in path]
image_path.sort()
captions_df = pd.read_csv('/content/captions.txt')

In [None]:
device = torch.device("cuda")
def encode_images(images: Union[List[str], List[PIL.Image.Image]], batch_size: int):
    def transform_fn(el):
        if isinstance(el['image'], PIL.Image.Image):
            imgs = el['image']
        else:
            imgs = [Image().decode_example(_) for _ in el['image']]
        return preprocess(images=imgs, return_tensors='pt')

    dataset = Dataset.from_dict({'image': images})
    dataset = dataset.cast_column('image',Image(decode=False)) if isinstance(images[0], str) else dataset
    dataset.set_format('torch')
    dataset.set_transform(transform_fn)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    image_embeddings = []
    pbar = tqdm(total=len(images) // batch_size, position=0)
    with torch.no_grad():
        for batch in dataloader:
            batch = {k:v.to(device) for k,v in batch.items()}
            image_embeddings.extend(model.get_image_features(**batch).detach().cpu().numpy())
            pbar.update(1)
        pbar.close()
    return np.stack(image_embeddings)

vector_embedding = np.array(encode_images(image_path,128))

  2%|▏         | 5/248 [00:15<10:58,  2.71s/it]

KeyboardInterrupt: ignored

In [None]:
def encode_text( text: List[str], batch_size: int):
    device =  "cuda"
    dataset = Dataset.from_dict({'text': text})
    dataset = dataset.map(lambda el: preprocess(text=el['text'], return_tensors="pt",
                                                        max_length=77, padding="max_length", truncation=True),
                            batched=True,
                            remove_columns=['text'])
    dataset.set_format('torch')
    dataloader = DataLoader(dataset, batch_size=batch_size)
    text_embeddings = []
    pbar = tqdm(total=len(text) // batch_size, position=0)
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            text_embeddings.extend(model.get_text_features(**batch).detach().cpu().numpy())
            pbar.update(1)
        pbar.close()
    return np.stack(text_embeddings)

## **Saving Image Embeddings as JSON**

In [None]:
with open('flicker30k_image_embeddings.pkl','wb') as f:
    pickle.dump(vector_embedding, f)

import pickle
with open('flicker30k_image_embeddings.pkl', 'rb') as fp:
    vector_embedding = pickle.load(fp)

## **Building Index Using FAISS**

In [None]:
index = faiss.IndexFlatIP(vector_embedding.shape[1])
index.add(vector_embedding)

## **Text to Image Search**

In [None]:
def Search(search_text, results):
  with torch.no_grad():
      text_search_embedding = encode_text([search_text], batch_size=32)
  text_search_embedding = text_search_embedding/np.linalg.norm(text_search_embedding, ord=2, axis=-1, keepdims=True)
  distances, indices = index.search(text_search_embedding.reshape(1, -1), results)
  distances = distances[0]
  indices = indices[0]

  indices_distances = list(zip(indices, distances))
  indices_distances.sort(key=lambda x: x[1])  # Sort based on the distances
  from PIL import Image
  fixed_size = (300, 300)  # Define the fixed dimension (width, height) for the displayed images
  for idx, distance in indices_distances:
      path = image_path[idx]
      print(path)
      im = Image.open(path)
      im_resized = im.resize(fixed_size)  # Resize the image to the fixed dimension
      plt.imshow(im_resized)
      plt.show()

In [None]:
search_text = "football"
results = 10
Search(search_text, results)