In [1]:
import numpy as np
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

In [2]:
def l2_normalize(X: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    return X / np.maximum(norms, eps)

def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    # a, b: (d,)
    return float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))

In [3]:
def spherical_kmeans(X: np.ndarray, k: int, iters: int = 20, seed: int = 0) -> Tuple[np.ndarray, np.ndarray]:
    """
    X: (n,d) asumimos L2-normalizado; devuelve (labels, centroids(L2=1))
    """
    n, d = X.shape
    rng = np.random.default_rng(seed)
    # init++ simple: elige k puntos aleatorios
    centroids = X[rng.choice(n, size=k, replace=False)].copy()

    for _ in range(iters):
        # asignación
        sims = X @ centroids.T               # (n,k)
        labels = np.argmax(sims, axis=1)     # (n,)
        # actualización
        new_centroids = np.zeros_like(centroids)
        counts = np.bincount(labels, minlength=k)
        for j in range(k):
            idx = (labels == j)
            if counts[j] == 0:
                # reinit a un punto aleatorio si vacío
                new_centroids[j] = X[rng.integers(0, n)]
            else:
                c = X[idx].mean(axis=0)
                # normaliza para mantener métrica de coseno
                nc = np.linalg.norm(c)
                new_centroids[j] = c / (nc + 1e-12)
        if np.allclose(new_centroids, centroids, atol=1e-6):
            centroids = new_centroids
            break
        centroids = new_centroids
    # asignación final
    sims = X @ centroids.T
    labels = np.argmax(sims, axis=1)
    return labels, l2_normalize(centroids)

In [18]:
@dataclass
class Node:
    is_leaf: bool
    centroids: Optional[np.ndarray] = None     # (B,d) si interno
    children: List["Node"] = field(default_factory=list)
    idxs: Optional[np.ndarray] = None          # ids en hoja
    nid: int = -1

