# Download Dataset

**Kaggle Dataset**

In [None]:
!pip install opendatasets

Collecting opendatasets
  Downloading opendatasets-0.1.20-py3-none-any.whl (14 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.20


In [None]:
import opendatasets as od

Для следующей ячейки потребуются данные из Kaggle аккаунта:

You Profile -> Account -> Create New API Token

In [None]:
od.download("https://www.kaggle.com/jessicali9530/stl10")

# Work with SearchModel

In [None]:
!pip install ruclip==0.0.1 > /dev/null

In [None]:
import numpy as np
# from faiss import Indexer


class DummyIndexer():
    def __init__(self):
        """
        Creates an empty index object
        """
        self.index = None

    def add(self, embs: np.ndarray):
        """
        Adds new embeddings embs in empty or existing index
        :param embs:
        :return:
        """
        if self.index is None:
            self.index = embs
        else:
            self.index = np.append(self.index, embs, axis=0)

    def train(self):
        """
        Not sure if this one is necessary here, left for compatibility with abstract class Indexer
        :return:
        """
        pass

    def find(self, query: np.ndarray, topn: int) -> (np.ndarray, np.ndarray):
        """
        Returns topn entries closest to the query vector
        :param query:
        :param topn:
        :return:
        """
        similarities = (self.index @ query.squeeze())
        best_photo_idx = (-similarities).argsort()
        D, I = similarities[best_photo_idx[:topn]], best_photo_idx[:topn]
        return D, I

    def save(self, file: str):
        """
        Saves data to npy file
        :param file:
        :return:
        """
        np.save(file, self.index)

    def load(self, file: str):
        """
        Loads data from npy file
        :param file:
        :return:
        """
        self.index = np.load(file)

In [None]:
import abc

import torch
import ruclip
import numpy as np

from PIL import Image

from numbers import Number
from typing import List

class Embedder(abc.ABC):
    @abc.abstractmethod
    def encode_text(self, text):
        pass

    @abc.abstractmethod
    def encode_imgs(self, imgs):
        pass

    def cos(self, emb1: np.ndarray, emb2: np.ndarray) -> Number:
        """
        Returns cos similarity between two embeddings
        :param emb1: 1D tensor
        :param emb2: 1D tensor
        :return: cos similarity (Number)
        """
        emb1, emb2 = emb1.squeeze(), emb2.squeeze() # convert (1, N) arrays to (N,)
        return np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))


class EmbedderRuCLIP(Embedder):
    def __init__(self, ruclip_model_name='ruclip-vit-base-patch32-384',
             device='cpu', templates = ['{}', 'это {}', 'на картинке {}']):
        """
        :param ruclip_model_name:
        :param device:
        :param templates:
        """
        clip, processor = ruclip.load(ruclip_model_name)
        self.predictor = ruclip.Predictor(clip, processor, device, bs=8, templates=templates)

    def _tonumpy(self, tensor: torch.Tensor) -> np.ndarray:
        """
        Detaches tensor from GPU and converts it to numpy array
        :return: numpy array
        """
        return tensor.cpu().detach().numpy()

    def encode_text(self, text: str) -> np.ndarray:
        """
        Returns text latent of the text input
        :param text:
        :return:
        """
        classes = [text, ]
        with torch.no_grad():
            text_latent = self.predictor.get_text_latents(classes)
        return self._tonumpy(text_latent)

    def encode_imgs(self, pil_imgs: List[Image.Image]) -> np.ndarray:
        """
        Returns image latents of a image batch
        :param pil_imgs: list of PIL images
        :return img_latents: numpy array of img latents
        """
        with torch.no_grad():
            img_latents = self.predictor.get_image_latents(pil_imgs)
        return self._tonumpy(img_latents)

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import math
from PIL import Image
from typing import List
from pathlib import Path


