In [1]:
import faiss
import json
import pandas as pd
from PIL import Image
from sentence_transformers import SentenceTransformer
from IPython.display import display
import os
import sys

sys.path.append("../")

from data.coco import get_caption
from services.settings import settings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
output_file = "/usr/prakt/s0070/vlm-based-image-search/outputs/response_dict.json"

coco_path = os.path.join(settings.data_dir, "coco/images/val2017/")

response = json.load(open(output_file))
captions = list(response.values())
files = list(response.keys())



# Methods

In [15]:
def encode_sentences(sentences):
    sentence_embeddings = model.encode(sentences)
    return sentence_embeddings

def get_filenames(indices):
    return [files[i] for i in indices]

def create_gt_captions_dict(files):
    return {file: get_caption(file) for file in files}

def display_file(filename):
    display(Image.open(coco_path+filename))
    
def create_index(sentence_embeddings):
    d = sentence_embeddings.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(sentence_embeddings)
    print("Index created with {} sentences".format(index.ntotal))
    return index

def search_index(index, query, k):
    query_embedding = encode_sentences([query])
    D, I = index.search(query_embedding, k)
    return D, I

# Search

In [7]:
caption_embeddings = encode_sentences(captions)
index = create_index(caption_embeddings)

Index created with 500 sentences


In [16]:
gt_dict = create_gt_captions_dict(files)
git_embeddings = encode_sentences(list(gt_dict.values()))
gt_index = create_index(git_embeddings)

Index created with 500 sentences


In [29]:
query = "Bathroom environment"
D, I = search_index(index, query, k=10)
gt_D, gt_I = search_index(gt_index, query, k=10)

# Search with predicted captions

In [None]:
retrieved_files = get_filenames(I[0])
for file in retrieved_files:
    display_file(file)
    print("Predicted caption: ", response[file])
    print("Actual caption: ", get_caption(file))
    print()
    print("\n")

# Search with ground truth captions

In [None]:
gt_retrieved_files = get_filenames(gt_I[0])
for file in gt_retrieved_files:
    display_file(file)
    print("Predicted caption: ", response[file])
    print("Actual caption: ", get_caption(file))
    print()
    print("\n")

# Eval

In [38]:
len(list(set(retrieved_files) & set(gt_retrieved_files))) # 6

6