In [3]:
import os
import numpy as np
import torch
from clip_retrieval.clip_client import ClipClient
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import clip

In [4]:
client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion5B-H-14")

model_name = ['ViT-L/14','RN50x16']

device = 'cuda' if torch.cuda.is_available() else 'cpu'
models = []
preprocesses = []
for model_name in model_name:
    model, preprocess = clip.load(model_name, device=device)
    models.append(model)
    preprocesses.append(preprocess)

In [5]:
items = ['shoe', 'handbag', 'nail polish', 'hat', 't shirt', 'coat','perfume']
brands = ['gucci', 'prada', 'chanel', 'dior', 'versace', 'nike', 'puma',
         'adidas', 'ralph lauren', 'armani', 'dolce & gabbana',
         'max factor', 'loreal', '']
colors = ['red', 'blue', 'green', 'yellow', 'orange', 'purple', 'pink', 'black',
                'white', 'grey', 'brown', 'beige', 'gold', 'silver', 'multicolor']

full_description = [f"{color} {brand} {item}" for item in items for brand in brands for color in colors]

descriptions_tokens = []

descriptions_tokens = clip.tokenize(full_description).to(device)

In [6]:
# Loop over the descriptions and embed them.
text_embeddings = []
n = 2048
for i in range(len(models)):
    with torch.no_grad():
        temp_embeddings = []
        for j in tqdm(range(0, len(descriptions_tokens), n)):
            embeds = models[i].encode_text(descriptions_tokens[j:j+n])
            temp_embeddings.append(embeds)
    text_embeddings.append(torch.cat(temp_embeddings, dim=0))

data_embeddings = torch.stack(text_embeddings)

100%|██████████| 1/1 [00:02<00:00,  2.69s/it]
100%|██████████| 1/1 [00:01<00:00,  1.31s/it]


In [13]:
data_embeddings = data_embeddings.to('cuda')

In [9]:
del temp_embeddings, text_embeddings, descriptions_tokens

In [20]:
imgs = os.listdir('temp_imgs')
len(imgs)

8366

In [29]:
img_scores = []
img_tags = []
for i in tqdm(range(len(imgs))):
    img = Image.open(f'temp_imgs/{imgs[i]}')
    #img = preprocesses[0](img).unsqueeze(0).to(device)
    top_3_scores = 0
    tags = []
    for model_i in range(len(models)):
        with torch.no_grad():
            query_tokens = preprocesses[model_i](img).unsqueeze(0).to(device)   
            query_embedding = models[model_i].encode_image(query_tokens)

            similarity = torch.nn.functional.cosine_similarity(data_embeddings[model_i], query_embedding)
            # get the top 3 most simialr data points
            top_10 = torch.topk(similarity, 10).indices
            scores = torch.topk(similarity, 10).values
            tags.append(full_description[top_10[0]])
        top_3_scores += torch.sum(scores[:3]).item()
    img_tags.append(tags)
    img_scores.append(top_3_scores)

100%|██████████| 8366/8366 [10:48<00:00, 12.90it/s]


In [34]:
i = 0
for score in img_scores:
    if score < 1.0:
        i += 1
print(round(i/len(img_scores)*100, 2))

0.57