class SearchModel():
    def __init__(self, embedder, indexer):
        self.embedder = embedder
        self.indexer = indexer
        self.indexed_imgs_path = [] # array with indexed embeddings
        self.images_dir = None
        self.imgs_path = None       # array for temp embeddings storage
        self.features_path = None

    def load_imgs(self, path: str, prefix: str):
        """
        Returns a list of names images in a given path
        :param path:
        :return:
        """
        self.images_dir = path
        photos_path = Path(self.images_dir)
        features_dir = str(photos_path.parents[0]) + '/features/' + prefix
        self.features_path = Path(features_dir)
        self.imgs_path = list(photos_path.glob("*.*"))

        try:
          os.mkdir(str(photos_path.parents[0]) + '/features')
          os.mkdir(features_dir)
        except:
          self.indexed_imgs_path = list(pd.read_csv(f"{self.features_path}/photo_ids.csv")['photo_id'])


    def load_img_urls(self):
        """
        In case we want to load imgs from a list of url
        :return:
        """
        pass

    def add_photo_path(self, name):
        return f'{self.images_dir}/{name}.png'

    def save_embs(self) -> None:
        """
        Extracts image embeddings from embedder and adds them to indexer
        :param pil_imgs:
        :return:
        """
        self.indexed_imgs_path.extend(self.imgs_path)

        if(len(self.imgs_path) >= 512):
          batch_size = 512
        else:
          batch_size = len(self.imgs_path)

        # Compute how many batches are needed
        batches = math.ceil(len(self.imgs_path) / batch_size)

        # Process each batch
        for i in range(batches):
          print(f"Processing batch {i+1}/{batches}")

          batch_ids_path = self.features_path / f"{i:010d}.csv"
          batch_features_path = self.features_path / f"{i:010d}.npy"
    
          # Only do the processing if the batch wasn't processed yet
          if not batch_features_path.exists():
            try:
              # Select the photos for the current batch
              batch_files = self.imgs_path[i*batch_size : min(len(self.imgs_path), (i+1)*batch_size)]
              pil_batch = [Image.open(photo_file) for photo_file in batch_files]

              # Compute the features and save to a numpy file
              batch_features = self.embedder.encode_imgs(pil_batch)
              np.save(batch_features_path, batch_features)

              # Save the photo IDs to a CSV file
              photo_ids = [photo_file.name.split(".")[0] for photo_file in batch_files]
              photo_ids_data = pd.DataFrame(photo_ids, columns=['photo_id'])
              photo_ids_data.to_csv(batch_ids_path, index=False)
            except:
              # Catch problems with the processing to make the process more robust
              print(f'Problem with batch {i}')

        # Load all numpy files
        features_list = [np.load(features_file) for features_file in sorted(self.features_path.glob("*.npy"))]

        # Concatenate the features and store in a merged file
        features = np.concatenate(features_list)
        np.save(self.features_path / "features.npy", features)

        # Load all the photo IDs
        photo_ids = pd.concat([pd.read_csv(ids_file) for ids_file in sorted(self.features_path.glob("*.csv"))])
        photo_ids = photo_ids["photo_id"].apply(self.add_photo_path)
        photo_ids.to_csv(self.features_path / "photo_ids.csv", index=False)
        
        for file in glob.glob('{}/0*.*'.format(self.features_path)):
          os.remove(file)
        
        self.indexer.add(embs=features)    
    
    def get_k_imgs(self, emb: np.ndarray, k: int):
        """
        Returns k indices of nearest image embeddings and respective distances for a given embedding emb
        :param emb:
        :param k:
        :return:
        """
        distances, indices = self.indexer.find(emb, k)
        return distances, np.array(self.indexed_imgs_path)[indices]

Для создание индексов в Colab

In [None]:
test_model = SearchModel(EmbedderRuCLIP(device='cuda'), DummyIndexer())

Prefix зависит от того, какую модель вы подали в SearchModel

In [None]:
test_model.load_imgs('/content/stl10/train_images','RuCLIP')
test_model.save_embs()

Processing batch 1/10


512it [00:10, 48.35it/s]


Processing batch 2/10


512it [00:10, 48.37it/s]


Processing batch 3/10


512it [00:10, 48.96it/s]


Processing batch 4/10


512it [00:10, 49.17it/s]


Processing batch 5/10


512it [00:10, 49.09it/s]


Processing batch 6/10


512it [00:10, 49.26it/s]


Processing batch 7/10


512it [00:10, 49.21it/s]


Processing batch 8/10


512it [00:10, 49.32it/s]


Processing batch 9/10


512it [00:10, 49.32it/s]


Processing batch 10/10


392it [00:08, 47.37it/s]


In [None]:
query = test_model.embedder.encode_text(text="Обезьяна играет с мячиком")
test_model.get_k_imgs(query, 3)

(array([0.53042364, 0.4278646 , 0.41783807], dtype=float32),
 array([PosixPath('/content/stl10/train_images/train_image_png_1815.png'),
        PosixPath('/content/stl10/train_images/train_image_png_2566.png'),
        PosixPath('/content/stl10/train_images/train_image_png_2900.png')],
       dtype=object))

Для работы в Streamlit

In [None]:
test_streamlit = SearchModel(EmbedderRuCLIP(device='cpu'), DummyIndexer())

'/content/stl10/train_images' - Путь к самим изображениям

'/content/stl10/features' - Путь, где находится csv и npy файл

На сервере в папке датасета нужно будет создать папку "features" в ней "CLIP"/"RuCLIP" или скачивать всё с Colab и переносить в нужный путь

Выбор prefix зависит от подаваемого текста или же от того, на какой модели пользователь хочет получить результат, в данный момент поддерживаются "CLIP" и "RuCLIP"


In [None]:
test_streamlit.load_imgs('/content/stl10/train_images','RuCLIP')

In [None]:
test_streamlit.indexer.load(str(test_streamlit.features_path) + '/features.npy')

In [None]:
query = test_streamlit.embedder.encode_text(text="Обезьяна играет с мячиком")
test_streamlit.get_k_imgs(query, 3)

CPU times: user 469 ms, sys: 5.93 ms, total: 475 ms
Wall time: 469 ms
