In [None]:
from deepface import DeepFace
from pathlib import Path
import numpy as np
import pickle
import os
import faiss
from typing import Dict, List
from functools import reduce
from sklearn.metrics import average_precision_score
from tqdm import tqdm
from time import time, sleep
import contextlib
from distributed_faiss.client import IndexClient, IndexCfg, IndexState
import json
from elasticsearch import Elasticsearch, helpers
from pymilvus import (
    connections,
    utility,
    FieldSchema, CollectionSchema, DataType,
    Collection,
)

os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""

In [None]:
def random_float32_array(num, dimention):
    ret = np.zeros((num, dimention), np.float32)
    step = 100000
    for offset in tqdm(range(0, num, step)):
        n = min(num - offset, step)
        ret[offset:offset + n] = np.random.rand(n, dimention).astype(np.float32)
    return ret

In [None]:
@contextlib.contextmanager
def timeit(comment):
    print(f"start {comment} ...")
    start = time()
    yield
    used = time() - start
    print(comment, f": {round(used, 2)}s")

# 生成 Enmbedding

In [None]:
# embedding = DeepFace.represent(img_path="/home/featurize/work/vec/cfp-dataset/Data/Images/001/frontal/03.jpg")

In [None]:
# path = Path("/home/featurize/work/vec/cfp-dataset/Data/Images")

In [None]:
# all_embeddings = {}
# for person_id in os.listdir("/home/featurize/work/vec/cfp-dataset/Data/Images"):
#     all_embeddings[person_id] = []
#     for k, image_path in enumerate(Path(f"/home/featurize/work/vec/cfp-dataset/Data/Images/{person_id}/frontal/").glob("*.jpg")):
#         try:
#             embedding = DeepFace.represent(img_path=image_path.as_posix())
#         except ValueError:
#             continue
#         all_embeddings[person_id].append(embedding)

In [None]:
# Path("./embeddings.pkl").write_bytes(pickle.dumps(all_embeddings))

# Flat indexes 计算指标

In [None]:
from socket import timeout


class Index:

    def prepare(self, *args, **kwargs):
        pass

    def add(self, *args, **kwargs):
        pass

    def query(self, *args, **kwargs):
        pass

    def clean(self, *args, **kwargs):
        pass


class EsIndex(Index):

    def __init__(self, dimention):
        self.dim = dimention
        self.es = Elasticsearch(hosts="http://localhost:9200", timeout=30)
        self.index_name = "vec-search-index"

    def prepare(self, embeddings):
        print("delete index...")
        res = self.es.options(ignore_status=[400, 404], request_timeout=999).indices.delete(index=self.index_name, master_timeout="5m")
        print("create index...")
        index_body = {
            "mappings": {
                "properties": {
                    "vector": {
                        "type": "dense_vector",
                        "dims": 512,
                        "index": True,
                        "similarity": "l2_norm"
                    },
                },
            },
        }
        self.es.indices.create(index=self.index_name, body=index_body, timeout="500m", master_timeout="500m")
        with timeit("prepare embedding"):
            actions = [
                {
                    "_index": self.index_name,
                    "_id": i,
                    "_source": {
                        "vector": embedding.tolist()
                    }
                }
                for i, embedding in enumerate(tqdm(embeddings))
            ]
            print("start buck insert data...")
            helpers.bulk(self.es, actions)

    def query(self, embeddings, num=10):
        print(embeddings.shape)
        res = self.es.knn_search(index=self.index_name, knn={
            "field": "vector",
            "query_vector": embeddings[0][:512],
            "k": 10,
            "num_candidates": 10,
        }, source=[""])
        match_indices = []
        for hit in res["hits"]["hits"]:
            match_indices.append(int(hit["_id"]))
        return np.array(match_indices)[np.newaxis, :]

    def clean(self):
        return


class MilvusIndex(Index):

    def __init__(self, dimention):
        self.dim = dimention
        connections.connect("default", host="localhost", port="19530")
        utility.drop_collection("hello_milvus")
        fields = [
            FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False),
            FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim)
        ]
        schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs")
        self.index = Collection("hello_milvus", schema, consistency_level="Strong")

    def prepare(self, embeddings):
        with timeit("prepare embedding"):
            for k in tqdm(range(0, embeddings.shape[0], 10000)):
                entities = [
                    [i+k for i in range(min(10000, len(embeddings) - k))],
                    embeddings[k:k+10000, :],
                ]
                self.index.insert(entities)
            self.index.create_index("embeddings", {
                "index_type": "IVF_FLAT",
                "metric_type": "L2",
                "params": {"nlist": 4096},
            })
            self.index.load()

    def query(self, embeddings, num=10):
        search_params = {
            "metric_type": "L2",
            "params": {"nprobe": 10},
        }
        res = self.index.search(embeddings, "embeddings", search_params, limit=num, output_fields=["pk"])
        pks = []
        for hits in res:
            for hit in hits:
                pks.append(hit.entity.get('pk'))
        return np.array(pks)[np.newaxis, :]

    def clean(self):
        pass


