In [1]:
import faiss
import numpy as np
import clip
from typing import List, Optional
import torch
from lavis.models import load_model_and_preprocess


In [3]:
import time
import json
from PIL import Image
import os
import glob

In [5]:
class Faiss: 
    def __init__(self, database_path: str, use_gpu: bool = True) -> None:
        database = np.load(database_path)
        
        db_size,_, dim = database.shape
        
        print('Indexing database...')
        print('database shape:', database.shape)
        # self.index_flat = self._get_indexer(dim, 'IP')
        self.index_flat = faiss.read_index(database_path)


        if use_gpu:
            res = faiss.StandardGpuResources()  # use a single GPU

            # make it into a gpu index
            self.index_flat = faiss.index_cpu_to_gpu(res, 0, self.index_flat)

        self.index_flat.add(database)


        print('Finish indexing database')
        
    def _get_indexer(self, dim: int, id_type: str):
        # if id_type == 'L2':
        #     return faiss.IndexFlatL2(dim)

        return faiss.read_index("/mmlabworkspace/Students/AIC/MMLAB-UIT-AIC2023/faiss/info.json")
    
    def search(self, encoded_queries: np.array, top_k: int) -> np.array:
        """_Return indexes of every query in queries

        Args:
            queries (np.array): (n_queries, dim)
            top_k (int): top k nearest

        Returns:
            np.array: (n_queries, top_k)
        """
        print('query.shape:', encoded_queries.shape)
        
        distances, indices = self.index_flat.search(encoded_queries, top_k)
        return distances, indices

class Encoder:
    def __init__(self, model_name: str='ViT-B/16',project=None, use_gpu: bool = True) -> None:
        if use_gpu:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = 'cpu'
        self.model_name = model_name
        self.project = project
        print('Loading model...')
        if self.model_name == 'BLIP':
            self.model, self.preprocess, self.tokenize = load_model_and_preprocess(
                                                        name="blip2_feature_extractor",
                                                        model_type="pretrain",
                                                        is_eval=True,
                                                        device=self.device)
        else:
            self.tokenize = clip.tokenize
            self.model, self.preprocess = clip.load(model_name, device=self.device)

        print('Finish loading model.')
            
    def encode_texts(self, text: List[str]) -> np.array:
        
                
        if self.model_name == 'BLIP':
            text_input = self.tokenize["eval"](text[0])
            sample = {"image": "", "text_input": [text_input]}
            features_text = self.model.extract_features(sample, mode="text")
            if self.project:
                text_features = features_text.text_embeds_proj[:,0,:].t().cpu().numpy().astype(np.float32)
            else:
                text_features = features_text.text_embeds[:,0,:].t().cpu().numpy().astype(np.float32)
        else:
            tokenized_text = self.tokenize(text).to(self.device)
            with torch.no_grad():
                text_features = self.model.encode_text(tokenized_text)
                text_features = text_features.cpu().numpy()
                text_features = text_features / np.linalg.norm(text_features)
        return text_features
            
    def encode_image(self, image) -> np.array:
        if self.model_name == 'ViT_B/16':
            image = self.preprocess(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                image_feature = self.model.encode(image)
        # if self.model_name == 'BLIP':
        #     from PIL import Image
        #     image = Image.fromarray(cv2.imread(img_file))
        #     image = vis_processors["eval"](image).unsqueeze(0).cuda()
        #     sample = {"image": image, "text_input": [""]}
        #     feature_image = model.extract_features(sample, mode="image")
        #     if self.project:
        #         text_features = features_text.image_embeds_proj
        #         print('project=True:', text_features.shape)
        #     else:
        #         text_features = features_text.image_embeds
        #         print('project=False:', text_features.shape)
        return image_feature
    

In [8]:



# Testing

database = np.load('/mmlabworkspace/Students/AIC/MMLAB-UIT-AIC2023/data/Merge/Feature_data/BLIP_256/feature.npy')
print('Database shape', database.shape)



Database shape (202148, 32, 256)


In [6]:
database_path = '/mmlabworkspace/Students/AIC/MMLAB-UIT-AIC2023/data/Merge/Feature_data/BLIP_256/feature.npy'
start_time = time.time()
faiss_searcher = Faiss(database_path=database_path)
print('Index time:', time.time() - start_time)

Indexing database...
database shape: (202148, 32, 256)


RuntimeError: Error in faiss::Index* faiss::read_index(faiss::IOReader*, int) at /home/circleci/miniconda/conda-bld/faiss-pkg_1681998300314/work/faiss/impl/index_read.cpp:1027: Index type 0x4d554e93 ("\x93NUM") not recognized

In [None]:
start_time = time.time()
encoder = Encoder(model_name="BLIP", project=True)
print('Load model time:', time.time() - start_time)
start_time = time.time()
queries = np.array(['A big cat'])
encoded_text = encoder.encode_texts(queries)
print('Encode time:', time.time() - start_time)



In [None]:
start_time = time.time()
result_indices = faiss_searcher.search(encoded_text, top_k=10).ravel()
print('Search time:', time.time() - start_time)
print(result_indices[0])

# with open('data/keyframe_names') as f:
#     keyframe_names = json.load(f)
# result = [keyframe_names[i] for i in result_indices]
# print(result)

# for file in glob.glob('faiss/result/*'):
#     os.remove(file)
# os.makedirs('faiss/result/', exist_ok=True)
# for image_file in result:
#     img = Image.open(image_file)
#     img.save(os.path.join('faiss/result', os.path.basename(image_file)))

In [None]:
import argparse
import os
import glob
import pickle
from unittest.result import failfast

import faiss
import numpy as np
import torch
from tqdm.auto import tqdm

import info


def load_embeddings(embeddings_dir):
    if os.path.exists(f'{embeddings_dir}/embeddings.pkl'):
        with open(f'{embeddings_dir}/embeddings.pkl', 'rb') as f:
            embeddings = pickle.load(f)
    else:
        embeddings = []
        for embedding_file in tqdm(sorted(glob.glob(f"{embeddings_dir}/*.npz"))):
            data = np.load(embedding_file)
            embeddings.append(np.array(data.get("feature_lst")[:, 0, 0, :]).astype(np.float32))
        embeddings = np.vstack(embeddings)
        
        with open(f'{embeddings_dir}/embeddings.pkl', 'wb') as f:
            pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)
        
    print(f"Loaded embedding with shape: {embeddings.shape}")
    return embeddings


