In [71]:
import pandas as pd
# Load your DataFrame
pew = pd.read_csv('../dataset/pew_dataset/metadata.csv')
pew['imgPath'] = pew['imgPath'].str.replace('imgs', '../dataset/pew_dataset/pew_imgs')
statista = pd.read_csv('../dataset/statista_dataset/metadata.csv')
statista['imgPath'] = statista['imgPath'].str.replace('out/two_col/imgs', '../dataset/statista_dataset/statista_imgs')
columns = ['title','caption','imgPath']

# Filtering the DataFrame to include only the specified columns
pew_df = pew[columns]
statista_df = statista[columns]
combined_df = pd.concat([pew_df, statista_df], ignore_index=True)

# Add a new column 'ID' to the DataFrame at the first position
combined_df.insert(0, 'id', combined_df.reset_index().index + 1)

combined_df.head(100)


Unnamed: 0,id,title,caption,imgPath
0,1,"Foreign-born population in the United States, ...",The foreign-born population residing in the U....,../dataset/pew_dataset/pew_imgs/1.png
1,2,"English proficiency among U.S. immigrants, 198...","Since 1980, the share of immigrants who are pr...",../dataset/pew_dataset/pew_imgs/2.png
2,3,"Languages spoken among U.S. immigrants, 2018","Among the nation’s immigrants, Spanish is by f...",../dataset/pew_dataset/pew_imgs/3.png
3,4,"Hispanic population in the U.S., 2000-2017",There were nearly 60 million Latinos in the Un...,../dataset/pew_dataset/pew_imgs/4.png
4,5,Weekly broadcast audience for top 20 NPR-affil...,The top 20 NPR-affiliated public radio station...,../dataset/pew_dataset/pew_imgs/5.png
...,...,...,...,...
95,96,About four-in-ten Hispanics reported experienc...,"Shortly before the outbreak, about four-in-ten...",../dataset/pew_dataset/pew_imgs/96.png
96,97,Far more Americans favor keeping spending on p...,The survey finds little support for reducing s...,../dataset/pew_dataset/pew_imgs/97.png
97,98,U.S Hispanic population reached nearly 61 mill...,The U.S. Hispanic population reached a record ...,../dataset/pew_dataset/pew_imgs/98.png
98,99,U.S Hispanic population growth has slowed Aver...,Population growth among U.S. Hispanics has slo...,../dataset/pew_dataset/pew_imgs/99.png


In [84]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm

# Initialize CLIP model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def generate_embedding(image, text):
    # Prepare the input for one image and one text
    inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        output = model(**inputs)
        image_embedding = output.image_embeds.squeeze(0)  # Remove batch dimension if needed
        text_embedding = output.text_embeds.squeeze(0)    # Remove batch dimension if needed
        return image_embedding, text_embedding
    
image_embeddings = []
text_embeddings = []

# Extract image paths and titles for the first 100 items
for image_path, title in tqdm(zip(combined_df['imgPath'][:100], combined_df['title'][:100]), total=100):
    image = Image.open(image_path)
    img_emb, txt_emb = generate_embedding(image, title)
    image_embeddings.append(img_emb)
    text_embeddings.append(txt_emb)

# Convert list of tensors to a single tensor for further processing if necessary
image_embeddings = torch.stack(image_embeddings)
text_embeddings = torch.stack(text_embeddings)

# Example text query
text_query = "What is the Hispanic population in the United States?"
dummy_image = Image.new("RGB", (224, 224), color=(255, 255, 255))

# Generate embedding for the text query with a dummy image
_, query_embedding = generate_embedding(dummy_image, text_query)
query_embedding = query_embedding.unsqueeze(0)  # Add batch dimension for cosine similarity calculation

# Normalize embeddings
normalized_image_embeddings = image_embeddings / torch.norm(image_embeddings, dim=1, keepdim=True)
normalized_text_embeddings = text_embeddings / torch.norm(text_embeddings, dim=1, keepdim=True)
normalized_query_embedding = query_embedding / torch.norm(query_embedding, dim=1, keepdim=True)

# Calculate cosine similarities between the query embedding and all image embeddings
# Calculate cosine similarities
image_similarities = cosine_similarity(normalized_query_embedding.cpu().numpy(), normalized_image_embeddings.cpu().numpy())
text_similarities = cosine_similarity(normalized_query_embedding.cpu().numpy(), normalized_text_embeddings.cpu().numpy())

# Now, you can use these similarities separately or combine them in some way depending on your application.
# For example, you might want to average them or use a weighted sum to find the most relevant items.
combined_similarities = (image_similarities + text_similarities) / 2

# Find and print the most similar images based on combined similarities
top_indices = np.argsort(combined_similarities[0])[::-1][:3]
for index in top_indices:
    print(f"Image {combined_df['imgPath'].iloc[index]} has a combined similarity score of {combined_similarities[0][index]}")

100%|██████████| 100/100 [00:15<00:00,  6.50it/s]


Image ../dataset/pew_dataset/pew_imgs/99.png has a combined similarity score of 0.6033253073692322
Image ../dataset/pew_dataset/pew_imgs/98.png has a combined similarity score of 0.574662446975708
Image ../dataset/pew_dataset/pew_imgs/71.png has a combined similarity score of 0.574662446975708
