In [None]:
from tqdm import tqdm
import tensorflow as tf
import tensorflow_datasets as tfds
import pickle
import textwrap
import os
import matplotlib.pyplot as plt
from matplotlib import gridspec
from PIL import Image
import numpy as np
from collections import namedtuple
import gdown

# Utility Classes

In [None]:
STRINGS_ENCODING = 'ISO-8859-1'

class Product:
    def __init__(self, p_id, name, caption, image, category, subcategory, pose="id_gridfs_1"):
        self.p_id = p_id
        self.pose = pose
        self.name = name
        self.caption = caption
        self.image = image
        self.category = category
        self.subcategory = subcategory
        self._image_features = None
        self._embedding = None
        self._caption_embedding = None

    @property
    def image_features(self):
        return self._image_features

    @image_features.setter
    def image_features(self, value):
        self._image_features = value

    @property
    def embedding(self):
        return self._embedding

    @embedding.setter
    def embedding(self, value):
        self._embedding = value

    @property
    def caption_embedding(self):
        return self._caption_embedding

    @caption_embedding.setter
    def caption_embedding(self, value):
        self._caption_embedding = value

    def decoded_caption(self):
        return self.caption.decode(STRINGS_ENCODING).replace('ÃÂÃÂÃÂÃÂ©', '')

    def __lt__(self, other):
        return self.p_id < other.p_id

    def __eq__(self, other):
        return self.p_id == other.p_id and self.pose == other.pose

    def __hash__(self):
        return hash((self.p_id, self.pose))

    def __str__(self):
        return f"product_id: {self.p_id}\nname: {self.name.decode(STRINGS_ENCODING) }\ncaption: {self.caption.decode(STRINGS_ENCODING) }\ncategory: {self.category.decode(STRINGS_ENCODING) } \nsubcategory: {self.subcategory.decode(STRINGS_ENCODING) }"


# Load Dataset

In [None]:
dataset_folder = "validation_dataset_w_embeddings"
if (not os.path.exists(dataset_folder)):
    gdown.download(f"https://drive.google.com/uc?id=1KcDFzv4JjuEQyyIvC7BkHgxN7LzbInPl", f"{dataset_folder}.tar.gz", False)
    !tar -xvf "{dataset_folder}.tar.gz"

In [None]:
# load dataset
validation_dataset = tf.data.experimental.load(dataset_folder, 
                                            (tf.TensorSpec(shape=(), dtype = tf.int32), #id
                                             tf.TensorSpec(shape=(), dtype = tf.string), #name 
                                             tf.TensorSpec(shape=(), dtype = tf.string), #category
                                             tf.TensorSpec(shape=(), dtype = tf.string), #subcategory
                                             tf.TensorSpec(shape=(), dtype = tf.string), #caption
                                             tf.TensorSpec(shape=(256,256, 3), dtype = tf.uint8), #image 
                                             tf.TensorSpec(shape=(131072,), dtype = tf.float32), #image features
                                             tf.TensorSpec(shape=(768,), dtype = tf.float32), #image_embedding
                                             tf.TensorSpec(shape=(768,), dtype = tf.float32)), #caption embedding
                                            compression="GZIP")

products_dict = {} # a dictionary that maps ids to products
for p in tqdm(tfds.as_numpy(validation_dataset)):
    product = Product(p_id = p[0], name = p[1], caption = p[4], image = p[5], category = p[2], subcategory = p[3])
    product.image_features = p[6]
    product.embedding = p[7]
    product.caption_embedding = p[8]
    products_dict[p[0]] = product

In [None]:
# Load candidate sets: lists of elements of the form: query_id -> [(candidate_id, label), ...].
# The relevant document has label=1, the other ones 0
txt2img_file = "txt2img.pkl"
img2txt_file = "img2txt.pkl"
#download files
if (not os.path.exists(txt2img_file)):
    gdown.download(f"https://drive.google.com/uc?id=101zaYhWws6CkWePEdBg5WVb9C8HXflC5", txt2img_file, False)
if (not os.path.exists(img2txt_file)):
    gdown.download(f"https://drive.google.com/uc?id=1IV61fyjUt7Pgve2t4ZuqS7XpAPw9vDLO", img2txt_file, False)

#load
txt2img_candidate_sets = pickle.load(open(txt2img_file, "rb"))
img2txt_candidate_sets = pickle.load(open(img2txt_file, "rb"))

