# Text-to-Image and Image-to-image search Using CLIP

# Import libraries

In [1]:
import aiohttp
import asyncio
import torch
import requests
import os
import time
import json
import dill
import gzip
import hashlib
import pandas as pd
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor, as_completed
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def save_pickle(pkl, fname:str=""):
	print(f"\nSaving {type(pkl)}\n{fname}")
	st_t = time.time()
	if isinstance(pkl, dict):
		with open(fname, mode="w") as f:
			json.dump(pkl, f)
	elif isinstance(pkl, ( pd.DataFrame, pd.Series ) ):
		pkl.to_pickle(path=fname)
	else:
		# with open(fname , mode="wb") as f:
		with gzip.open(fname , mode="wb") as f:
			dill.dump(pkl, f)
	elpt = time.time()-st_t
	fsize_dump = os.path.getsize(fname) / 1e6
	print(f"Elapsed_t: {elpt:.3f} s | {fsize_dump:.2f} MB".center(120, " "))

def load_pickle(fpath: str) -> object:
	print(f"Loading {fpath}")
	if not os.path.exists(fpath):
		raise FileNotFoundError(f"File not found: {fpath}")
	start_time = time.time()
	try:
		with open(fpath, mode='r') as f:
			pickle_obj = json.load(f)
	except Exception as exerror:
		# print(f"not a JSON file: {exerror}")
		try:
			with gzip.open(fpath, mode='rb') as f:
				pickle_obj = dill.load(f)
		except gzip.BadGzipFile as ee:
			print(f"Error BadGzipFile: {ee}")
			with open(fpath, mode='rb') as f:
				pickle_obj = dill.load(f)
		except Exception as eee:
			print(f"Error dill: {eee}")
			try:
				pickle_obj = pd.read_pickle(fpath)
			except Exception as err:
				print(f"Error pandas pkl: {err}")
				raise
	elapsed_time = time.time() - start_time
	file_size_mb = os.path.getsize(fpath) / 1e6
	print(f"Elapsed_t: {elapsed_time:.3f} s | {type(pickle_obj)} | {file_size_mb:.3f} MB".center(150, " "))
	return pickle_obj

def plot_images_by_side(top_images, topK:int=5):
	index_values = list(top_images.index.values)
	images_paths = [top_images.iloc[idx].image_path for idx in index_values] 
	list_captions = [top_images.iloc[idx].caption for idx in index_values] 
	similarity_score = [top_images.iloc[idx].cos_sim for idx in index_values] 
	n_row = 1
	n_col = topK
	_, axs = plt.subplots(n_row, n_col, figsize=(18, 8))
	axs = axs.flatten()
	for img_path, ax, caption, sim_score in zip(images_paths, axs, list_captions, similarity_score):
		ax.imshow(Image.open(img_path).convert("RGB"))
		ax.axis("off")
		# sim_score = 100*float("{:.2f}".format(sim_score))
		ax.set_title(f"{caption}\nSimilarity: {sim_score:.3f}", fontsize=7)
	plt.show()



# Image Exploration

## Load Data

In [3]:
# https://huggingface.co/datasets/google-research-datasets/conceptual_captions
# DATASET_NAME = "train"
DATASET_NAME = "validation"
ds = load_dataset("conceptual_captions", split=DATASET_NAME,)

In [None]:
print(len(ds), type(ds),)
print(ds)

In [5]:
# Define the image directory
DATASET_DIRECTORY = "/home/farid/WS_Farid/ImACCESS/datasets/WW_DATASETs/conceptualized_captions"
IMAGE_DIRECTORY = os.path.join(DATASET_DIRECTORY, f"{DATASET_NAME}_images")
# IMAGE_DIRECTORY = os.path.join(DATASET_DIRECTORY, f"images")
os.makedirs(IMAGE_DIRECTORY, exist_ok=True)

def safe_get_image(url):
	try:
		# Add a user agent to prevent potential blocking
		headers = {
			'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
		}
		response = requests.get(url, timeout=20, headers=headers)
		response.raise_for_status()

		#  Create a unique filename for the image using a hash of the URL
		url_hash = hashlib.md5(url.encode()).hexdigest()  # Generate a unique hash for the URL
		image_name = os.path.join(IMAGE_DIRECTORY, f"{url_hash}.jpg")  # Use the hash as the filename

		if os.path.exists(image_name):
			# print(f"Image already exists: {image_name}")
			return image_name

		image = Image.open(BytesIO(response.content)).convert("RGB")
		image.save(image_name)
		return image_name
	except Exception as e:
		# print(f"{e}: {url}")
		return None

