In [39]:
import json
import os
from pathlib import Path
from typing import Dict, List

import numpy as np
from loguru import logger
from tqdm.auto import tqdm

## Calculate centroids for all collections

In [15]:
embeddings_dir = Path("../database/embeddings")
embeddings_paths = [
    path for path in list(embeddings_dir.iterdir()) if path.suffix == ".jsonl"
]
if not embeddings_dir.exists():
    logger.error(f"Embeddings dir not found: {embeddings_dir}.")

In [21]:
def load_embeddings(path: Path) -> List:
    if not path.exists():
        logger.error(f"File does not exist: {path}.")

    with open(path, "r", encoding="utf-8") as file:
        embeddings = []
        for line in file:
            embeddings.append(json.loads(line)[1]["data"][0]["embedding"])

    return embeddings

In [34]:
collection_centroids = {}
for file_path in tqdm(
    embeddings_paths, desc="Calculating centroids", total=len(embeddings_paths)
):
    embeddings = load_embeddings(file_path)
    centroid = np.array(embeddings).mean(axis=0)
    collection_centroids[file_path.stem] = centroid.tolist()

Calculating centroids:   0%|          | 0/5 [00:00<?, ?it/s]

In [36]:
centroid_path = Path("./collection_centroids.json")

Save centroids

In [38]:
with open(centroid_path, "w", encoding="utf-8") as file:
    file.write(json.dumps(collection_centroids, indent=4))

### Query router dev

In [46]:
def rout_query(centroids: Dict, query_embedding: List) -> str:
    centroids = list(centroids.items())
    centroids_np = np.array([value for key, value in centroids])
    query_np = np.array(query_embedding)

    norm_query = np.linalg.norm(query_np)
    norm_centroids = np.linalg.norm(centroids_np, axis=1)

    cosine_similarities = np.dot(centroids_np, query_np) / (norm_centroids * norm_query)
    max_index = np.argmax(cosine_similarities)

    collection = centroids[max_index][0]
    return collection

In [47]:
rout_query(
    centroids=collection_centroids,
    query_embedding=collection_centroids["porodicni_zakon"],
)

'porodicni_zakon'