Takeaway: perfect long_form_product_title_text <> product_image does not mean it can be directly used for query <> product

In [None]:
from tqdm import tqdm

In [31]:
from sentence_transformers import SentenceTransformer, util
from PIL import Image, ImageFile
import requests
import torch

# We use the original clip-ViT-B-32 for encoding images
img_model = SentenceTransformer('clip-ViT-B-32')

# Our text embedding model is aligned to the img_model and maps 50+
# languages to the same vector space
text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1')


# Now we load and encode the images
def load_image(url_or_path):
    if url_or_path.startswith("http://") or url_or_path.startswith("https://"):
        return Image.open(requests.get(url_or_path, stream=True).raw)
    else:
        return Image.open(url_or_path)


ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [32]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [33]:
img_model.cuda()
text_model.cuda()
tmp = 2

In [35]:
import pandas as pd 
import dvc.api 

df_clipmore_test = next(pd.read_json(dvc.api.get_url(
    "data/wish_clipmore/Wish_Clipmore_Tahoe_Train_Dedup.json",
    repo='git@github.com:ContextLogic/multitask-llm-rnd.git'
), lines=True, chunksize=10000))

df_clipmore_test['img_url'] = df_clipmore_test['product_id'].apply(
    lambda x: f"https://canary.contestimg.wish.com/api/webimage/{x}-large.jpg")

In [39]:
df_clipmore_test = df_clipmore_test.head(1000)

In [None]:

# We load 3 images. You can either pass URLs or
# a path on your disc
img_paths = df_clipmore_test['img_url'].tolist()

images = [load_image(img) for img in tqdm(img_paths)]

# Map images to the vector space
img_embeddings = img_model.encode(images, show_progress_bar=True)

# Now we encode our text:
texts = df_clipmore_test.title.tolist()
text_embeddings = text_model.encode(texts, show_progress_bar=True)


In [41]:

# Compute cosine similarities:
cos_sim = util.cos_sim(text_embeddings, img_embeddings)
c = 0
for text, scores in zip(texts, cos_sim):
    max_img_idx = torch.argmax(scores)
    print("Text:", text)
    print("Score:", scores[max_img_idx] )
    print("Path:", img_paths[max_img_idx], "\n")
    c += 1
    if c == 5:
        break

Text: 3D Printed Tokyo Ghoul T-shirt Summer Mens Anime Large Size Short-sleeve T-shirt
Score: tensor(0.3213)
Path: https://canary.contestimg.wish.com/api/webimage/60335ab7e8cd71d755304a79-large.jpg 

Text: Eyebrow Trimmer with Eyebrow Comb Eyebrow Trimmer Makeup Scissors Beauty Scissors
Score: tensor(0.2861)
Path: https://canary.contestimg.wish.com/api/webimage/5ee85d62255ab0063d8bf67a-large.jpg 

Text: 925 Sterling silver Natural Malachite Peridot Oval Pendant Pure Jewelry
Score: tensor(0.3046)
Path: https://canary.contestimg.wish.com/api/webimage/6176b0b655a3e8e305baa5a3-large.jpg 

Text: Pistola Para Pintar 0.9 Litros Bdph1200- B3 Black And Decker
Score: tensor(0.2926)
Path: https://canary.contestimg.wish.com/api/webimage/61b74eed7e09afd6837142b4-large.jpg 

Text: 100Pcs/Set Ballerina Beauty Tools DIY UV Gel Coffin Fake Nails Manicure False Nail Tips Full Cover
Score: tensor(0.3579)
Path: https://canary.contestimg.wish.com/api/webimage/5efec170190d93004929df75-large.jpg 



In [42]:
lab = df_clipmore_test['v121_category'].apply(lambda x: tuple(x))

In [43]:
from sklearn.metrics import silhouette_score

In [44]:
silhouette_score(text_embeddings, lab)

-0.078640245

In [45]:
silhouette_score(img_embeddings, lab)

-0.06835164

In [46]:
silhouette_score(img_embeddings + text_embeddings, lab)

-0.048485573

In [48]:
from sklearn.preprocessing import normalize

In [None]:
import umap.plot
import umap
umap.plot.output_notebook()
import numpy as np
from bokeh.plotting import figure, output_file, show, ColumnDataSource
from bokeh.models import HoverTool
from bokeh.transform import factor_cmap
from bokeh.palettes import Category20

hidden_states = normalize(text_embeddings + img_embeddings)

mapper = umap.UMAP().fit(hidden_states)
proj_data = mapper.transform(hidden_states)


output_file("toolbar_clip.html")

