In [2]:
import numpy as np
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

In [3]:
def l2_normalize(X: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    return X / np.maximum(norms, eps)

def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    # a, b: (d,)
    return float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))

In [4]:
def spherical_kmeans(X: np.ndarray, k: int, iters: int = 20, seed: int = 0) -> Tuple[np.ndarray, np.ndarray]:
    """
    X: (n,d) asumimos L2-normalizado; devuelve (labels, centroids(L2=1))
    """
    n, d = X.shape
    rng = np.random.default_rng(seed)
    # init++ simple: elige k puntos aleatorios
    centroids = X[rng.choice(n, size=k, replace=False)].copy()

    for _ in range(iters):
        # asignación
        sims = X @ centroids.T               # (n,k)
        labels = np.argmax(sims, axis=1)     # (n,)
        # actualización
        new_centroids = np.zeros_like(centroids)
        counts = np.bincount(labels, minlength=k)
        for j in range(k):
            idx = (labels == j)
            if counts[j] == 0:
                # reinit a un punto aleatorio si vacío
                new_centroids[j] = X[rng.integers(0, n)]
            else:
                c = X[idx].mean(axis=0)
                # normaliza para mantener métrica de coseno
                nc = np.linalg.norm(c)
                new_centroids[j] = c / (nc + 1e-12)
        if np.allclose(new_centroids, centroids, atol=1e-6):
            centroids = new_centroids
            break
        centroids = new_centroids
    # asignación final
    sims = X @ centroids.T
    labels = np.argmax(sims, axis=1)
    return labels, l2_normalize(centroids)

In [5]:
import numpy as np

def _cluster_anisotropy(Xc: np.ndarray, eps: float = 1e-8) -> float:
    """
    Xc: puntos del cluster centrados (n_k, d) en float32.
    Devuelve A = 1 - sphericity in [0,1).
    """
    if Xc.shape[0] <= 1:
        return 0.0
    # Covarianza vía momentos (sin formar la matriz completa)
    # tr(Sigma) = sum_j Var_j ;  tr(Sigma^2) = ||Sigma||_F^2
    # Calculamos Sigma explícita solo si d es moderada; para d muy grande,
    # puedes estimar ||Sigma||_F^2 por muestreo.
    C = np.cov(Xc, rowvar=False)
    tr1 = np.trace(C)
    tr2 = np.sum(C * C)  # Frobenius^2
    d = C.shape[0]
    if tr2 <= eps:
        return 0.0
    sphericity = (tr1 * tr1) / (d * tr2 + eps)
    sphericity = float(np.clip(sphericity, 0.0, 1.0))
    return 1.0 - sphericity  # anisotropía

def spherical_kmeans_iso(X: np.ndarray, k: int, iters: int = 20, seed: int = 0,
                         lambda_iso: float = 0.02, reg_mu: float = 1e-3):
    """
    Spherical K-Means con sesgo de asignación por anisotropía de cluster.
    X: (n,d) L2-normalizado.
    k: nº de clusters.
    lambda_iso: peso del regularizador (0 => K-Means estándar).
    """
    n, d = X.shape
    rng = np.random.default_rng(seed)
    # init: elige k puntos aleatorios como centroides
    centroids = X[rng.choice(n, size=k, replace=False)].copy()

    for _ in range(iters):
        # ----- ASIGNACIÓN con sesgo por anisotropía -----
        # Similitud coseno equivale a distancia euclídea en L2=1:
        # argmax cos <=> argmin ||x - c||^2 = 2(1 - cos) (constante por x)
        sims = X @ centroids.T                    # (n,k)
        # sesgo por cluster (igual para todos los puntos, cambia cada iter.)
        # estimamos anisotropía del cluster "actual" usando las asignaciones previas;
        # en la 1ª iteración no tenemos labels: usa 0.
        if _ == 0 or 'labels' not in locals():
            bias = np.zeros(k, dtype=np.float32)
        else:
            bias = np.zeros(k, dtype=np.float32)
            for j in range(k):
                idx = (labels == j)
                if not np.any(idx):
                    bias[j] = 0.0
                else:
                    Xj = X[idx]
                    mu = Xj.mean(axis=0, dtype=np.float32)
                    Xc = (Xj - mu).astype(np.float32)
                    A = _cluster_anisotropy(Xc)
                    bias[j] = lambda_iso * A
        # Convertimos sesgo a espacio “distancia”: distancia_eff ≈ 2(1 - cos) + bias
        # Como 2 y el término constante no afectan al argmin, basta con:
        eff = -sims + bias  # menor es mejor
        labels = np.argmin(eff, axis=1)

        # ----- ACTUALIZACIÓN de centroides -----
        new_centroids = np.zeros_like(centroids)
        counts = np.bincount(labels, minlength=k)
        for j in range(k):
            idx = (labels == j)
            if counts[j] == 0:
                new_centroids[j] = X[rng.integers(0, n)]
            else:
                c = X[idx].mean(axis=0)
                # opcional: pequeña retracción hacia 0 para evitar medios muy sesgados
                if reg_mu > 0:
                    c = c * (1.0 - reg_mu)
                new_centroids[j] = c / (np.linalg.norm(c) + 1e-12)

        # parada por convergencia de centroides
        if np.allclose(new_centroids, centroids, atol=1e-6):
            centroids = new_centroids
            break
        centroids = new_centroids

    # asignación final
    sims = X @ centroids.T
    labels = np.argmax(sims, axis=1)
    return labels, (centroids / (np.linalg.norm(centroids, axis=1, keepdims=True) + 1e-12))


