In [1]:
import os
import pickle
import sys

import numpy as np
import torch
import yaml
import json
import pandas as pd

# ChromaDB imports
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import Dict, List, Optional, Tuple, Union

try:
    # This will work in scripts where __file__ is defined
    current_dir = os.path.dirname(os.path.abspath(__file__))
    # Assuming "src" is parallel to the script folder
    project_root = os.path.abspath(os.path.join(current_dir, ".."))
except NameError:
    # In notebooks __file__ is not defined: assume we're in notebooks/
    project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

src_path = os.path.join(project_root, "src")
if src_path not in sys.path:
    sys.path.append(src_path)

from main.node_embedding_models import GraphSAGE, GraphSAGE_V2
from main.ollama_utils import get_ollama_embedding
from main.sage_utils import get_new_sage_embedding


## Retrieval

In [2]:
bundle_directory = "bsard_V2_3hop_15epochs_1024-896-768-640"

# Load Graph
graph_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_directory, "graph.pkl"
    )
with open(graph_path, "rb") as f:
    graph = pickle.load(f)

# Load sage config
config_file_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_directory, "config.yaml"
    )
with open(config_file_path, "r") as f:
    config = yaml.safe_load(f)


# Load SAGE model
sage_model_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_directory, "graphsage.pth"
    )

# Define model and load weights
model = GraphSAGE_V2(
    channels=config["model_params"]["channels"],   
    )
try:
    model.load_state_dict(torch.load(sage_model_path))
except:
    model.load_state_dict(torch.load(sage_model_path, map_location=torch.device('cpu')))
model.eval() 

# Set device for model inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device
model = model.to(device)

# Load queries
queries_path = os.path.join(
    project_root, "data", "BSARD_dataset", "bsard_train_questions.csv"
    )

queries_df = pd.read_csv(queries_path)

In [3]:
#queries_subset_df = queries_df.head(10)
queries_subset_df = queries_df


queries = list(queries_subset_df['question'])

In [4]:
def parallel_retrieval_pipeline(graph, queries, sage_model, semantic_embedder=get_ollama_embedding):

    # Get queries semantic embedding
    query_semantic_embeddings = []
    for query in queries:
        query_semantic_embeddings.append(semantic_embedder(query)['embedding'] )
    
    # Get queries sage embedding
    query_sage_embeddings = []
    for query in query_semantic_embeddings:
        query_sage_embeddings.append(get_new_sage_embedding(model=sage_model, new_feature=torch.tensor(query, dtype=torch.float32).reshape(model.convs[0].in_channels), device=device))

    # Create two chromadb collections
    chroma_client = chromadb.Client()

        # One for the semantic embeddings
    collection_semantic = chroma_client.create_collection(name="semantic_collection")
        # Fill collection
    for node_id in graph.nodes:
        node_data = graph.nodes[node_id]
        if node_data.get("node_type") == "Article":
            doc_text = node_data.get("article_text", "")
            doc_embedding = node_data.get("embedding", None)
        
            # Ensure that doc_embedding is a list (or array) of floats
            # Add the node to the collection by specifying IDs, documents, and embeddings
            collection_semantic.add(
                ids=[str(node_id)],         # The ID will be the node's ID (converted to a string)
                documents=[doc_text],         # The node's text
                embeddings=[doc_embedding]    # The precomputed embedding you already have
                )
            
        # One for the sage embeddings
    collection_sage = chroma_client.create_collection(name="sage_collection")
        # Fill collection
    for node_id in graph.nodes:
        node_data = graph.nodes[node_id]
        if node_data.get("node_type") == "Article":
            doc_text = node_data.get("article_text", "")
            doc_embedding = node_data.get("hybrid_embedding", None)
        
            # Ensure that doc_embedding is a list (or array) of floats
            # Add the node to the collection by specifying IDs, documents, and embeddings
            collection_sage.add(
                ids=[str(node_id)],         # The ID will be the node's ID (converted to a string)
                documents=[doc_text],         # The node's text
                embeddings=[doc_embedding]    # The precomputed embedding you already have
                )

    # Retrieve top k documents from the semantic collection based on the query semantic embedding for each query (based on semantic embedding)
    semantic_results = []
    for query in query_semantic_embeddings:
        topk_semantic    = collection_semantic.query(
            query_embeddings=[query] ,
            n_results=10
        )
        semantic_results.append(topk_semantic["ids"])


    # Retrieve top k documents from the sage collection based on the query sage embedding for each query (based on sage embedding)
    sage_results = []
    for query in query_sage_embeddings:
        topk_sage    = collection_sage.query(
            query_embeddings=query.detach().cpu().unsqueeze(0).tolist(),
            n_results=10
        )
        sage_results.append(topk_sage["ids"])

    
    return semantic_results, sage_results