class KMeansTree:
    def __init__(self, B: int = 8, max_depth: int = 3, min_leaf_size: int = 256, kmeans_iters: int = 20, seed: int = 0):
        """
        B: branching factor
        max_depth: profundidad máxima del árbol
        min_leaf_size: tamaño mínimo para no seguir dividiendo
        """
        self.B = B
        self.max_depth = max_depth
        self.min_leaf_size = min_leaf_size
        self.kmeans_iters = kmeans_iters
        self.seed = seed
        self._next_id = 0
        self.root: Optional[Node] = None
        self.X: Optional[np.ndarray] = None  # embeddings normalizados

    # -------- build --------
    def fit(self, X: np.ndarray):
        """
        X: (n,d) embeddings (se normalizan internamente a L2=1)
        """
        X = l2_normalize(X.astype(np.float32))
        self.X = X
        n = X.shape[0]
        idxs = np.arange(n, dtype=np.int64)
        self.root = self._build_recursive(idxs, depth=0, seed=self.seed)

    def _new_node(self, **kwargs):
        node = Node(**kwargs)
        node.nid = self._next_id
        self._next_id += 1
        return node 

    def _build_recursive(self, idxs: np.ndarray, depth: int, seed: int) -> Node:
        # condición de hoja
        if depth >= self.max_depth or len(idxs) <= self.min_leaf_size:
            return self._new_node(is_leaf=True, idxs=idxs)

        X_sub = self.X[idxs]
        k = min(self.B, max(1, len(idxs)))  # por si hay pocos puntos
        if k == 1:
            return self._new_node(is_leaf=True, idxs=idxs)

        labels, centroids = spherical_kmeans(X_sub, k=k, iters=self.kmeans_iters, seed=seed)
        children = []
        for j in range(k):
            child_idxs = idxs[labels == j]
            if len(child_idxs) == 0:
                # crea hoja vacía para mantener aridad (opcional)
                children.append(self._new_node(is_leaf=True, idxs=np.array([], dtype=np.int64)))
            else:
                children.append(self._build_recursive(child_idxs, depth + 1, seed + j + 1))

        # Si hay menos hijos que B por pocos datos, centramos en k real
        node = self._new_node(is_leaf=False, centroids=centroids[:k], children=children)
        return node

    # -------- query --------
    def query(self, q: np.ndarray, topN: int = 10) -> List[Tuple[int, float]]:
        """
        q: (d,) embedding de consulta (se normaliza)
        Devuelve lista de (idx, similitud) ordenada desc.
        """
        assert self.root is not None and self.X is not None, "Primero llama a fit()"
        q = l2_normalize(q.reshape(1, -1).astype(np.float32))[0]
        leaf = self._route_greedy(q, self.root)
        if leaf.idxs.size == 0:
            return []
        X_leaf = self.X[leaf.idxs]  # ya normalizados
        sims = X_leaf @ q
        # topN parcial eficiente
        k = min(topN, sims.size)
        top_idx = np.argpartition(-sims, k - 1)[:k]
        pairs = list(zip(leaf.idxs[top_idx].tolist(), sims[top_idx].tolist()))
        # ordena exacto
        pairs.sort(key=lambda t: -t[1])
        return pairs

    def _route_greedy(self, q: np.ndarray, node: Node) -> Node:
        while not node.is_leaf:
            C = node.centroids  # (k,d)
            sims = C @ q        # coseno (centroides y q están L2=1)
            j = int(np.argmax(sims))
            node = node.children[j]
        return node
    
    def print_tree(self, max_children: int = 8):
        assert self.root is not None
        def _rec(node, depth):
            indent = "  " * depth
            if node.is_leaf:
                size = 0 if node.idxs is None else len(node.idxs)
                print(f"{indent}• [Leaf nid={node.nid}] size={size}")
            else:
                k = 0 if node.centroids is None else node.centroids.shape[0]
                print(f"{indent}◦ [Inner nid={node.nid}] k={k}")
                # por si hay muchos hijos, recorta la muestra visual
                children = node.children
                if len(children) > max_children:
                    head = children[:max_children//2]
                    tail = children[-max_children//2:]
                    for ch in head: _rec(ch, depth+1)
                    print(f"{indent}  ... ({len(children)-len(head)-len(tail)} hijos omitidos) ...")
                    for ch in tail: _rec(ch, depth+1)
                else:
                    for ch in children:
                        _rec(ch, depth+1)
        _rec(self.root, 0)

    def tree_stats(self):
        assert self.root is not None
        num_nodes = 0
        num_leaves = 0
        max_depth = 0
        leaf_sizes = []

        def _rec(node, depth):
            nonlocal num_nodes, num_leaves, max_depth
            num_nodes += 1
            max_depth = max(max_depth, depth)
            if node.is_leaf:
                num_leaves += 1
                leaf_sizes.append(0 if node.idxs is None else len(node.idxs))
            else:
                for ch in node.children:
                    _rec(ch, depth+1)
        _rec(self.root, 0)
        leaf_sizes = np.array(leaf_sizes)
        return {
            "num_nodes": num_nodes,
            "num_leaves": num_leaves,
            "max_depth": max_depth,
            "leaf_size_min": int(leaf_sizes.min()) if len(leaf_sizes) else 0,
            "leaf_size_med": float(np.median(leaf_sizes)) if len(leaf_sizes) else 0,
            "leaf_size_max": int(leaf_sizes.max()) if len(leaf_sizes) else 0,
        }

In [10]:
rng = np.random.default_rng(123)
n, d = 5000, 128
# Creamos 5 clusters "temáticos"
centers = l2_normalize(rng.normal(size=(5, d)))
X = []
for c in centers:
    cloud = l2_normalize(c + 0.2 * rng.normal(size=(n // 5, d)))
    X.append(cloud)
X = np.vstack(X).astype(np.float32)
print("Datos creados:", X.shape)

Datos creados: (5000, 128)


In [54]:
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings

model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
    model_kwargs={"device": "cuda"}
)

loaded_vectorstore=FAISS.load_local(
    "../data/db/parliament_db/parliament_all_docs_embeddings_sentence-transformers_paraphrase-multilingual-mpnet-base-v2",
    model,
    allow_dangerous_deserialization=True
)

print(f"Loaded vector store contains {loaded_vectorstore.index.ntotal} vectors")

n = loaded_vectorstore.index.ntotal
d = loaded_vectorstore.index.d  # dimensión de los embeddings

Loaded vector store contains 11162 vectors


In [59]:
# get metadata of the index
pos_to_store_id = [loaded_vectorstore.index_to_docstore_id[i] for i in range(n)]

doc_ids = []
for store_id in pos_to_store_id:
    doc = loaded_vectorstore.docstore.search(store_id)  # recupera el Document
    md = getattr(doc, "metadata", {}) if doc is not None else {}
    # Ajusta la prioridad de claves según cómo lo guardaste
    for key in ("id", "doc_id", "document_id", "uid"):
        if key in md:
            doc_ids.append(md[key])
            break
    else:
        # Si no hay id en metadata, usa el store_id como fallback
        doc_ids.append(store_id)
pos_to_doc_id = {i: doc_ids[i] for i in range(n)}

In [61]:
X = np.array([emb for emb in loaded_vectorstore.index.reconstruct_n(0, loaded_vectorstore.index.ntotal)])
print(f"All embeddings shape: {X.shape}")

All embeddings shape: (11162, 768)


In [62]:
tree = KMeansTree(B=8, max_depth=3, min_leaf_size=200, kmeans_iters=15, seed=42)
tree.fit(X)

In [63]:
tree.print_tree(max_children=10)

◦ [Inner nid=224] k=8
  ◦ [Inner nid=48] k=8
    ◦ [Inner nid=8] k=8
      • [Leaf nid=0] size=30
      • [Leaf nid=1] size=41
      • [Leaf nid=2] size=42
      • [Leaf nid=3] size=72
      • [Leaf nid=4] size=53
      • [Leaf nid=5] size=53
      • [Leaf nid=6] size=53
      • [Leaf nid=7] size=40
    ◦ [Inner nid=17] k=8
      • [Leaf nid=9] size=19
      • [Leaf nid=10] size=4
      • [Leaf nid=11] size=3
      • [Leaf nid=12] size=21
      • [Leaf nid=13] size=30
      • [Leaf nid=14] size=59
      • [Leaf nid=15] size=39
      • [Leaf nid=16] size=33
    ◦ [Inner nid=26] k=8
      • [Leaf nid=18] size=24
      • [Leaf nid=19] size=41
      • [Leaf nid=20] size=58
      • [Leaf nid=21] size=31
      • [Leaf nid=22] size=47
      • [Leaf nid=23] size=30
      • [Leaf nid=24] size=24
      • [Leaf nid=25] size=51
    • [Leaf nid=27] size=193
    • [Leaf nid=28] size=49
    ◦ [Inner nid=37] k=8
      • [Leaf nid=29] size=1
      • [Leaf nid=30] size=66
      • [Leaf nid=31] size=61
 

In [64]:
tree.tree_stats()

{'num_nodes': 225,
 'num_leaves': 197,
 'max_depth': 3,
 'leaf_size_min': 1,
 'leaf_size_med': 45.0,
 'leaf_size_max': 193}

In [65]:
from datasets import load_from_disk

dataset = load_from_disk("../data/processed/parliament_qa")
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'response', 'cost', 'documents', 'type', 'retrieved_pks', 'oracle_context', 'formatted_context'],
        num_rows: 614
    })
    validation: Dataset({
        features: ['id', 'question', 'response', 'cost', 'documents', 'type', 'retrieved_pks', 'oracle_context', 'formatted_context'],
        num_rows: 161
    })
    test: Dataset({
        features: ['question', 'id', 'response', 'type', 'retrieved_pks', 'oracle_context', 'injected_oracle', 'formatted_context', 'documents'],
        num_rows: 205
    })
})