# Rank@K Evaluation

In [None]:
def euclidean_distance(x, y, axis = None):
    return np.linalg.norm(x-y, axis = axis)

def compute_rank_at_k(sorted_documents, Ks):
    fount_at_top_k = {k:0 for k in Ks}
    for _, documents in sorted_documents.items():
        for i, document in enumerate(documents):
            if document[1]: #if label is equal to 1 we found the relevant document
                fount_at_top_k = {k:v + (1 if k>=i+1 else 0) for k, v in fount_at_top_k.items()}
                break # we can break since there will be no other relevant document
    return fount_at_top_k

Ks = [1, 5, 10]

txt2img_sorted_documents = {}
print("=== Text-to-Image Retrieval ===")
for query_id, candidates in txt2img_candidate_sets.items():
    caption_emb = np.array(products_dict[query_id].caption_embedding)
    image_embs = np.array([products_dict[candidate_id].embedding for candidate_id, _ in candidates])
    scores = np.array(euclidean_distance(caption_emb, image_embs, axis=1))
    sorted_indexes = np.argsort(scores)
    txt2img_sorted_documents[query_id] = list(map(candidates.__getitem__, list(sorted_indexes)))   
rank_at_k = compute_rank_at_k(txt2img_sorted_documents, Ks)            
for k in Ks:
    print(f"Rank @ {k}: {float(rank_at_k[k])/float(len(txt2img_sorted_documents) + 1e-5)}") 
                      
print("=== Image-to-Text Retrieval ===")
img2txt_sorted_documents = {}
for query_id, candidates in img2txt_candidate_sets.items():
    caption_embs = np.array([products_dict[candidate_id].caption_embedding for candidate_id, _ in candidates])
    image_emb = np.array(products_dict[query_id].embedding)
    scores = np.array(euclidean_distance(image_emb, caption_embs, axis=1))
    sorted_indexes = np.argsort(scores)
    img2txt_sorted_documents[query_id] = list(map(candidates.__getitem__, list(sorted_indexes)))    
rank_at_k = compute_rank_at_k(img2txt_sorted_documents, Ks)
for k in Ks:
    print(f"Rank @ {k}: {float(rank_at_k[k])/float(len(img2txt_sorted_documents) + 1e-5)}") 

# Query visualization

In [None]:
def visualize_product_image(p, ax, color_id):
    image = Image.fromarray(p.image, 'RGB')
    ax.imshow(image)
    if (p.p_id==color_id):
        ax.axis('on')
        ax.set_xticks([])
        ax.set_yticks([])
        for s in ax.spines:
            ax.spines[s].set_color('green')
            ax.spines[s].set_linewidth(3.5)
    
def visualize_product_caption(p, ax, color_id):
    caption = p.decoded_caption()
    props = dict(boxstyle='round', facecolor="white")
    if (p.p_id==color_id):
        props["facecolor"]="green"
        props["alpha"] = 0.5
    ax.text(0.1, 0.9, textwrap.fill(caption, 32), transform=ax.transAxes, fontsize=20, verticalalignment='top', bbox=props)


## Visualize Retrieval Examples

In [None]:
NUM_QUERIES = 5
N_RESULTS = 5
TASK = "txt2img" # change with "img2txt"
SAVE_RESULT = False

if TASK=="txt2img":
    documents_dict = txt2img_sorted_documents
    query_visualization_function = visualize_product_caption
    document_visualization_function = visualize_product_image
elif TASK=="img2txt":
    documents_dict = img2txt_sorted_documents
    query_visualization_function = visualize_product_image
    document_visualization_function = visualize_product_caption


for i, (query_id, documents) in zip(range(NUM_QUERIES), documents_dict.items()):
    fig, axes = plt.subplots(ncols=N_RESULTS+1, nrows=1, constrained_layout=True, sharex=True, sharey=True, figsize=(32,32))
    product = products_dict[query_id]
    retrieved = [products_dict[d] for d,_ in documents]
    for col, ax in enumerate(axes):
        ax.axis('off')
        ax.set_aspect('equal')
        if(col==0):
            query_visualization_function(product, ax, None)
        else:
            document_visualization_function(retrieved[col-1], ax, product.p_id) 
    
    if SAVE_RESULT:
        fig.savefig(f"{TASK}_{i}.svg", format="svg", bbox_inches = 'tight')