In [6]:
@dataclass
class Node:
    is_leaf: bool
    centroids: Optional[np.ndarray] = None     # (B,d) si interno
    children: List["Node"] = field(default_factory=list)
    idxs: Optional[np.ndarray] = None          # ids en hoja
    nid: int = -1

class KMeansTree:
    def __init__(self, B: int = 8, max_depth: int = 3, min_leaf_size: int = 256, kmeans_iters: int = 20, seed: int = 0, iso: bool = False):
        """
        B: branching factor
        max_depth: profundidad máxima del árbol
        min_leaf_size: tamaño mínimo para no seguir dividiendo
        """
        self.B = B
        self.max_depth = max_depth
        self.min_leaf_size = min_leaf_size
        self.kmeans_iters = kmeans_iters
        self.seed = seed
        self._next_id = 0
        self.iso = iso
        self.root: Optional[Node] = None
        self.X: Optional[np.ndarray] = None  # embeddings normalizados
        self._id2leaf = None   # dict: int -> Node
        self._id2path = None   # dict: int -> tuple(int, ...)

    # -------- build --------
    # --- helpers internos ---
    def _build_id_maps(self):
        """Construye mapas doc_id -> hoja y doc_id -> path (tupla de índices hijo)."""
        self._id2leaf = {}
        self._id2path = {}

        def dfs(node, path_prefix):
            if node.is_leaf:
                if node.idxs is None:
                    return
                for doc_id in node.idxs.tolist():
                    self._id2leaf[doc_id] = node
                    self._id2path[doc_id] = tuple(path_prefix)
                return
            # recorre hijos guardando el índice de hijo en la ruta
            for child_idx, ch in enumerate(node.children):
                dfs(ch, path_prefix + [child_idx])

        dfs(self.root, [])

    def fit(self, X: np.ndarray):
        """
        X: (n,d) embeddings (se normalizan internamente a L2=1)
        """
        X = l2_normalize(X.astype(np.float32))
        self.X = X
        n = X.shape[0]
        idxs = np.arange(n, dtype=np.int64)
        self.root = self._build_recursive(idxs, depth=0, seed=self.seed)
        self._build_id_maps()

    def _new_node(self, **kwargs):
        node = Node(**kwargs)
        node.nid = self._next_id
        self._next_id += 1
        return node 

    def _build_recursive(self, idxs: np.ndarray, depth: int, seed: int) -> Node:
        # condición de hoja
        if depth >= self.max_depth or len(idxs) <= self.min_leaf_size:
            return self._new_node(is_leaf=True, idxs=idxs)

        X_sub = self.X[idxs]
        k = min(self.B, max(1, len(idxs)))  # por si hay pocos puntos
        if k == 1:
            return self._new_node(is_leaf=True, idxs=idxs)

        if self.iso:
            labels, centroids = spherical_kmeans_iso(X_sub, k=k, iters=self.kmeans_iters, seed=seed)
        else:
            labels, centroids = spherical_kmeans(X_sub, k=k, iters=self.kmeans_iters, seed=seed)
        children = []
        for j in range(k):
            child_idxs = idxs[labels == j]
            if len(child_idxs) == 0:
                # crea hoja vacía para mantener aridad (opcional)
                children.append(self._new_node(is_leaf=True, idxs=np.array([], dtype=np.int64)))
            else:
                children.append(self._build_recursive(child_idxs, depth + 1, seed + j + 1))

        # Si hay menos hijos que B por pocos datos, centramos en k real
        node = self._new_node(is_leaf=False, centroids=centroids[:k], children=children)
        return node
    
    # --- API pública ---
    def get_leaf_path(self, doc_id: int):
        """
        Devuelve la ruta desde la raíz hasta la hoja que contiene doc_id
        como lista de índices de hijo [i0, i1, ..., iL-1].
        Lanza KeyError si doc_id no existe.
        """
        if self._id2path is None:
            raise RuntimeError("El índice aún no ha sido construido. Llama a fit() primero.")
        try:
            return list(self._id2path[int(doc_id)])
        except KeyError:
            raise KeyError(f"doc_id {doc_id} no existe en este árbol")

    def get_leaf_node(self, doc_id: int):
        """
        Devuelve el objeto Node (hoja) que contiene doc_id.
        Lanza KeyError si doc_id no existe.
        """
        if self._id2leaf is None:
            raise RuntimeError("El índice aún no ha sido construido. Llama a fit() primero.")
        try:
            return self._id2leaf[int(doc_id)]
        except KeyError:
            raise KeyError(f"doc_id {doc_id} no existe en este árbol")

    # -------- query --------
    def query(self, q: np.ndarray, topN: int = 10) -> List[Tuple[int, float]]:
        """
        q: (d,) embedding de consulta (se normaliza)
        Devuelve lista de (idx, similitud) ordenada desc.
        """
        assert self.root is not None and self.X is not None, "Primero llama a fit()"
        q = l2_normalize(q.reshape(1, -1).astype(np.float32))[0]
        leaf = self._route_greedy(q, self.root)
        if leaf.idxs.size == 0:
            return []
        X_leaf = self.X[leaf.idxs]  # ya normalizados
        sims = X_leaf @ q
        # topN parcial eficiente
        k = min(topN, sims.size)
        top_idx = np.argpartition(-sims, k - 1)[:k]
        pairs = list(zip(leaf.idxs[top_idx].tolist(), sims[top_idx].tolist()))
        # ordena exacto
        pairs.sort(key=lambda t: -t[1])
        return pairs

    def _route_greedy(self, q: np.ndarray, node: Node) -> Node:
        while not node.is_leaf:
            C = node.centroids  # (k,d)
            sims = C @ q        # coseno (centroides y q están L2=1)
            j = int(np.argmax(sims))
            node = node.children[j]
        return node
    
    def print_tree(self, max_children: int = 8):
        assert self.root is not None
        def _rec(node, depth):
            indent = "  " * depth
            if node.is_leaf:
                size = 0 if node.idxs is None else len(node.idxs)
                print(f"{indent}• [Leaf nid={node.nid}] size={size}")
            else:
                k = 0 if node.centroids is None else node.centroids.shape[0]
                print(f"{indent}◦ [Inner nid={node.nid}] k={k}")
                # por si hay muchos hijos, recorta la muestra visual
                children = node.children
                if len(children) > max_children:
                    head = children[:max_children//2]
                    tail = children[-max_children//2:]
                    for ch in head: _rec(ch, depth+1)
                    print(f"{indent}  ... ({len(children)-len(head)-len(tail)} hijos omitidos) ...")
                    for ch in tail: _rec(ch, depth+1)
                else:
                    for ch in children:
                        _rec(ch, depth+1)
        _rec(self.root, 0)

    def tree_stats(self):
        assert self.root is not None
        num_nodes = 0
        num_leaves = 0
        max_depth = 0
        leaf_sizes = []

        def _rec(node, depth):
            nonlocal num_nodes, num_leaves, max_depth
            num_nodes += 1
            max_depth = max(max_depth, depth)
            if node.is_leaf:
                num_leaves += 1
                leaf_sizes.append(0 if node.idxs is None else len(node.idxs))
            else:
                for ch in node.children:
                    _rec(ch, depth+1)
        _rec(self.root, 0)
        leaf_sizes = np.array(leaf_sizes)
        return {
            "num_nodes": num_nodes,
            "num_leaves": num_leaves,
            "max_depth": max_depth,
            "leaf_size_min": int(leaf_sizes.min()) if len(leaf_sizes) else 0,
            "leaf_size_med": float(np.median(leaf_sizes)) if len(leaf_sizes) else 0,
            "leaf_size_max": int(leaf_sizes.max()) if len(leaf_sizes) else 0,
        }

In [7]:
class KMeansTreeBeamSearch(KMeansTree):
    def __init__(
        self,
        B: int = 8,
        max_depth: int = 3,
        min_leaf_size: int = 256,
        kmeans_iters: int = 20,
        seed: int = 0,
        iso: bool = False,
        # --- nuevos parámetros de búsqueda ---
        beam: int = 3,              # ancho máximo del beam
        tau_margin: float = 0.02,   # margen adaptativo s1-s2
        last_level_probe: int = 2,  # nº extra de hojas a probar en el último nivel
        last_level_delta: float = 0.01  # umbral de similitud respecto al mejor centroide
    ):
        super().__init__(B, max_depth, min_leaf_size, kmeans_iters, seed, iso)
        # búsqueda mejorada
        self.beam = max(1, beam)
        self.tau_margin = float(tau_margin)
        self.last_level_probe = max(0, last_level_probe)
        self.last_level_delta = float(last_level_delta)

     # -------- routing mejorado --------
    def _route_adaptive_beam(self, q: np.ndarray) -> List[Node]:
        assert self.root is not None
        frontier = [self.root]
        depth = 0

        while True:
            new_frontier: List[Node] = []
            all_leaves = True

            for node in frontier:
                if node.is_leaf:
                    new_frontier.append(node)
                    continue

                all_leaves = False
                C = node.centroids  # (k,d)
                sims = C @ q
                order = np.argsort(-sims)

                s1 = sims[order[0]]
                s2 = sims[order[1]] if len(order) > 1 else -1.0
                local_beam = 1 if (s1 - s2) >= self.tau_margin else self.beam

                # ¿sus hijos son hojas? (último salto)
                next_is_leaf = all(ch.is_leaf for ch in node.children)

                if not next_is_leaf:
                    chosen = order[:local_beam]
                    new_frontier.extend([node.children[i] for i in chosen])
                else:
                    # multi-probe: añade hasta last_level_probe extras
                    chosen = order[:local_beam]
                    # añade vecinos cercanos por umbral Δ
                    if self.last_level_probe > 0:
                        s_best = sims[order[0]]
                        extra = []
                        for j in order[local_beam:]:
                            if s_best - sims[j] <= self.last_level_delta:
                                extra.append(j)
                                if len(extra) >= self.last_level_probe:
                                    break
                        chosen = np.concatenate([chosen, np.array(extra, dtype=int)]) if len(extra) else chosen
                    new_frontier.extend([node.children[i] for i in chosen])

            frontier = new_frontier
            depth += 1
            if all_leaves or depth >= self.max_depth:
                break

            # si por beam nos quedamos sin nodos (poco probable), salimos
            if len(frontier) == 0:
                break

        # aquí frontier debe ser una lista de hojas candidatas
        # quitamos duplicadas por si algún nodo se añadió dos veces
        unique_leaves = []
        seen = set()
        for lf in frontier:
            key = id(lf)
            if key not in seen:
                seen.add(key)
                unique_leaves.append(lf)
        return unique_leaves

    # -------- búsqueda sobre hojas candidatas --------
    def query(self, q: np.ndarray, topN: int = 10) -> List[Tuple[int, float]]:
        assert self.root is not None and self.X is not None, "Primero llama a fit()"
        q = l2_normalize(q.reshape(1, -1).astype(np.float32))[0]
        leaves = self._route_adaptive_beam(q)

        cand_ids = []
        for leaf in leaves:
            if leaf.idxs is None or leaf.idxs.size == 0:
                continue
            cand_ids.extend(leaf.idxs.tolist())
        if not cand_ids:
            return []

        cand_ids = np.unique(np.array(cand_ids, dtype=np.int64))
        sims = self.X[cand_ids] @ q
        k = min(topN, sims.size)
        top_idx = np.argpartition(-sims, k - 1)[:k]
        pairs = list(zip(cand_ids[top_idx].tolist(), sims[top_idx].tolist()))
        pairs.sort(key=lambda t: -t[1])
        return pairs

In [9]:
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings

model_embeddings = HuggingFaceEmbeddings(
    model_name="Qwen/Qwen3-Embedding-0.6B",
    model_kwargs={"device": "cuda"}
)

loaded_vectorstore=FAISS.load_local(
    "../data/db/parliament_db/parliament_all_docs_embeddings_Qwen_Qwen3-Embedding-0.6B_chunked_max_length-512",
    model_embeddings,
    allow_dangerous_deserialization=True
)

print(f"Loaded vector store contains {loaded_vectorstore.index.ntotal} vectors")

n = loaded_vectorstore.index.ntotal
d = loaded_vectorstore.index.d  # dimensión de los embeddings

Loaded vector store contains 128518 vectors


In [10]:
# get metadata of the index
pos_to_store_id = [loaded_vectorstore.index_to_docstore_id[i] for i in range(n)]

doc_ids = []
for store_id in pos_to_store_id:
    doc = loaded_vectorstore.docstore.search(store_id)  # recupera el Document
    md = getattr(doc, "metadata", {}) if doc is not None else {}
    # Ajusta la prioridad de claves según cómo lo guardaste
    for key in ("id", "doc_id", "document_id", "uid"):
        if key in md:
            doc_ids.append(md[key])
            break
    else:
        # Si no hay id en metadata, usa el store_id como fallback
        doc_ids.append(store_id)
pos_to_doc_id = {i: doc_ids[i] for i in range(n)}

In [11]:
X = np.array([emb for emb in loaded_vectorstore.index.reconstruct_n(0, loaded_vectorstore.index.ntotal)])
print(f"All embeddings shape: {X.shape}")

All embeddings shape: (128518, 1024)


In [None]:
tree = KMeansTreeBeamSearch(B=8, max_depth=2, min_leaf_size=5, kmeans_iters=20, seed=42, iso=False)
tree.fit(X)
tree.print_tree(max_children=10)
tree.tree_stats()

◦ [Inner nid=72] k=8
  ◦ [Inner nid=8] k=8
    • [Leaf nid=0] size=594
    • [Leaf nid=1] size=3136
    • [Leaf nid=2] size=1249
    • [Leaf nid=3] size=1909
    • [Leaf nid=4] size=479
    • [Leaf nid=5] size=2469
    • [Leaf nid=6] size=1801
    • [Leaf nid=7] size=2191
  ◦ [Inner nid=17] k=8
    • [Leaf nid=9] size=3439
    • [Leaf nid=10] size=2498
    • [Leaf nid=11] size=1334
    • [Leaf nid=12] size=3851
    • [Leaf nid=13] size=3697
    • [Leaf nid=14] size=2002
    • [Leaf nid=15] size=2383
    • [Leaf nid=16] size=3642
  ◦ [Inner nid=26] k=8
    • [Leaf nid=18] size=2447
    • [Leaf nid=19] size=2663
    • [Leaf nid=20] size=2466
    • [Leaf nid=21] size=2789
    • [Leaf nid=22] size=1950
    • [Leaf nid=23] size=2146
    • [Leaf nid=24] size=2012
    • [Leaf nid=25] size=1211
  ◦ [Inner nid=35] k=8
    • [Leaf nid=27] size=2009
    • [Leaf nid=28] size=1512
    • [Leaf nid=29] size=859
    • [Leaf nid=30] size=1144
    • [Leaf nid=31] size=910
    • [Leaf nid=32] size=1664
 

{'num_nodes': 73,
 'num_leaves': 64,
 'max_depth': 2,
 'leaf_size_min': 405,
 'leaf_size_med': 1887.0,
 'leaf_size_max': 4477}

In [15]:
doc_id = 3
path = tree.get_leaf_path(doc_id)   # p.ej. [0, 3, 1]
leaf = tree.get_leaf_node(doc_id)   # nodo hoja que contiene ese id
print("ruta:", path, "id:", "".join(str(x) for x in path))
print("tamaño hoja:", len(leaf.idxs))

ruta: [7, 3] id: 73
tamaño hoja: 1959


In [16]:
from datasets import load_from_disk

dataset = load_from_disk("../data/processed/parliament_qa")
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'response', 'cost', 'documents', 'type', 'retrieved_pks', 'oracle_context', 'formatted_context'],
        num_rows: 614
    })
    validation: Dataset({
        features: ['id', 'question', 'response', 'cost', 'documents', 'type', 'retrieved_pks', 'oracle_context', 'formatted_context'],
        num_rows: 161
    })
    test: Dataset({
        features: ['question', 'id', 'response', 'type', 'retrieved_pks', 'oracle_context', 'injected_oracle', 'formatted_context', 'documents'],
        num_rows: 205
    })
})

