In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import random as rnd
from sklearn.manifold import TSNE
from tqdm import tqdm
import tensorflow_datasets as tfds
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, TextArea
import numpy as np
import os
import gdown
import textwrap

In [None]:
STRINGS_ENCODING = 'ISO-8859-1'

class Product:
    def __init__(self, p_id, name, caption, image, category, subcategory, pose="id_gridfs_1"):
        self.p_id = p_id
        self.pose = pose
        self.name = name
        self.caption = caption
        self.image = image
        self.category = category
        self.subcategory = subcategory
        self._image_features = None
        self._embedding = None
        self._caption_embedding = None

    @property
    def image_features(self):
        return self._image_features

    @image_features.setter
    def image_features(self, value):
        self._image_features = value

    @property
    def embedding(self):
        return self._embedding

    @embedding.setter
    def embedding(self, value):
        self._embedding = value

    @property
    def caption_embedding(self):
        return self._caption_embedding

    @caption_embedding.setter
    def caption_embedding(self, value):
        self._caption_embedding = value

    def decoded_caption(self):
        return self.caption.decode(STRINGS_ENCODING).replace('ÃÂÃÂÃÂÃÂ©', '')

    def __lt__(self, other):
        return self.p_id < other.p_id

    def __eq__(self, other):
        return self.p_id == other.p_id and self.pose == other.pose

    def __hash__(self):
        return hash((self.p_id, self.pose))

    def __str__(self):
        return f"product_id: {self.p_id}\nname: {self.name.decode(STRINGS_ENCODING) }\ncaption: {self.caption.decode(STRINGS_ENCODING) }\ncategory: {self.category.decode(STRINGS_ENCODING) } \nsubcategory: {self.subcategory.decode(STRINGS_ENCODING) }"


In [None]:
dataset_folder = "validation_dataset_w_embeddings"
if (not os.path.exists(dataset_folder)):
    gdown.download(f"https://drive.google.com/uc?id=1KcDFzv4JjuEQyyIvC7BkHgxN7LzbInPl", f"{dataset_folder}.tar.gz", False)
    !tar -xvf "{dataset_folder}.tar.gz"
    
validation_dataset = tf.data.experimental.load("validation_dataset_w_embeddings", 
                                            (tf.TensorSpec(shape=(), dtype = tf.int32), #id
                                             tf.TensorSpec(shape=(), dtype = tf.string), #name 
                                             tf.TensorSpec(shape=(), dtype = tf.string), #category
                                             tf.TensorSpec(shape=(), dtype = tf.string), #subcategory
                                             tf.TensorSpec(shape=(), dtype = tf.string), #caption
                                             tf.TensorSpec(shape=(256,256, 3), dtype = tf.uint8), #image 
                                             tf.TensorSpec(shape=(131072,), dtype = tf.float32), #image features
                                             tf.TensorSpec(shape=(768,), dtype = tf.float32), #image_embedding
                                             tf.TensorSpec(shape=(768,), dtype = tf.float32)), #caption embedding
                                            compression="GZIP")
products_list = [] # a list of products
for p in tqdm(tfds.as_numpy(validation_dataset)):
    product = Product(p_id = p[0], name = p[1], caption = p[4], image = p[5], category = p[2], subcategory = p[3])
    #product.image_features = p[6]
    product.embedding = p[7]
    product.caption_embedding = p[8]
    products_list.append(product)

In [None]:
captions_ids = [91765, 108990, 88120, 1582273]
caption_embs = [p.caption_embedding for p in products_list if p.p_id in captions_ids]
tsne = TSNE(n_components=2,perplexity=65, early_exaggeration=12.0, n_iter=2000, random_state=42, learning_rate=200, init='pca')

tsne = tsne.fit_transform(np.concatenate(([p.embedding for p in products_list], caption_embs)))

In [None]:

def textscatter(x, y, queries, box_alignments, ax=None, text_size = 15):
  if ax is None:
    ax = plt.gca()
    
  artists = []
  for x0, y0, q, box_align in zip(x, y, queries, box_alignments):
    ax.plot(x0,y0, ".r", markersize=120)
    offsetbox = TextArea(f"{textwrap.fill(q, width=42)}", minimumdescent=False, textprops ={"size":text_size})
          
        
    ab = AnnotationBbox(offsetbox, (x0, y0),
                      xycoords='data',
                      boxcoords=None,
                      box_alignment=box_align,
                      arrowprops=dict(arrowstyle="->"))
    ab.set_zorder(-1)
    artists.append(ax.add_artist(ab))

  ax.update_datalim(np.column_stack([x, y]))
  ax.autoscale()
  return artists


def imscatter(x, y, images, ax=None, zoom=1):
    if ax is None:
        ax = plt.gca()
    artists = []
    for x0, y0, i in zip(x, y, images):
        im = OffsetImage(i, zoom=zoom)
        #x, y = np.atleast_1d(x, y)
        ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
        ab.set_zorder(-2)
        artists.append(ax.add_artist(ab))
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists

ROUND_COORDINATES = True
fig, ax = plt.subplots(figsize=(160,100))
plt.axis('off')

num_images = len(products_list)
if not ROUND_COORDINATES:
    imscatter(tsne[:num_images, 0], tsne[:num_images, 1], [p.image for p in products_list], ax, zoom=0.3)
else:
    imscatter(list(map(round, tsne[:,0])), list(map(round, tsne[:,1])),  [p.image for p in products_list], ax, zoom=0.3)

textscatter(tsne[num_images:, 0], tsne[num_images:, 1], 
            [p.decoded_caption() for p in products_list if p.p_id in captions_ids], 
            [(1.5, 1.5), (-0.3, -0.3), (1.5, 5.5), (0.5, 4)], ax, text_size=100)
plt.show()