In [41]:
# Importing necessary modules

from transformers import AutoModel, AutoTokenizer, CLIPProcessor, CLIPModel, AutoProcessor, CLIPVisionModel, CLIPTextModel, CLIPVisionModelWithProjection
import csv
import json
import time
import os
import torch
import clip
from PIL import Image
import requests

In [42]:
device = "cuda" if torch.cuda.is_available() else "cpu"

text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") # --> returns 512d vector
# vision_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") --> returns 768d vector
# We use vision model followed by a linear projection 
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") # --> returns 512d vector
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") # for text tokenization
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") # for image preprocessing

# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # for image + text 

In [50]:
# Here we create a list of images from dir

image_dir = './mm_retrieval_project/images'
images = []
images_ids = []
for image_name in os.listdir(image_dir):
    if image_name.endswith('.jpg'):
        images_ids.append(image_name.split('.')[0])
        image = Image.open(f'{image_dir}/{image_name}')
        images.append(image)

In [44]:
documents_file = './mm_retrieval_project/document.csv'
query_file = './mm_retrieval_project/query.csv'

documents = []
documents_dict = {}
with open(documents_file, 'r') as f:
    reader = csv.reader(f)
    for line in reader:
        documents.append(line)
    for line in documents[1:]:
        documents_dict.update({line[0]: line[1]})
    

queries = []
queries_dict = {}
with open(query_file, 'r') as f:
    reader = csv.reader(f)
    for line in reader:
        queries.append(line)
    for line in queries[1:]:
        queries_dict.update({line[0]: line[1]})


In [72]:
# Getting text representations for documents and queries 
docs = [doc[1] for doc in documents[1:]]

def get_text_representations(texts):
    inputs = tokenizer(texts, padding=True, return_tensors="pt", truncation=True) # Downside -> context window for text in CLIP is 77
    outputs = text_model(**inputs)
    last_hidden_state = outputs.last_hidden_state
    pooled_output = outputs.pooler_output  # pooled (EOS token) states
    return pooled_output

# Creating docs_representation var to store embeddings
docs_representation = get_text_representations(docs)

predictions = []

for query_id, query in queries_dict.items():
    # print(f'{query_id} -- {query}')
    query_represtation = get_text_representations(query)

    # Computing matrix product for represntations
    document_scores = (docs_representation @ query_represtation.T).tolist()
    max_idx = document_scores.index(max(document_scores)) + 1
    
    # print(f'{documents[max_idx][0]} -- {documents[max_idx][1]}')
    predictions.append([query_id, documents[max_idx][0]])

In [71]:
# Getting represetations for images 
import time
top_k_images = 5

def get_image_representations(images):
    inputs = processor(images=images, return_tensors="pt")
    outputs = vision_model(**inputs)
    image_embeds = outputs.image_embeds
    return image_embeds

image_representation = get_image_representations(images)

query_image_pairs = []
for query_id, query in queries_dict.items():
    # print(f'{query_id} -- {query}')
    query_represtation = get_text_representations(query)

    # Computing matrix product for represntations
    image_scores = (image_representation @ query_represtation.T).tolist()
    indexed_scores = list(enumerate(image_scores))
    sorted_scores = sorted(indexed_scores, key=lambda x: x[1], reverse=True)

    top_k_indices = [index for index, _ in sorted_scores[:top_k_images]]    
    top_k_similar_images = [images_ids[idx] for idx in top_k_indices]
    
    query_image_pairs.append([query_id, top_k_similar_images])
    # print(f'{documents[max_idx][0]} -- {documents[max_idx][1]}')

In [None]:
import time
query_image_pairs = []
for query_id, query in queries_dict.items():

    query_represtation = get_text_representations(query)

    # Computing matrix product for represntations
    image_scores = (image_representation @  query_represtation.T).tolist()
    indexed_scores = list(enumerate(image_scores))
    sorted_scores = sorted(indexed_scores, key=lambda x: x[1], reverse=True)

    top_k_indices = [index for index, _ in sorted_scores[:top_k_images]]    
    top_k_similar_images = [images_ids[idx] for idx in top_k_indices]

    query_image_pairs.append([query_id, top_k_similar_images])

    # for imm in top_k_similar_images:
    #     image = Image.open(f'{image_dir}/{imm}.jpg')
    #     display(image)

    # print(f'{documents[max_idx][0]} -- {documents[max_idx][1]}')

In [73]:
for idx in range(len(predictions)):
    predictions[idx] += query_image_pairs[idx][1]

In [74]:
with open('./mm_retrieval_project/clip-predictions.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['qid', 'doc_id', 'image_ids'])
    writer.writerows(predictions)

In [None]:
# dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
# dict_keys(['logits_per_image', 'logits_per_text', 'text_embeds', 'image_embeds', 'text_model_output', 'vision_model_output'])