In [None]:
from collections.abc import Sequence
from pathlib import Path
from typing import Union

import numpy as np
import onnxruntime as ort
import pandas as pd
import torch
from hydra import compose, initialize
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor, CLIPTextModelWithProjection

from sneakers_ml.models.onnx_utils import predict_clip, save_clip_model

In [None]:
class CLIPSimilaritySearchCreator:
    def __init__(self, device: str, clip_model_name: str, onnx_path: str) -> None:
        self.device = device
        self.clip_model_name = clip_model_name
        self.onnx_path = onnx_path
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)

        self._create_onnx_model(self.onnx_path)

        self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
        self.clip_model.eval()

    def _create_onnx_model(self, save_path: str) -> None:
        model = CLIPTextModelWithProjection.from_pretrained(self.clip_model_name)
        model.eval()

        text = ["a dummy sentence"]
        inputs = self.processor(text=text, return_tensors="pt", padding=True)

        save_clip_model(model, tuple(inputs.values()), save_path)

    def get_image_features_folder(self, folder_path: str) -> tuple[np.ndarray, np.ndarray, dict[str, int]]:
        dataset = ImageFolder(folder_path, transform=lambda x: self.processor(images=x, return_tensors="pt"))
        dataloader = DataLoader(
            dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=6, pin_memory=False
        )

        image_features = []
        with torch.inference_mode():
            for data in tqdm(dataloader, desc=folder_path):
                images = data[0].to(self.device)["pixel_values"].squeeze(1)
                outputs = self.clip_model.get_image_features(pixel_values=images)
                image_features.append(outputs.cpu())

        full_image_features = torch.cat(image_features, dim=0)
        numpy_image_features = full_image_features.numpy()
        image_paths = np.array(dataset.imgs)

        return numpy_image_features, image_paths, dataset.class_to_idx

    @staticmethod
    def save_features(path: str, numpy_features: np.ndarray, classes: np.ndarray, class_to_idx: dict[str, int]) -> None:
        save_path = Path(path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        with save_path.open("wb") as save_file:
            np.save(save_file, numpy_features, allow_pickle=False)
            np.save(save_file, classes, allow_pickle=False)
            np.save(save_file, np.array(list(class_to_idx.items())), allow_pickle=False)

    @staticmethod
    def load_features(path: str) -> tuple[np.ndarray, np.ndarray, dict[str, int]]:
        with Path(path).open("rb") as file:
            numpy_features = np.load(file, allow_pickle=False)
            classes = np.load(file, allow_pickle=False)
            class_to_idx_numpy = np.load(file, allow_pickle=False)
            class_to_idx = dict(zip(class_to_idx_numpy[:, 0], class_to_idx_numpy[:, 1].astype(int)))
            return numpy_features, classes, class_to_idx

In [None]:
# with initialize(version_base=None, config_path="config", job_name="text2image-features-create"):
#     cfg = compose(config_name="cfg_text_to_image")
#     creator = CLIPSimilaritySearchCreator(cfg.device, cfg.base_model, cfg.model_path)
#     numpy_features_, classes_, class_to_idx_ = creator.get_image_features_folder(cfg.images_path)
#     creator.save_features(cfg.embeddings_path, numpy_features_, classes_, class_to_idx_)

In [None]:
class CLIPTextToImageSimilaritySearch:
    def __init__(self, embeddings_path: str, onnx_path: str, metadata_path: str, clip_model_name: str) -> None:
        self.embeddings_path = embeddings_path
        self.onnx_path = onnx_path
        self.metadata_path = metadata_path

        self.numpy_features, self.classes, self.class_to_idx = CLIPSimilaritySearchCreator.load_features(
            self.embeddings_path
        )
        self.idx_to_class = {str(v): k for k, v in self.class_to_idx.items()}

        self.onnx_session = ort.InferenceSession(self.onnx_path)
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)

        self.df = pd.read_csv(self.metadata_path)
        self.df = self.df.drop(
            ["brand_merge", "images_path", "collection_name", "color", "images_flattened", "title_without_color"],
            axis=1,
        )
        self.df["title"] = self.df["title"].apply(eval)
        self.df["brand"] = self.df["brand"].apply(eval)
        self.df["price"] = self.df["price"].apply(eval)
        self.df["pricecurrency"] = self.df["pricecurrency"].apply(eval)
        self.df["website"] = self.df["website"].apply(eval)
        self.df["url"] = self.df["url"].apply(eval)
        self.df = self.df.explode(["title", "brand", "price", "pricecurrency", "url", "website"])

    def get_text_features(self, text_queries: Union[Sequence[str], str]) -> np.ndarray:
        inputs = self.processor(text=text_queries, return_tensors="np", padding=True)
        return predict_clip(self.onnx_session, inputs)

    def get_similar(self, text_query: str, top_k: int):
        text_features = self.get_text_features(text_query)
        similarity_matrix = cosine_similarity(self.numpy_features, text_features.reshape(1, -1)).flatten()
        top_k_indices = np.argsort(similarity_matrix)[-top_k:][::-1]

        similar_objects = self.classes[top_k_indices]
        similar_images = similar_objects[:, 0]
        similar_models = np.vectorize(self.idx_to_class.get)(similar_objects[:, 1])

        similar_metadata_dump = (
            self.df[self.df["title_merge"].isin(set(similar_models.tolist()))]
            .groupby(["title", "website"])
            .agg(
                {
                    "title_merge": "first",
                    "brand": "first",
                    "price": lambda x: f"{min(x)} - {max(x)}",
                    "pricecurrency": "first",
                    "url": "first",
                }
            )
            .reset_index()
            .to_numpy()
        )
        return similar_metadata_dump, similar_images