def set_nprobe(index, nprobe):
    changed = False
    if hasattr(index, "nprobe") and nprobe and index.nprobe != nprobe:
        index.nprobe = nprobe
        print(f"Set nprobe = {nprobe}")
        changed = True
    return changed


def auto_nprobe(nlist):
    nprobe = min(max(round(2e-3 * nlist), 128), nlist)
    print(f"Automatic nprobe: {nprobe}")
    return nprobe


def auto_ivf_sq(nembeddings):
    # refering to https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
    ncentroids = 0
    if nembeddings <= 1e6:
        ncentroids = int(np.ceil(16 * np.sqrt(nembeddings)))
    elif 1e6 < nembeddings <= 10e6:
        ncentroids = 65536
    elif 10e6 < nembeddings <= 100e6:
        ncentroids = 262144
    elif 100e6 <= nembeddings <= 1e9:
        ncentroids = 1048576
    else:
        raise ValueError(
            "Too many embeddings! Please set the index factory string yourself"
        )

    ncentroids = min(nembeddings // 39, ncentroids)
    index_factory_string = f"IVF{ncentroids},SQ4"
    print(f"Automatic index factory string: {index_factory_string}")
    return index_factory_string


def run(
    embeddings_dir,
    output_dir,
    index_factory_string=None,
    distance="IP",
    nprobe=None,
    use_gpu=False,
):
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"AIC_db_{distance}.index")
    if not os.path.exists(output_path):
        use_gpu = use_gpu and torch.cuda.is_available()
        embeddings = load_embeddings(embeddings_dir)
        ndim = embeddings.shape[1]
        if len(embeddings) > 1e7:
            print("WARNING: #embeddings > 10M, please use GPU(s) for saving time!!!")

        if index_factory_string is None:
            index_factory_string = auto_ivf_sq(len(embeddings))
        if distance == "IP":
            distance = faiss.METRIC_INNER_PRODUCT
        elif distance == "L2":
            distance = faiss.METRIC_L2
        else:
            raise NotImplementedError
        index = faiss.index_factory(ndim, index_factory_string, distance)

        if use_gpu:
            ngpus = faiss.get_num_gpus()
            print(f"Using {ngpus} gpus")
            index = faiss.index_cpu_to_all_gpus(index)
        else:
            print("Using cpu")

        index.train(embeddings)
        index.add(embeddings)
        if use_gpu:
            index = faiss.index_gpu_to_cpu(index)

        if nprobe is None and index.nlist:
            nprobe = auto_nprobe(index.nlist)
        set_nprobe(index, nprobe)
        faiss.write_index(index, output_path)
    else:
        print("Found existing index")
        index = faiss.read_index(output_path)
        if set_nprobe(index, nprobe):
            faiss.write_index(index, output_path)

    info.run(output_dir, output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--embeddings_dir")
    parser.add_argument("--output_dir")
    parser.add_argument(
        "--index_factory_string",
        default=None,
        help="By default, it will use IVFSQ index",
    )
    parser.add_argument("--distance", default="IP", choices=["L2", "IP"])
    parser.add_argument(
        "--nprobe",
        type=int,
        default=None,
        help="How many clusters you want to dive in for each search. By default it will be 2e-3 * nlist",
    )
    parser.add_argument(
        "--use_gpu", action="store_true", help="Please use GPU(s) if #embeddings >= 10M"
    )
    args = parser.parse_args()
    run(**vars(args))