def parallel_download(df, max_workers=8):
	df = df.copy() # Create a copy of dataframe to avoid SettingWithCopyWarning
	with ThreadPoolExecutor(max_workers=max_workers) as executor:
		# Create a mapping of futures to URLs
		future_to_url = {executor.submit(safe_get_image, url): url for url in df['image_url']}
		image_paths = [None] * len(df) # Prepare a list to store image paths
		# Process completed futures
		for future in as_completed(future_to_url):
			url = future_to_url[future]
			try:
				index = df.index[df['image_url'] == url].tolist() # Find index of current URL
				if index:
					image_paths[index[0]] = future.result()
			except Exception as e:
				print(f"Error processing {url}: {e}")
	df['image_path'] = image_paths
	return df

def download_images(dataset, slice:int=1000):
	df_raw = pd.DataFrame(dataset[:slice])
	try:
		df = parallel_download(df_raw, max_workers=18)
		df = df[df['image_path'].notnull()] # Optional: Remove rows with failed downloads
		return df
	except Exception as e:
		print(f"Error in download_images: {e}")
		return df

In [None]:
NUM_SAMPLEs = len(ds) #500
metadata_fpath = os.path.join(DATASET_DIRECTORY, f"metadata_{DATASET_NAME}_x_{NUM_SAMPLEs}_raw_samples.gz")
csv_metadata_fname = os.path.join(DATASET_DIRECTORY, f"metadata_{DATASET_NAME}.csv")
try:
	df = pd.read_csv(
		filepath_or_buffer=csv_metadata_fname,
		on_bad_lines='skip',
	)
	df = load_pickle(fpath=metadata_fpath)
except Exception as e:
	print(e)
	df = download_images(dataset=ds, slice=NUM_SAMPLEs)
	df.to_csv(csv_metadata_fname, index=False)
	save_pickle(pkl=df, fname=metadata_fpath)

In [None]:
df.shape

In [None]:
df.isna().sum()

In [9]:
# del df
df_ = df.copy()

In [None]:
img_0_path = df.loc[100, "image_path"]
img_0_path

In [None]:
Image.open(img_0_path)

In [None]:
df.shape

## Approach 1: Hugging Face Transformers Library

In [None]:
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
def get_model_info(model_ID, device):
	model = CLIPModel.from_pretrained(model_ID, revision="main").to(device)
	processor = CLIPProcessor.from_pretrained(model_ID, revision="main")
	tokenizer = CLIPTokenizer.from_pretrained(model_ID, revision="main")
	return model, processor, tokenizer

# Set the device
model_ID = "openai/clip-vit-base-patch32"
transformers_clip_model, transformers_clip_processor, transformers_clip_tokenizer = get_model_info(model_ID, device)

### Text Embeddings

In [14]:
def get_single_text_embedding(text):
	inputs = transformers_clip_tokenizer(text, return_tensors = "pt").to(device)
	text_embeddings = transformers_clip_model.get_text_features(**inputs)
	embedding_as_np = text_embeddings.cpu().detach().numpy()
	return embedding_as_np
df["text_embeddings"] = df["caption"].apply(get_single_text_embedding)

In [None]:
df.head()

### Image Embeddings

In [None]:
def get_single_image_embedding(input_image, model=transformers_clip_model, processor=transformers_clip_processor, ):
	# print(type(input_image), input_image)
	if isinstance(input_image, str): # image path
		my_image = Image.open(input_image).convert("RGB")
	else:
		my_image = input_image
	if my_image.size!= (1, 1):  # Check if image size is not 1x1
		# print(my_image)
		image = processor(
			text=None,
			images=my_image, 
			return_tensors="pt"
		)["pixel_values"].to(device)
		embedding = model.get_image_features(image)
		embedding_as_np = embedding.cpu().detach().numpy()
		return embedding_as_np
	else:
		print(f"Skipping image {my_image} due to invalid size")
		return None
# Get all image embedddings:
df["image_embeddings"] = df["image_path"].apply(get_single_image_embedding)
df = df.dropna(subset=["image_embeddings"])

In [None]:
df.head()

In [None]:
print(df.loc[0, "text_embeddings"].shape, type(df.loc[0, "text_embeddings"]))
print(df.loc[0, "image_embeddings"].shape, type(df.loc[0, "image_embeddings"]))

### Cosine Similarity Search 