In [17]:
# create labels in dataset using the tree
def assign_labels(example):
    doc_id = example['id']
    id = next((k for k, v in pos_to_doc_id.items() if v == doc_id), None)
    try:
        path = tree.get_leaf_path(id)
        label = "".join(str(x) for x in path)
    except KeyError:
        label = "unknown"
    return {"label": label}

dataset = dataset.map(assign_labels)

Map: 100%|██████████| 614/614 [00:00<00:00, 870.58 examples/s]
Map: 100%|██████████| 161/161 [00:00<00:00, 875.39 examples/s]
Map: 100%|██████████| 205/205 [00:00<00:00, 746.72 examples/s]


In [18]:
len(set(dataset['train'][:]['label']))

37

In [20]:
from tqdm import tqdm

labels_real = []
labels_predicted = []
topN = 100
for idx in tqdm(range(len(dataset['test']))):
    query = dataset['test'][idx]['question']
    id_real = dataset['test'][idx]['id']
    query_emb = np.array(model_embeddings.embed_query(query))
    results = tree.query(query_emb, topN=topN)
    labels_predicted.append([pos_to_doc_id[idx] for idx, sim in results])
    labels_real.append([id_real])    

100%|██████████| 205/205 [00:12<00:00, 16.74it/s]


In [21]:
from ranking_metrics import calc_ranking_metrics