In [70]:
query = dataset['test'][1]['question']
id_real = dataset['test'][1]['id']
print("Query:", query)
print("Real id:", id_real)
query_emb = np.array(model.embed_query(query))
results = tree.query(query_emb, topN=10)
print("Top-10 (idx, cos):")
for i, s in results:
    print(pos_to_doc_id[i], f"{s:.4f}")
    # get corresponding document using loaded_vectorstore
    

Query: ¿Qué argumentos presenta el presidente del Gobierno de Canarias, Clavijo Batlle, para justificar la necesidad urgente de recibir los fondos adeudados antes del cierre del presupuesto de 2024?
Real id: 6600_6
Top-10 (idx, cos):
6598_7 0.7116
6051_27 0.7111
6529_6 0.7007
5930_30 0.6925
6176_7 0.6867
6135_9 0.6859
6581_4 0.6842
6230_26 0.6835
6049_30 0.6775
6532_27 0.6718


In [71]:
retrival_dense = loaded_vectorstore.similarity_search_by_vector(query_emb, k=10)
print("\nTop-10 from dense retrival (store_id):")
for doc in retrival_dense:
    print(doc.metadata.get("id", "N/A"))


Top-10 from dense retrival (store_id):
5862_28
6598_7
6051_27
5472_1
5401_20
5418_11
6529_6
5687_7
5886_22
5418_21
