In [1]:
import requests

from dataclasses import dataclass
from elasticsearch import Elasticsearch

In [20]:
import os

os.chdir(os.path.dirname(os.getcwd()))

In [2]:
@dataclass
class Config:
    ES_URL: str = "http://flush-es-v3.dev.omnious.com"
    ES_PORT: int = 30313
    ES_LOGIN_ID: str = "elastic"
    ES_LOGIN_PWD: str = "5P44CV414Pxb4f7R7nfU1rkq"
    ES_INDEX: str = "01hz64dqt6h8fqkmm78eyz52sd"
    
    ES_RETRY_COUNT: int = 3
    ES_TIMEOUT: int = 2
    ES_MAX_COUNT: int = 1000

In [4]:
class ElasticsearchService:
    def __init__(self, config, timeout: int = 30):
        self.es = Elasticsearch(
            f"{config.ES_URL}:{config.ES_PORT}",
            http_auth=(config.ES_LOGIN_ID, config.ES_LOGIN_PWD),
            timeout=config.ES_TIMEOUT,
            max_retries=config.ES_RETRY_COUNT,
            retry_on_timeout=True,
        )

    def search(self, index: str, query: list, size: int = 10):
        tmp_res, query = query
        try:
            es_results = self.es.search(
                index=index,
                query={
                    "bool": {
                        "should": [
                            (
                                {"term": {k: v}}
                                if not isinstance(v, list)
                                else {"terms": {k: v}}
                            )
                            for k, v in query.items()
                        ]
                    }
                },
                source_excludes=[
                    "clip_features_l2_norm",
                    "embedded_features_l2_norm",
                    "duplicate_features_l2_norm",
                ],
                size=size,
            )
            es_results = es_results.get("hits", dict()).get("hits", list())
            return es_results
        except Exception as e:
            raise Exception(f"{e}, check the result : {tmp_res}, query : {query}")

In [3]:
config = Config

es = ElasticsearchService(config=config)

In [6]:
query = {
    "bool": {
        "must": {
            "exists": {
                "field": "semantic_vectors_l2_norm"
            }
        }
    }
}
hits = []
scroll_id_list = []

es_results = es.es.search(
    index=config.ES_INDEX,
    query=query,
    size=10000,
    scroll="1m",
    track_total_hits=True
)
tmp_hits = es_results.get("hits", dict()).get("hits", list())
hits.extend(tmp_hits)

scroll_id = es_results["_scroll_id"]
scroll_id_list.append(scroll_id)
while len(tmp_hits):
    es_results = es.es.scroll(
        scroll_id=scroll_id_list[-1],
        scroll="1m"
    )
    tmp_scroll_id = es_results["_scroll_id"]
    scroll_id_list.append(tmp_scroll_id)
    tmp_hits = es_results.get("hits", dict()).get("hits", list())
    hits.extend(tmp_hits)

In [13]:
pids, embeddings = [], []
for h in hits:
    pid = h["_source"]["productId"]
    embed = h["_source"]["embedded_features_l2_norm"]
    pids.append(pid)
    embeddings.append(embed)

In [18]:
from jovis_model.utils.helper import build_faiss_index

In [21]:
build_faiss_index(
    embeddings=embeddings,
    save_path="outputs/skb",
    save_name="KB_with_des_cohere",
    pids=pids
)