metrics = calc_ranking_metrics(labels_predicted, labels_real, ks=[1, 5, 10, 20, 100], one_relevant_per_query=True)

print("Ranking Metrics:")
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

Ranking Metrics:
  MRR: 0.5204
  mAP: 0.5204
  AvgRank: 4.5036
  CMC@1: 0.4732
  Recall@k (macro)@1: 0.4732
  Precision@k (macro)@1: 0.4732
  Accuracy@1: 0.4732
  F1@k (macro)@1: 0.4732
  CMC@5: 0.5902
  Recall@k (macro)@5: 0.5902
  Precision@k (macro)@5: 0.1180
  Accuracy@5: 0.5902
  F1@k (macro)@5: 0.1967
  CMC@10: 0.6000
  Recall@k (macro)@10: 0.6000
  Precision@k (macro)@10: 0.0600
  Accuracy@10: 0.6000
  F1@k (macro)@10: 0.1091
  CMC@20: 0.6293
  Recall@k (macro)@20: 0.6293
  Precision@k (macro)@20: 0.0315
  Accuracy@20: 0.6293
  F1@k (macro)@20: 0.0599
  CMC@100: 0.6780
  Recall@k (macro)@100: 0.6780
  Precision@k (macro)@100: 0.0068
  Accuracy@100: 0.6780
  F1@k (macro)@100: 0.0134


