In [2]:
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import gc
from PIL import ImageFile
import torch
import open_clip
from PIL import Image


ImageFile.LOAD_TRUNCATED_IMAGES = True
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x16e8a35b0>

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')

batch_size = 100
all_images_features = []
for i in tqdm(range(0, len(all_images[:20000]), batch_size)):
    batch_images = all_images[i:i+batch_size]
    batch_images = [preprocess(Image.open('assets/' + img_path)).unsqueeze(0) for img_path in batch_images]
    batch_images = torch.cat(batch_images, 0)
    with torch.no_grad():
        batch_features = model.encode_image(batch_images)
    all_images_features.extend(batch_features)
    del batch_images, batch_features  # delete the variables
    gc.collect()

print(all_images_features)

In [None]:
all_images_features_tensor = torch.stack(all_images_features)
all_images_features_tensor /= all_images_features_tensor.norm(dim=-1, keepdim=True)

def plotKNN(base_embedding, nrow=5, ncol=5, figsize=(10, 10), title=""):
    probs = torch.nn.functional.cosine_similarity(all_images_features_tensor, base_embedding.view(1, 512))

    top_indices = probs.topk(nrow * ncol).indices
    top_images = [all_images[i] for i in top_indices]
    top_scores = probs[top_indices]

    fig, axs = plt.subplots(nrow, ncol, figsize=figsize)
    fig.suptitle(f"{title}", fontsize=20, fontname='Comic Sans MS')
    
    for i, (img, score, ix) in enumerate(zip(top_images, top_scores, top_indices)):
        img_path = 'assets/' + img
        image = Image.open(img_path)
        axs[i//ncol][i%ncol].imshow(image)
        axs[i//ncol][i%ncol].set_title(f"{int(score * 100)/100}, {ix}, {products[ix]['price']}")

        axs[i//ncol][i%ncol].axis('off')

    plt.show()
    for ix in top_indices:
        url = 'https://www.myntra.com/' + products[ix]['landingPageUrl']
        print(f'{ix}\t{url}')

def mashup(image_ixs, nrow=5, ncol=5, figsize=(10, 10)):
    mean_embedding = all_images_features_tensor[torch.tensor(image_ixs)].mean(dim=0)
    mean_embedding /= mean_embedding.norm()

    plotKNN(mean_embedding, nrow, ncol, figsize, title='Mashup :' + str(image_ixs))

def ask(query, nrow=5, ncol=5, figsize=(10, 10)):
    text = tokenizer([query])
    text_features = model.encode_text(text)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    plotKNN(text_features, nrow, ncol, figsize, title=query)

def similar(image_ix, nrow=5, ncol=5, figsize=(10, 10)):
    base_image_feature = all_images_features_tensor[image_ix]
    base_image_feature /= base_image_feature.norm()
    plotKNN(base_image_feature, nrow, ncol, figsize, title=f'Similar: {image_ix}')

def mashup_all(image_ixs, queries, nrow=5, ncol=5, figsize=(20, 20)):
    image_mean_embedding = all_images_features_tensor[torch.tensor(image_ixs)].mean(dim=0)
    image_mean_embedding /= image_mean_embedding.norm()

    query_embedding = model.encode_text(tokenizer(queries))
    query_embedding /= query_embedding.norm(dim=-1, keepdim=True)
    query_embedding = query_embedding.mean(dim=0)

    final_embedding = torch.stack([image_mean_embedding, query_embedding]).mean(dim=0)
    final_embedding /= final_embedding.norm(dim=0)
    plotKNN(final_embedding, nrow, ncol, figsize, title='Mashup: ' + str(image_ixs) + ', ' + ','.join(queries))

In [None]:
# ask("work wear tops")
# ask("basic tops", [20, 20])
# similar(14998, figsize=(20, 20))
# similar(16819, figsize=(20, 20))
# mashup_all([14998, 15334, 14552], ['corset top'], figsize=(20, 20))
# similar(7745, figsize=(20, 20))
# mashup_all([14248], ['sunflower pattern'], figsize=(20, 20))
# mashup([7745, 5036], figsize=(20, 20))
# ask("light turtleneck with dark color", figsize=(20, 20))
# ask("sexy fine dine date dress corset top", figsize=(20, 20))
similar(493, figsize= (20, 20))
# ask("casual blazer", figsize=(20, 20))