In [5]:
semantic_results, sage_results = parallel_retrieval_pipeline(graph, queries, model)

In [6]:
# semantic_results

In [7]:
# sage_results

## Evaluation

In [8]:
# queries_df
queries_subset_df

# Load bsard corpus lean
bsard_corpus_lean_path = os.path.join(
    project_root, "data", "BSARD_dataset", "bsard_corpus_lean_V2.csv"
    )
bsard_corpus_lean = pd.read_csv(bsard_corpus_lean_path)

In [9]:
article_ids_code = []
for article_ids_list in queries_subset_df['article_ids']:
    tmp = []
    for article_id in article_ids_list.split(','):
        #print(article_id)
        tmp.append(bsard_corpus_lean[bsard_corpus_lean['id'] == int(article_id)]['article_code'].values[0])
    article_ids_code.append(tmp)
    #print(tmp)

queries_subset_df['article_ids_codes'] = article_ids_code


In [10]:
queries_subset_df

Unnamed: 0,id,category,subcategory,question,extra_description,article_ids,article_ids_codes
0,1102,Travail,Travail et parentalité,Je suis travailleur salarié(e). Puis-je refuse...,Pendant la grossesse,"22225,22226,22227,22228,22229,22230,22231,2223...","[33.10.5.1, 33.10.5.2, 33.10.5.3, 33.10.5.4, 3..."
1,91,Argent,Dettes,Peut-on saisir tous mes revenus ?,"Procédures de récupération des dettes, Récupér...",585358545855,"[10.0.4.82, 10.0.4.83, 10.0.4.84]"
2,474,Famille,Situation de couples,Je suis marié(e). Nous sommes mariés. Dois-je ...,Mariage,109610971098110811091110,"[4.2.8.4, 4.2.8.5, 4.2.8.6, 4.2.8.16, 4.2.8.17..."
3,836,Logement,Location en Wallonie,Je mets un kot en location (bail de droit comm...,"Mettre un logement en location (Wallonie), Doi...",12012120301203112032120331203412035,"[21.0.1.1, 21.0.2.15, 21.0.2.16, 21.0.2.17, 21..."
4,1079,Travail,Maladie - incapacité de travail,Suis-je payé pendant la procédure du trajet de...,Rupture du contrat de travail pour force majeu...,"21114,21115,21116,21117,21118,21119,21120,2112...","[33.1.4.72, 33.1.4.73, 33.1.4.74, 33.1.4.75, 3..."
...,...,...,...,...,...,...,...
881,308,Famille,Lien parents/enfants,Quel est le rôle du tuteur d'un enfant mineur ?,"Tutelle, Rôle du tuteur",13211322132313241325132613271328,"[4.2.12.18, 4.2.12.19, 4.2.12.20, 4.2.12.21, 4..."
882,387,Famille,Personnes à l'autonomie fragilisée,Peut-on changer d'administrateur de biens et/o...,Administration de biens et/ou de la personne (...,1381,[4.2.13.28]
883,940,Logement,Location à Bruxelles,Mon propriétaire ne fait pas les réparations n...,"Bail de résidence principale (Bruxelles), Loye...",2494,[4.0.9.26]
884,364,Famille,Personnes à l'autonomie fragilisée,Que faire si le montant des frais et honoraire...,Administration de biens et/ou de la personne (...,1387,[4.2.13.34]


In [11]:
# include the retrieved articles into queries_df as a new column
ret_semantic = []
ret_sage = []

for i in range(len(semantic_results)):
    n_articles = len(queries_df['article_ids_codes'][i])
    ret_semantic.append(semantic_results[i][0][:n_articles])
    ret_sage.append(sage_results[i][0][:n_articles])

queries_subset_df['ret_semantic'] = ret_semantic
queries_subset_df['ret_sage'] = ret_sage

In [12]:
# Función de Recall@K para una sola fila
def recall_at_k(true_ids, pred_ids):
    true_set = set(true_ids)
    # si no hay relevantes, definimos recall=0 (o 1 si lo prefieres)
    if not true_set:
        return 0.0
    # Intersección entre lo que debería y lo que recuperaste
    hits = true_set & set(pred_ids)
    return len(hits) / len(true_set)

# Aplicamos al DataFrame
queries_subset_df['recall_semantic'] = queries_subset_df.apply(
    lambda row: recall_at_k(row['article_ids_codes'], row['ret_semantic']),
    axis=1
)
queries_subset_df['recall_sage'] = queries_subset_df.apply(
    lambda row: recall_at_k(row['article_ids_codes'], row['ret_sage']),
    axis=1
)

In [13]:
macro_recall_semantic = queries_subset_df['recall_semantic'].mean()
print("Macro recall semantic: " ,macro_recall_semantic)
macro_recall_sage     = queries_subset_df['recall_sage'].mean()
print("Macro recall sage: " ,macro_recall_sage)

Macro recall semantic:  0.09967588011287074
Macro recall sage:  0.003466623669783941


In [14]:
queries_subset_df.sample(10)

Unnamed: 0,id,category,subcategory,question,extra_description,article_ids,article_ids_codes,ret_semantic,ret_sage,recall_semantic,recall_sage
197,527,Famille,Vivre en couple,Je deviens cohabitant légal. Quelles démarches...,Cohabitation légale,2329,[4.0.6.2],[4.0.6.1],[16.8.4.195],0.0,0.0
80,789,Logement,Location en Wallonie,Dois-je partir avant la mise en oeuvre de l'ex...,"Bail de résidence principale (Wallonie), Fin e...",55625563556455655566,"[10.6.0.360, 10.6.0.361, 10.6.0.362, 10.6.0.36...","[29.1.3.51, 29.1.2.17, 29.1.5.19, 16.6.14.80, ...","[16.8.4.195, 10.6.0.364, 16.8.4.239, 16.8.4.23...",0.0,0.2
147,587,Justice,Au tribunal,Je suis convoqué en justice. Qu'est-ce qu'une ...,"Au civil, Déroulement des audiences","4826,4827,4828,4829,4830,4831,4832,4833,4834,4...","[10.3.2.74, 10.3.2.75, 10.3.2.76, 10.3.2.77, 1...","[10.3.6.187, 10.0.0.14, 20.0.1.27, 10.0.0.12, ...","[16.8.4.195, 10.6.0.364, 16.8.4.239, 16.8.4.23...",0.0,0.0
783,1052,Protection sociale,Grossesse et naissance,A-t-on droit à l'allocation de naissance en ca...,Allocations familiales et allocation de naissa...,947948,"[4.2.2.54, 4.2.2.55]","[4.2.2.54, 4.2.2.55]","[16.8.4.195, 10.6.0.364]",1.0,0.0
589,277,Famille,Lien parents/enfants,Les parents ne sont pas mariés ensemble. La mè...,"Lien de filiation, Etablir un lien de filiatio...","931,932,933,1096,1097,1098,1099,1100,1101,1102...","[4.2.2.38, 4.2.2.39, 4.2.2.40, 4.2.8.4, 4.2.8....","[4.2.8.7, 4.2.8.12, 4.2.8.4, 4.2.11.7, 4.2.8.1...","[16.8.4.195, 10.6.0.364, 16.8.4.238, 16.8.4.23...",0.166667,0.0
497,662,Justice,Infractions,Je suis victime d’une infraction. La police es...,"Procédure pénale, Procédure au poste de police",13219132201322113222132231322413225,"[24.1.0.18, 24.1.0.19, 24.1.0.20, 24.1.0.21, 2...","[12.1.0.7, 12.1.0.1, 24.2.1.60, 24.1.0.216, 24...","[16.8.4.195, 10.6.0.364, 16.8.4.238, 16.8.4.23...",0.0,0.0
680,682,Justice,L'avocat,Puis-je avoir l'assistance gratuite d'un avoca...,Aide juridique (ex pro deo),"4478,4479,4480,4481,4482,4483,4484,4485,4486,4...","[10.5.0.1, 10.5.0.2, 10.5.0.3, 10.5.0.4, 10.5....","[10.5.0.17, 10.5.0.13, 10.5.0.25, 10.5.0.7, 10...","[16.8.4.195, 10.6.0.364, 16.8.4.239, 16.8.4.23...",0.157895,0.0
64,121,Argent,Dettes,Que faire si je ne suis pas d'accord avec une ...,"Solutions pour lutter contre l'endettement, Rè...",5152515351545961,"[10.4.4.22, 10.4.4.23, 10.4.4.24, 10.0.5.32]","[10.3.2.148, 10.0.5.7, 10.3.6.175, 18.0.12.57]","[16.8.4.195, 10.6.0.364, 16.8.4.239, 16.8.4.238]",0.0,0.0
769,842,Logement,Location à Bruxelles,Quelles sont les conséquences du bail de résid...,"Bail de résidence principale (Bruxelles), Cham...",833834835,"[3.0.11.1, 3.0.11.2, 3.0.11.3]","[4.0.9.14, 4.0.9.12, 3.0.7.2]","[16.8.4.195, 10.6.0.364, 16.8.4.238]",0.0,0.0
365,479,Famille,Situation de couples,Je me marie. Quelles démarches dois-je effectu...,Mariage,102210231024,"[4.2.6.14, 4.2.6.15, 4.2.6.16]","[26.0.0.49, 4.2.6.21, 4.2.6.4]","[16.8.4.195, 10.6.0.364, 16.8.4.239]",0.0,0.0


In [15]:
#def recall_at_k(true_ids, pred_ids, k):
#    """Recall@K para UNA sola consulta."""
#    true_set = set(true_ids)
#    topk = set(pred_ids[:k])
#    return len(true_set & topk) / len(true_set)
#
#
#def average_precision(true_ids, pred_ids, k=None):
#    """Average Precision para UNA sola consulta."""
#    if k is not None:
#        pred_ids = pred_ids[:k]
#    true_set = set(true_ids)
#    hits = 0
#    sum_prec = 0.0
#    for i, pid in enumerate(pred_ids, start=1):
#        if pid in true_set:
#            hits += 1
#            sum_prec += hits / i
#    # normalizamos por el total de relevantes
#    return sum_prec / len(true_set) if true_set else 0.0
#
#
#def r_precision(true_ids, pred_ids):
#    """Precision@R para UNA sola consulta."""
#    R = len(true_ids)
#    return recall_at_k(true_ids, pred_ids, R)
#
#
## --- Cálculo agregado sobre todo el testset ---------------------
#
## supongamos:
##   q2true[q] = lista de IDs relevantes para la pregunta q
##   q2pred[q] = lista de IDs devuelta por tu sistema RAG (ordenada por similitud)
#
#recalls, aps, rps = [], [], []
#
#for q in q2true:
#    true_ids = q2true[q]
#    pred_ids = q2pred[q]
#
#    recalls.append(recall_at_k(true_ids, pred_ids, k=10))
#    aps.append(average_precision(true_ids, pred_ids))
#    rps.append(r_precision(true_ids, pred_ids))
#
#mean_recall10 = sum(recalls) / len(recalls)
#map_score      = sum(aps)     / len(aps)
#mean_rp        = sum(rps)     / len(rps)

#print(f"Recall@10: {mean_recall10:.3f}")
#print(f"MAP:        {map_score:.3f}")
#print(f"mRP:        {mean_rp:.3f}")