## LLM for routing

In [18]:
from datasets import DatasetDict, concatenate_datasets
dataset_clf = {}
dataset_clf['train'] = dataset['train'].remove_columns([col for col in dataset['train'].column_names if col not in ['question', 'label']])
dataset_clf['validation'] = dataset['validation'].remove_columns([col for col in dataset['validation'].column_names if col not in ['question', 'label']])
dataset_clf['test'] = dataset['test'].remove_columns([col for col in dataset['test'].column_names if col not in ['question', 'label']])

dataset_clf = DatasetDict(dataset_clf)

In [19]:
dataset_clf = dataset_clf.rename_column("question", "text")
dataset_clf

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 614
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 161
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 205
    })
})

In [20]:
all_data = concatenate_datasets([dataset_clf['train'], dataset_clf['validation'], dataset_clf['test']])
labels_list = all_data.unique('label')
print(f"Labels: {labels_list}")
num_labels = len(labels_list)
print(f"Number of labels: {num_labels}")

# map labels to integers
label_to_id = {label: i for i, label in enumerate(labels_list)}
def map_labels(example):
    return {
        "label": label_to_id[example['label']]
    }
dataset_clf = dataset_clf.map(map_labels)

Labels: ['4', '3', '0', '1', '2']
Number of labels: 5


In [21]:
dataset_clf["train"] = concatenate_datasets([dataset_clf['train'], dataset_clf['validation']]).shuffle(seed=42)                                            

In [40]:
from transformers import (
    AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, DataCollatorWithPadding, EarlyStoppingCallback
)
from peft import LoraConfig, get_peft_model

model_name = "Qwen/Qwen3-0.6B"
MAX_LENGTH = 512

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [41]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    low_cpu_mem_usage=True,
    device_map={"": 0}
)
model.config.pad_token_id = tokenizer.pad_token_id