In [19]:
def get_top_N_images(query, data, search_criterion="text", top_K=5, ):
	print(type(query))
	"""
	Sort Cosine Similarity Column in Descending Order 
	Here we start at 1 to remove similarity with itself because it is always 1
	"""

	if(search_criterion.lower()=="text"): 
		query_vect = get_single_text_embedding(query) # Text to image Search
	else:
		query_vect = get_single_image_embedding(query) # Image to image Search
	
	# Run similarity Search
	data.loc[:, "cos_sim"] = data["image_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))
	data.loc[:, "cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
	topK_results = data.sort_values(by='cos_sim', ascending=False).iloc[:top_K]
	relevant_cols = ["caption", "image_path", "cos_sim"]
	return topK_results[relevant_cols].reset_index(drop=True)

### Text-to-Image(s) search

In [None]:
# query_caption = df.iloc[555].caption
query_caption = "aircraft"
print(f"Query: {query_caption}")

In [None]:
top_images = get_top_N_images(query=query_caption, data=df, search_criterion="text", top_K=5)
top_images

In [None]:
plot_images_by_side(top_images, topK=len(top_images))

### Image-to-Image search

In [23]:
# query_image = df.iloc[7].image
IMG_DIR = "/home/farid/WS_Farid/ImACCESS/TEST_IMGs"
############################# LOCAL COMPUTER #############################
# query_image = Image.open(os.path.join(IMG_DIR, "bikini.jpeg"))
query_image = Image.open(os.path.join(IMG_DIR, "6018_107470.jpg"))
# query_image = Image.open(os.path.join(IMG_DIR, "6170_107622.jpg"))
# query_image = Image.open(os.path.join(IMG_DIR, "6655_108104.jpg"))
# query_image = Image.open(os.path.join(IMG_DIR, "6247_107698.jpg"))
############################# LOCAL COMPUTER #############################

############################# URL #############################
# img_url = "https://www.finna.fi/Cover/Show?source=Solr&id=sa-kuva.sa-kuva-165318"
# img_url = "https://www.finna.fi/Cover/Show?source=Solr&id=sa-kuva.sa-kuva-37527"
# img_url = "https://digitalcollections.smu.edu/digital/api/singleitem/image/stn/779/default.jpg"
# img_url = "https://www.finna.fi/Cover/Show?source=Solr&id=sa-kuva.sa-kuva-37129"
# img_url = "https://www.finna.fi/Cover/Show?source=Solr&id=sa-kuva.sa-kuva-37599" # not so good results
# img_url = "https://www.finna.fi/Cover/Show?source=Solr&id=sa-kuva.sa-kuva-56563"
# try:
# 	response = requests.get(img_url)
# 	response.raise_for_status()
# 	resp_content = BytesIO(response.content)
# 	query_image = Image.open(resp_content)
# except Exception as e:
# 	print(e)


In [None]:
query_image

In [None]:
top_images = get_top_N_images(query=query_image, data=df, search_criterion="image")
top_images
plot_images_by_side(top_images, topK=len(top_images))

## Approach 2: Local CLIP | Cloned from GitHub

In [26]:
def get_topK_images(query, data, top_K=5, ):
	print(type(query))
	"""
	Sort Cosine Similarity Column in Descending Order 
	Here we start at 1 to remove similarity with itself because it is always 1
	"""
	print(type(query), query)
	if isinstance(query, str):
		query_vect = get_single_text_embedding(query) # Text to image Search
	else:
		query_vect = get_single_image_embedding(query) # Image to image Search
	print(data.columns)
	# Run similarity Search
	data.loc[:, "cos_sim"] = data["image_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))
	data.loc[:, "cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
	topK_results = data.sort_values(by='cos_sim', ascending=False).iloc[:top_K]
	relevant_cols = ["caption", "image_path", "cos_sim"]
	return topK_results[relevant_cols].reset_index(drop=True)

In [None]:
print(df_.shape)
df_.head()


In [28]:
import clip
model, preprocess = clip.load("ViT-B/32", device=device)

def get_txt_embedding(text):
	tokenized_text = clip.tokenize([text]).to(device)
	text_embedding = model.encode_text(tokenized_text)
	text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
	embedding_as_np = text_embedding.cpu().detach().numpy()
	return embedding_as_np

def get_img_embedding(img_path):
	img = Image.open(img_path)
	image = preprocess(img).unsqueeze(0).to(device)
	image_embedding = model.encode_image(image)
	image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
	embedding_as_np = image_embedding.cpu().detach().numpy()
	return embedding_as_np

In [None]:
df_["text_embeddings"] = df_["caption"].apply(get_txt_embedding)
df_["image_embeddings"] = df_["image_path"].apply(get_img_embedding)
df_ = df_.dropna(subset=["image_embeddings"])
df_.head()

#### image-to-image using local clip

In [None]:
top_images = get_topK_images(query=query_image, data=df)
plot_images_by_side(top_images, topK=len(top_images))

#### text-to-images using local clip

In [None]:
topk_images_ = get_topK_images(query=query_caption, data=df_)
topk_images_

In [None]:
# plotting
plot_images_by_side(topk_images_, topK=len(topk_images_))

---