class FaissIndex(Index):

    def __init__(self, use_gpu=False, index_str="Flat", dimention=512, distributed=False):
        self.dimention = dimention
        self.use_gpu = use_gpu
        self.distributed = distributed
        self.distributed_index_name = "dd_index"
        self.index_str = index_str

    def prepare(self, embeddings):
        assert len(embeddings.shape) == 2
        assert embeddings.shape[1] == self.dimention
        if self.distributed:
            self.index = IndexClient(
                "./faiss_distribute_config.txt",
            )
            print("start to train index in distributed mode...")
            with timeit("train index in distributed mode"):
                self.index.create_index(self.distributed_index_name, IndexCfg(
                    faiss_factory="OPQ16_64,IMI2x8,PQ8+16",
                    dim=self.dimention,
                    index_storage_dir="/home/featurize/faiss_client_index",
                    metric="l2"
                ))
                self.index.add_index_data(self.distributed_index_name, embeddings)
                self.index.async_train(self.distributed_index_name)
                for i in range(30):
                    if self.index.get_state(self.distributed_index_name) != IndexState.TRAINED:
                        sleep(1)
                        continue
        else:
            self.index = faiss.index_factory(self.dimention, self.index_str)
            if self.use_gpu:
                res = faiss.StandardGpuResources()
                self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
            if not self.index.is_trained:
                print("start to train index...")
                with timeit("training time"):
                    self.index.train(embeddings)
            print("start to add index...")
            print(embeddings.shape)
            self.index.add(embeddings)

    def query(self, embeddings, num=10):
        if self.distributed:
            D, I = self.index.search(embeddings, num, self.distributed_index_name, return_embeddings=False)
        else:
            D, I = self.index.search(embeddings, k=num)
        return I

In [None]:
class VectorDB:

    def __init__(self, index: Index, embeddings: Dict[str, List[List[float]]], dimention=512):
        # embeddings: {"user_id": [embedding, embedding, embedding]}
        self.raw_embeddings = embeddings
        self.dimention = dimention
        self.user_lookup_table = []
        self.index = index
        for user_id, embeddings_list in self.raw_embeddings.items():
            if len(embeddings_list) == 0:
                continue
            self.user_lookup_table += [user_id] * len(embeddings_list)

    def build_index(self, number_of_extra_embedding):
        face_embeddings = []
        for user_id, embeddings_list in self.raw_embeddings.items():
            if len(embeddings_list) == 0:
                continue
            face_embeddings.append(embeddings_list)

        face_embeddings = np.concatenate(face_embeddings)[:, :self.dimention].astype(np.float32)
        print("create random embeddings...")
        embeddings = random_float32_array(number_of_extra_embedding, self.dimention)
        embeddings[:len(face_embeddings)] = face_embeddings
        print("random embeddings created")

        print("embeddings shape: ", embeddings.shape)
        print("user_lookup_table len: ", len(self.user_lookup_table))
        self.index.prepare(embeddings)

    def search(self, embeddings, num=10):
        ret = self.index.query(embeddings[:, :self.dimention], num)
        return ret

    def test_add_embeddings(self, num):
        start = time()
        new_tokens = np.random.rand(num, self.dimention).astype(np.float32)
        self.index.add(new_tokens)
        time_used = time() - start
        print(f"add {num} tokens used {time_used}")

    def index2user_id(self, index):
        if index >= len(self.user_lookup_table):
            return "unknown"
        else:
            return self.user_lookup_table[index]

    def evaluate_map(self):
        # evaluate mAP
        start = time()
        k = np.max([len(x) for user_id, x in self.raw_embeddings.items()]).item()
        scores = []
        index_time = 0
        for user_id, embeddings in tqdm(self.raw_embeddings.items()):
            if len(embeddings) < 2:
                continue
            gt = [True] * (len(embeddings) - 1) + [False] * (k - len(embeddings))
            for embedding in embeddings:
                embedding = np.array(embedding, dtype=np.float32)[np.newaxis, :]
                index_time += 1
                results = self.search(embedding, num=k)[0]
                results = results[1:]
                results = [self.index2user_id(pred.item()) == user_id for pred in results]
                score = average_precision_score(gt, results)
                scores.append(score)
        map = np.mean(scores)
        time_used = time() - start
        time_used / index_time * 1000
        print("mAP: ", round(map, 4))
        print(f"performance: {round(time_used, 4)}s / 1kquerys", )

    def benchmark(self, embedding_num=None):
        embedding_num = int(embedding_num)
        if embedding_num is None:
            print(f"benchmark of original embeddings")
        else:
            print(f"benchmark of {embedding_num} embeddings")
        self.build_index(embedding_num)
        self.evaluate_map()
        self.test_add_embeddings(1000)
        self.test_add_embeddings(10000)

In [None]:
embeddings = pickle.loads(Path("./embeddings.pkl").read_bytes())

In [None]:
# index = FaissIndex(index_str="OPQ16_64,IMI2x8,PQ8+16", dimention=512)
# index = FaissIndex(index_str="IVF4096,Flat", dimention=512)
# index = MilvusIndex(dimention=512)
index = EsIndex(dimention=512)
vdb = VectorDB(index, embeddings, dimention=512)

In [None]:
vdb.benchmark(1e4)

In [None]:
# # 单样本预测
# lookup_index = 220
# res = vdb.search(vdb.embeddings[lookup_index:lookup_index+1], num=5)
# print("Query user: ", vdb.user_lookup_table[lookup_index])
# print("Top 10 res: ", ", ".join([vdb.user_lookup_table[k] for k in res[0]]))