Some weights of Qwen3ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen3-0.6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
lora_r = 64
lora_alpha = lora_r * 2
lora_dropout = 0.
lora_bias = "none"
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj"]

In [43]:
config = LoraConfig(
    task_type="SEQ_CLS",
    inference_mode=False,
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias=lora_bias,
    target_modules=target_modules
)
model = get_peft_model(model, config)

In [44]:
for p in model.base_model.model.score.parameters():
    p.requires_grad_(True)

In [45]:
from datasets import load_from_disk
FOLDER_AUTORE = "../data/processed/parliament_all_docs"
dataset_indexing = load_from_disk(FOLDER_AUTORE)["all"]
dataset_indexing = dataset_indexing.rename_column("PK", "id")

In [46]:
dataset_indexing = dataset_indexing.map(assign_labels)
# labels to integers
dataset_indexing = dataset_indexing.map(lambda example: {"label": int(example['label'])})

In [47]:
set(dataset_indexing['label'][:])

{0, 1, 2, 3, 4}

In [48]:
def preprocess(examples):
    return tokenizer(examples["text"], truncation=True, max_length=MAX_LENGTH, padding=False)

# tokenize test from dataset
tokenized_datasets = dataset_clf.map(preprocess, batched=True)
tokenized_datasets_indexing = dataset_indexing.map(preprocess, batched=True)

Map: 100%|██████████| 161/161 [00:00<00:00, 13892.42 examples/s]


In [49]:
tokenized_datasets_indexing = tokenized_datasets_indexing.remove_columns(["id"])

In [50]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)

def softmax(x, axis=-1):
    x = np.asarray(x, dtype=np.float64)
    # Restar el máximo para evitar overflow
    x_shift = x - np.max(x, axis=axis, keepdims=True)
    exps = np.exp(x_shift)
    return exps / np.sum(exps, axis=axis, keepdims=True)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = (preds == labels).mean()

    return {
        "accuracy": acc
    }

In [51]:
SEED = 42
EPOCHS = 10
BATCH_SIZE = 8

In [53]:
training_args_indexing = TrainingArguments(
    output_dir=f"models/parlamento_clf_{model_name.replace('/', '_')}_indexing",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    num_train_epochs=2,
    learning_rate=5e-5,
    weight_decay=0.05,
    warmup_ratio=0.2,
    lr_scheduler_type="cosine",
    eval_strategy="steps",
    save_strategy="no",        # <-- no guarda checkpoints ni el modelo final
    eval_steps=10,
    logging_steps=10,
    load_best_model_at_end=False,  # <-- desactivado porque no hay checkpoints
    fp16=True,
    report_to="none",
    seed=SEED,
    label_smoothing_factor=0.01,
)

training_args_query = TrainingArguments(
    output_dir=f"models/parlamento_clf_{model_name.replace('/', '_')}_query",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    num_train_epochs=EPOCHS,
    learning_rate=5e-5,
    weight_decay=0.05,
    warmup_ratio=0.2,
    lr_scheduler_type="cosine",
    eval_strategy="steps",
    save_strategy="no",        # <-- no guarda checkpoints ni el modelo final
    eval_steps=10,
    logging_steps=10,
    load_best_model_at_end=False,  # <-- desactivado porque no hay checkpoints
    fp16=True,
    report_to="none",
    seed=SEED,
    label_smoothing_factor=0.01,
)

# IDs de tokens
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id