source = ColumnDataSource(
        data=dict(
            x=proj_data[:,0],
            y=proj_data[:,1],
            desc=df_clipmore_test['title'].tolist(),
            cat=df_clipmore_test['v121_category'].apply(lambda x: " > ".join(x)).tolist(),
            cat_zero=df_clipmore_test['v121_category'].apply(lambda x: x[0]).tolist(),
            imgs = df_clipmore_test['img_url'].tolist()
        )
    )

hover = HoverTool(
        tooltips="""
        <div>
            <div>
                <img
                    src="@imgs" height="100" alt="@imgs" width="100"
                    style="float: left; margin: 0px 15px 15px 0px;"
                    border="2"
                ></img>
            </div>
            <div>
                <span style="font-size: 17px; font-weight: bold;">@desc</span>
                <span style="font-size: 17px; font-weight: bold;">>>>></span>
                <span style="font-size: 17px; font-weight: bold;">@cat</span>
                <span style="font-size: 15px; color: #966;">[$index]</span>
            </div>
            <div>
                <span style="font-size: 15px;">Location</span>
                <span style="font-size: 10px; color: #696;">($x, $y)</span>
            </div>
        </div>
        """
    )

p = figure(plot_width=1200, plot_height=800, tools=[hover],
           title="Mouse over the dots")
cat0 = list(set(df_clipmore_test['v121_category'].apply(lambda x: x[0]).tolist()))
p.circle('x', 'y', size=5, color=factor_cmap('cat_zero', palette=Category20[20], factors=cat0), source=source)


In [None]:
show(p)

# manual search

In [51]:
img_embeddings


array([[ 0.06057785,  0.44280058,  0.2951861 , ...,  0.40780723,
        -0.01515003,  0.28453058],
       [-0.69512296, -0.2774524 ,  0.3453142 , ...,  0.38775447,
        -0.14922294,  0.08063488],
       [-0.858289  ,  0.2887728 , -0.11065876, ...,  0.6694223 ,
         0.1941432 ,  0.36045936],
       ...,
       [-0.669177  , -0.00928982,  0.13870558, ...,  0.15982851,
        -0.35338193, -0.03364283],
       [-0.35870048,  0.01317909, -0.0221952 , ...,  0.7985848 ,
        -0.62996936,  0.35887843],
       [-0.04864804, -0.09387708, -0.19785263, ...,  0.40523928,
         0.4519784 , -0.393081  ]], dtype=float32)

In [113]:
qs = [ 
    'Black Tee',
    'Black T shirt',
    'White Painting',
    'Underwear'
]
qs = [f'{i}' for i in qs]

In [114]:
query_embeddings = text_model.encode(qs, show_progress_bar=True)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [115]:
from rank_bm25 import BM25Okapi

corpus = texts

tokenized_corpus = [doc.split(" ") for doc in corpus]

bm25 = BM25Okapi(tokenized_corpus)

In [116]:
bm_sim = np.vstack([bm25.get_scores(q.split(" ")) for q in qs])

In [118]:
cos_sim = util.cos_sim(query_embeddings, text_embeddings) + util.cos_sim(query_embeddings, img_embeddings) + 0 * bm_sim / bm_sim.max() / 8
c = 0
for text, scores in zip(qs, cos_sim):
    max_img_idx = torch.argmax(scores)
    print("Query:", text)
    print("Title:", texts[max_img_idx])
    print("Score:", scores[max_img_idx] )
    print("Path:", img_paths[max_img_idx], "\n")
    

Query: Black Tee
Title: Sold Out
Score: tensor(1.1393, dtype=torch.float64)
Path: https://canary.contestimg.wish.com/api/webimage/5822d0f703bcd11b5be4fea7-large.jpg 

Query: Black T shirt
Title: Mom Sold My Bike Black T-shirt - Super - Motorcycle T-shirt
Score: tensor(1.1202, dtype=torch.float64)
Path: https://canary.contestimg.wish.com/api/webimage/59fb2cdf86ac5b2bd86219b4-large.jpg 

Query: White Painting
Title: Sade White T shirt
Score: tensor(1.0927, dtype=torch.float64)
Path: https://canary.contestimg.wish.com/api/webimage/5e832c560e84c35e251d8f42-large.jpg 

Query: Underwear
Title: King Men Casual Drawstring Joggers Sweatpants Cotton Pants 4XL
Score: tensor(1.1035, dtype=torch.float64)
Path: https://canary.contestimg.wish.com/api/webimage/60ae106bd81c4a382c6c3d1d-large.jpg 