In [None]:
with initialize(version_base=None, config_path="config", job_name="text2image-features-predict"):
    cfg = compose(config_name="cfg_text_to_image")
    temp = CLIPTextToImageSimilaritySearch(cfg.embeddings_path, cfg.model_path, cfg.metadata_path, cfg.base_model)
    print(temp.get_similar("blue sneakers", 3))

(array([["CHUCK TAYLOR ALL STAR HIGH TOP 'LITE BLUE'", 'superkicks',
        'chuck taylor all star high top', 'Converse', '4299.0 - 4299.0',
        'INR',
        'https://www.superkicks.in/products/chuck-taylor-all-star-high-top-lite-blue'],
       ["GEL-LYTE V GODAI 'ARCTIC BLUE/SKY'", 'superkicks',
        'gel lyte v godai', 'Asics', '11999.0 - 11999.0', 'INR',
        'https://www.superkicks.in/products/gel-lyte-v-godai-arctic-blue-sky'],
       ["GEL-LYTE V GODAI 'EGGPLANT/PURPLE'", 'superkicks',
        'gel lyte v godai', 'Asics', '11999.0 - 11999.0', 'INR',
        'https://www.superkicks.in/products/gel-lyte-v-godai-eggplant-purple'],
       ["GEL-LYTE V GODAI 'SOOTHING SEA/SEAFOAM'", 'superkicks',
        'gel lyte v godai', 'Asics', '11999.0 - 11999.0', 'INR',
        'https://www.superkicks.in/products/gel-lyte-v-godai-soothing-sea-seafoam'],
       ["WMN'S STAN SMITH CS 'CLEAR SKY/WHITE/CORE WHITE'", 'superkicks',
        'wmns stan smith cs', 'Adidas Originals', '9999.

In [None]:
top_k = 3
text_features = temp.get_text_features("dark blue sneakers")
similarity_matrix = cosine_similarity(temp.numpy_features, text_features).flatten()

In [None]:
similarity_matrix[np.argsort(similarity_matrix)[-top_k:][::-1]]

array([0.3153295 , 0.3138626 , 0.31370813], dtype=float32)

In [None]:
np.argsort(similarity_matrix)[-top_k:][::-1]

array([ 9263, 22967, 10474])

In [None]:
# Get the indices of the top_k most similar items
top_k_indices = np.argsort(similarity_matrix)[-top_k:][::-1]

In [None]:
similar_objects = temp.classes[top_k_indices]
similar_images = similar_objects[:, 0]
similar_models = np.vectorize(temp.idx_to_class.get)(similar_objects[:, 1])

In [None]:
text_features = temp._get_text_feature("Dark blue sneakers")

In [None]:
similarity_matrix = cosine_similarity(temp.numpy_features, text_features).flatten()

In [None]:
similar_indices = np.argwhere((similarity_matrix >= 0.3) & (similarity_matrix <= 1)).flatten()

In [None]:
similar_objects = temp.classes[similar_indices]

In [None]:
similar_images = similar_objects[:, 0]

In [None]:
similar_images

array(['data/merged-with-footshop/images/by-models/2002r/11.jpeg',
       'data/merged-with-footshop/images/by-models/2002r/31.jpeg',
       'data/merged-with-footshop/images/by-models/2002r gore tex/1.jpeg',
       'data/merged-with-footshop/images/by-models/550/165.jpeg',
       'data/merged-with-footshop/images/by-models/576/5.jpeg',
       'data/merged-with-footshop/images/by-models/adifom climacool/39.jpeg',
       'data/merged-with-footshop/images/by-models/air max terrascape 90/21.jpeg',
       'data/merged-with-footshop/images/by-models/albatross 82/1.jpeg',
       'data/merged-with-footshop/images/by-models/all star high trainers/1.jpeg',
       'data/merged-with-footshop/images/by-models/basket butter goods/0.jpeg',
       'data/merged-with-footshop/images/by-models/blazer low 77 jumbo/19.jpeg',
       'data/merged-with-footshop/images/by-models/blazer low 77 jumbo/2.jpeg',
       'data/merged-with-footshop/images/by-models/campus 80s/62.jpeg',
       'data/merged-with-footsh

In [None]:
temp.idx_to_class

{'0': '01 low',
 '1': '01 low m',
 '2': '01 low man',
 '3': '01 low w',
 '4': '01 low wom',
 '5': '01 low wom leat draw',
 '6': '101',
 '7': '101 bex 6 eye boot',
 '8': '101 ys 6 eye boot',
 '9': '1460',
 '10': '1460 8 eye boot',
 '11': '1460 bex 8 eye boot',
 '12': '1460 bex squared 8 eye boot',
 '13': '1460 for pride 8 eye boot',
 '14': '1460 pascal 8 eye boot',
 '15': '1460 pascal mono 8 eye boot',
 '16': '1460 patent leather lace up boots',
 '17': '1460 serena 8 eye boot',
 '18': '1460 smooth mono',
 '19': '1460 vonda 8 eye boot',
 '20': '1461 3 eye shoe',
 '21': '1461 for pride',
 '22': '1461 patent lamper',
 '23': '1461 quad 3 eye shoe',
 '24': '1461 smooth',
 '25': '1490 bloom',
 '26': '1490 quad squa',
 '27': '1500',
 '28': '180',
 '29': '180 corduroy',
 '30': '180 pop',
 '31': '180 tones',
 '32': '180 w',
 '33': '1906',
 '34': '1906 protection pack',
 '35': '2001 2 high',
 '36': '2002',
 '37': '2002 protection pack',
 '38': '2002r',
 '39': '2002r gore tex',
 '40': '2976',
 '41

In [None]:
similar_models = np.vectorize(temp.idx_to_class.get)(similar_objects[:, 1])

In [None]:
temp.get_text_features("Hello world").flatten()

[[ 1.02045454e-01  1.27308190e-01 -1.69991583e-01  2.45203316e-01
  -8.50420967e-02 -2.42741793e-01  1.64409459e-01 -1.47837305e+00
   2.04973847e-01 -1.41165629e-02 -4.08323824e-01 -9.25241262e-02
  -2.32954949e-01 -9.10157710e-02  2.28031278e-01 -4.48102802e-02
   3.77843618e-01  3.15139443e-02 -2.41764374e-02 -9.05932188e-02
   2.54764646e-01 -2.41974741e-01  2.13344395e-01  1.66678190e-01
  -4.16782796e-01 -3.47045422e-01 -1.37874022e-01  2.72287309e-01
  -1.38795346e-01  2.60674775e-01  1.62942946e-01 -1.24393567e-01
  -4.74403799e-03 -1.19121030e-01  2.56806463e-02  1.27570763e-01
  -8.20921957e-02 -3.71039510e-02  1.55443341e-01 -1.69893831e-01
   1.10771656e-02 -1.65026128e-01 -9.52544063e-02 -1.75989807e-01
   1.98973075e-01  1.38817102e-01  3.96730825e-02 -1.14678860e-01
   1.23869166e-01  1.70168519e-01  3.67715359e-01 -3.34401369e-01
   9.82569158e-03 -3.86678040e-01  1.93235818e-02 -5.23265749e-02
  -1.64434686e-01  6.37686253e-03 -1.24753460e-01 -1.38096929e-01
   3.74642

array([ 1.02045454e-01,  1.27308190e-01, -1.69991583e-01,  2.45203316e-01,
       -8.50420967e-02, -2.42741793e-01,  1.64409459e-01, -1.47837305e+00,
        2.04973847e-01, -1.41165629e-02, -4.08323824e-01, -9.25241262e-02,
       -2.32954949e-01, -9.10157710e-02,  2.28031278e-01, -4.48102802e-02,
        3.77843618e-01,  3.15139443e-02, -2.41764374e-02, -9.05932188e-02,
        2.54764646e-01, -2.41974741e-01,  2.13344395e-01,  1.66678190e-01,
       -4.16782796e-01, -3.47045422e-01, -1.37874022e-01,  2.72287309e-01,
       -1.38795346e-01,  2.60674775e-01,  1.62942946e-01, -1.24393567e-01,
       -4.74403799e-03, -1.19121030e-01,  2.56806463e-02,  1.27570763e-01,
       -8.20921957e-02, -3.71039510e-02,  1.55443341e-01, -1.69893831e-01,
        1.10771656e-02, -1.65026128e-01, -9.52544063e-02, -1.75989807e-01,
        1.98973075e-01,  1.38817102e-01,  3.96730825e-02, -1.14678860e-01,
        1.23869166e-01,  1.70168519e-01,  3.67715359e-01, -3.34401369e-01,
        9.82569158e-03, -