trainer_indexing = Trainer(
    model=model,
    args=training_args_indexing,
    train_dataset=tokenized_datasets_indexing,
    eval_dataset=tokenized_datasets["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer_query = Trainer(
    model=model,
    args=training_args_query,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

In [54]:
model.print_trainable_parameters()  # Verificar parámetros entrenables

trainable params: 33,040,384 || all params: 629,090,304 || trainable%: 5.2521


In [58]:
for epoch in range(1):
    print(f"=== Indexing Epoch {epoch+1}/{2} ===")
    trainer_indexing.train()
    print(f"=== Query Epoch {epoch+1}/{2} ===")
    trainer_query.train()

=== Indexing Epoch 1/2 ===


Step,Training Loss,Validation Loss,Accuracy
10,0.6524,2.162805,0.478049
20,0.6955,2.097874,0.492683
30,0.5988,1.993405,0.492683
40,0.4278,1.882646,0.497561
50,0.5121,1.806089,0.492683
60,0.4461,1.801856,0.502439
70,0.3278,1.808368,0.502439
80,0.204,1.822186,0.487805
90,0.4254,1.83712,0.492683
100,0.2453,1.843272,0.492683


=== Query Epoch 1/2 ===


Step,Training Loss,Validation Loss,Accuracy
10,1.2995,2.29879,0.326829
20,0.6756,2.028748,0.42439
30,0.2368,2.404466,0.37561
40,0.1443,2.297593,0.453659
50,0.0714,2.120533,0.536585
60,0.067,2.122047,0.570732
70,0.0631,2.096563,0.536585
80,0.0638,2.127107,0.502439
90,0.0637,2.113527,0.55122
100,0.0824,1.946852,0.560976


In [59]:
metrics = trainer_query.evaluate()
print("Evaluation metrics after training:", metrics)

Evaluation metrics after training: {'eval_loss': 2.5644822120666504, 'eval_accuracy': 0.5073170731707317, 'eval_runtime': 1.5834, 'eval_samples_per_second': 129.465, 'eval_steps_per_second': 16.42, 'epoch': 10.0}


In [62]:
idxs = tree.get_leaf_node(0).idxs
# get embeddings of those idxs
X_leaf = X[idxs]
X_leaf.shape

(2595, 1024)

In [66]:
import torch


def query_llm(llm_clf, q: np.ndarray, q_text: str, tokenizer, topN: int = 10, topK: int = 1) -> List[Tuple[int, float]]:
    q = l2_normalize(q.reshape(1, -1).astype(np.float32))[0]
    # predice la etiqueta con el clasificador
    with torch.no_grad():
        inputs = tokenizer(q_text, return_tensors="pt", truncation=True, max_length=512).to(llm_clf.model.device)
        outputs = llm_clf.model(**inputs)
        logits = outputs.logits
        pred_label_id = int(torch.argmax(logits, dim=-1).cpu().numpy()[0])

    leaves = [tree.get_leaf_node(pred_label_id)]
    cand_ids = []
    for leaf in leaves:
        if leaf.idxs is None or leaf.idxs.size == 0:
            continue
        cand_ids.extend(leaf.idxs.tolist())
    if not cand_ids:
        return []

    cand_ids = np.unique(np.array(cand_ids, dtype=np.int64))
    sims = X[cand_ids] @ q
    k = min(topN, sims.size)
    top_idx = np.argpartition(-sims, k - 1)[:k]
    pairs = list(zip(cand_ids[top_idx].tolist(), sims[top_idx].tolist()))
    pairs.sort(key=lambda t: -t[1])
    return pairs

In [71]:
from tqdm import tqdm

labels_real = []
labels_predicted = []
topN = 100
for idx in tqdm(range(len(dataset['test']))):
    query = dataset['test'][idx]['question']
    id_real = dataset['test'][idx]['id']
    query_emb = np.array(model_embeddings.embed_query(query))
    results = query_llm(trainer_query, query_emb, query, tokenizer, topN=topN)
    labels_predicted.append([pos_to_doc_id[idx] for idx, sim in results])
    labels_real.append([id_real])  

100%|██████████| 205/205 [00:24<00:00,  8.54it/s]


In [72]:
metrics = calc_ranking_metrics(labels_predicted, labels_real, ks=[1, 5, 10, 20, 100], one_relevant_per_query=True)

print("Ranking Metrics:")
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

Ranking Metrics:
  MRR: 0.3385
  mAP: 0.3385
  AvgRank: 6.0698
  CMC@1: 0.3122
  Recall@k (macro)@1: 0.3122
  Precision@k (macro)@1: 0.3122
  Accuracy@1: 0.3122
  F1@k (macro)@1: 0.3122
  CMC@5: 0.3659
  Recall@k (macro)@5: 0.3659
  Precision@k (macro)@5: 0.0732
  Accuracy@5: 0.3659
  F1@k (macro)@5: 0.1220
  CMC@10: 0.3756
  Recall@k (macro)@10: 0.3756
  Precision@k (macro)@10: 0.0376
  Accuracy@10: 0.3756
  F1@k (macro)@10: 0.0683
  CMC@20: 0.3854
  Recall@k (macro)@20: 0.3854
  Precision@k (macro)@20: 0.0193
  Accuracy@20: 0.3854
  F1@k (macro)@20: 0.0367
  CMC@100: 0.4195
  Recall@k (macro)@100: 0.4195
  Precision@k (macro)@100: 0.0042
  Accuracy@100: 0.4195
  F1@k (macro)@100: 0.0083


## Other

In [73]:
tree_beam = KMeansTreeBeamSearch(B=5, max_depth=1, min_leaf_size=5, kmeans_iters=20, seed=42, iso=True)
tree_beam.fit(X)
tree_beam.print_tree(max_children=10)
stats = tree_beam.tree_stats()
print("Estadísticas del árbol con beam search:", stats)

◦ [Inner nid=5] k=5
  • [Leaf nid=0] size=1662
  • [Leaf nid=1] size=2758
  • [Leaf nid=2] size=2595
  • [Leaf nid=3] size=2504
  • [Leaf nid=4] size=1643
Estadísticas del árbol con beam search: {'num_nodes': 6, 'num_leaves': 5, 'max_depth': 1, 'leaf_size_min': 1643, 'leaf_size_med': 2504.0, 'leaf_size_max': 2758}


In [75]:
labels_real = []
labels_predicted = []
topN = 100
for idx in tqdm(range(len(dataset['test']))):
    query = dataset['test'][idx]['question']
    id_real = dataset['test'][idx]['id']
    query_emb = np.array(model_embeddings.embed_query(query))
    results = tree_beam.query(query_emb, topN=topN)
    labels_predicted.append([pos_to_doc_id[idx] for idx, sim in results])
    labels_real.append([id_real])

100%|██████████| 205/205 [00:09<00:00, 22.60it/s]


In [76]:
metrics = calc_ranking_metrics(labels_predicted, labels_real, ks=[1, 5, 10, 20, 100], one_relevant_per_query=True)

print("Ranking Metrics:")
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

Ranking Metrics:
  MRR: 0.5339
  mAP: 0.5339
  AvgRank: 6.3406
  CMC@1: 0.4927
  Recall@k (macro)@1: 0.4927
  Precision@k (macro)@1: 0.4927
  Accuracy@1: 0.4927
  F1@k (macro)@1: 0.4927
  CMC@5: 0.5805
  Recall@k (macro)@5: 0.5805
  Precision@k (macro)@5: 0.1161
  Accuracy@5: 0.5805
  F1@k (macro)@5: 0.1935
  CMC@10: 0.6098
  Recall@k (macro)@10: 0.6098
  Precision@k (macro)@10: 0.0610
  Accuracy@10: 0.6098
  F1@k (macro)@10: 0.1109
  CMC@20: 0.6293
  Recall@k (macro)@20: 0.6293
  Precision@k (macro)@20: 0.0315
  Accuracy@20: 0.6293
  F1@k (macro)@20: 0.0599
  CMC@100: 0.6732
  Recall@k (macro)@100: 0.6732
  Precision@k (macro)@100: 0.0067
  Accuracy@100: 0.6732
  F1@k (macro)@100: 0.0133


In [None]:
from itertools import product
VALUES_B = [5, 10, 20, 40, 80]
VALUES_MAX_DEPTH = [1, 2, 3, 4]
VALUES_MIN_LEAF_SIZE = [5, 10, 20, 40, 80, 100, 200]
VALUES_KMEANS_ITERS = [5, 10, 20, 40, 80]
# contruir combinaciones en una lista
param_combinations = list(product(VALUES_B, VALUES_MAX_DEPTH, VALUES_MIN_LEAF_SIZE, VALUES_KMEANS_ITERS))
print(f"Total parameter combinations to try: {len(param_combinations)}")

Total parameter combinations to try: 700


In [None]:
best_mrr = 0.0
for b, depth, min_size, iters in tqdm(param_combinations):
    #print(f"Probando B={b}, max_depth={depth}, min_leaf_size={min_size}, kmeans_iters={iters}")
    tree = KMeansTreeBeamSearch(B=b, max_depth=depth, min_leaf_size=min_size, kmeans_iters=iters, seed=42)
    tree.fit(X)
    
    labels_predicted = []

    for idx in range(len(dataset['test'])):
        query = dataset['test'][idx]['question']
        query_emb = np.array(model.embed_query(query))
        results = tree.query(query_emb, topN=topN)
        labels_predicted.append([pos_to_doc_id[idx] for idx, sim in results])
    
    metrics = calc_ranking_metrics(labels_predicted, labels_real, ks=[1, 5, 10, 20, 100], one_relevant_per_query=True)
    
    if metrics['MRR'] > best_mrr:
        best_mrr = metrics['MRR']
        print(f"MRR:{metrics['MRR']:.4f} @ B={b}, max_depth={depth}, min_leaf_size={min_size}, kmeans_iters={iters}")

print(f"Best MRR: {best_mrr:.4f}")

  0%|          | 1/700 [00:10<2:01:32, 10.43s/it]

MRR:0.4763 @ B=5, max_depth=1, min_leaf_size=5, kmeans_iters=5


  0%|          | 2/700 [00:20<2:01:21, 10.43s/it]

MRR:0.4910 @ B=5, max_depth=1, min_leaf_size=5, kmeans_iters=10


  0%|          | 3/700 [00:30<1:56:27, 10.02s/it]

MRR:0.5315 @ B=5, max_depth=1, min_leaf_size=5, kmeans_iters=20


  1%|          | 6/700 [01:09<2:14:19, 11.61s/it]


KeyboardInterrupt: 

In [None]:
labels_predicted = []
retriever=loaded_vectorstore.as_retriever(search_kwargs={"k":topN})

for idx in tqdm(range(len(dataset['test']))):
    query = dataset['test'][idx]['question']
    docs = retriever.get_relevant_documents(query)
    labels_predicted.append([doc.metadata['id'] for doc in docs])



  docs = retriever.get_relevant_documents(query)
100%|██████████| 205/205 [00:06<00:00, 33.28it/s]


In [None]:
metrics = calc_ranking_metrics(labels_predicted, labels_real, ks=[1, 5, 10, 20, 100], one_relevant_per_query=True)

print("Ranking Metrics:")
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

Ranking Metrics:
  MRR: 0.6506
  mAP: 0.6506
  AvgRank: 6.8743
  CMC@1: 0.5854
  Recall@k (macro)@1: 0.5854
  Precision@k (macro)@1: 0.5854
  Accuracy@1: 0.5854
  F1@k (macro)@1: 0.5854
  CMC@5: 0.7268
  Recall@k (macro)@5: 0.7268
  Precision@k (macro)@5: 0.1454
  Accuracy@5: 0.7268
  F1@k (macro)@5: 0.2423
  CMC@10: 0.7610
  Recall@k (macro)@10: 0.7610
  Precision@k (macro)@10: 0.0761
  Accuracy@10: 0.7610
  F1@k (macro)@10: 0.1384
  CMC@20: 0.7951
  Recall@k (macro)@20: 0.7951
  Precision@k (macro)@20: 0.0398
  Accuracy@20: 0.7951
  F1@k (macro)@20: 0.0757
  CMC@100: 0.8537
  Recall@k (macro)@100: 0.8537
  Precision@k (macro)@100: 0.0085
  Accuracy@100: 0.8537
  F1@k (macro)@100: 0.0169
