In [1]:
import os
import pickle
import sys

import numpy as np
import torch
import yaml
import json
import pandas as pd

# ChromaDB imports
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import Dict, List, Optional, Tuple, Union

try:
    # This will work in scripts where __file__ is defined
    current_dir = os.path.dirname(os.path.abspath(__file__))
    # Assuming "src" is parallel to the script folder
    project_root = os.path.abspath(os.path.join(current_dir, ".."))
except NameError:
    # In notebooks __file__ is not defined: assume we're in notebooks/
    project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

src_path = os.path.join(project_root, "src")
if src_path not in sys.path:
    sys.path.append(src_path)

from main.node_embedding_models import GraphSAGE
from main.ollama_utils import get_ollama_embedding
from main.sage_utils import get_new_sage_embedding


## Retrieval

In [2]:
bundle_directory = "bsard_2hop_20e_1024-768-512"

# Load Graph
graph_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_directory, "graph.pkl"
    )
with open(graph_path, "rb") as f:
    graph = pickle.load(f)

# Load sage config
config_file_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_directory, "config.yaml"
    )
with open(config_file_path, "r") as f:
    config = yaml.safe_load(f)


# Load SAGE model
sage_model_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_directory, "graphsage.pth"
    )

# Define model and load weights
model = GraphSAGE(
    channels=config["model_params"]["channels"],   
    )
try:
    model.load_state_dict(torch.load(sage_model_path))
except:
    model.load_state_dict(torch.load(sage_model_path, map_location=torch.device('cpu')))
model.eval() 

# Set device for model inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device
model = model.to(device)

# Load queries
queries_path = os.path.join(
    project_root, "data", "BSARD_dataset", "bsard_train_questions.csv"
    )

queries_df = pd.read_csv(queries_path)

In [3]:
queries_subset_df = queries_df.head(10)
queries = list(queries_subset_df['question'])

In [5]:
def parallel_retrieval_pipeline(graph, queries, sage_model, semantic_embedder=get_ollama_embedding):

    # Get queries semantic embedding
    query_semantic_embeddings = []
    for query in queries:
        query_semantic_embeddings.append(semantic_embedder(query)['embedding'] )
    
    # Get queries sage embedding
    query_sage_embeddings = []
    for query in query_semantic_embeddings:
        query_sage_embeddings.append(get_new_sage_embedding(model=sage_model, new_feature=torch.tensor(query, dtype=torch.float32).reshape(model.convs[0].in_channels), device=device))

    # Create two chromadb collections
    chroma_client = chromadb.Client()

        # One for the semantic embeddings
    collection_semantic = chroma_client.create_collection(name="semantic_collection")
        # Fill collection
    for node_id in graph.nodes:
        node_data = graph.nodes[node_id]
        if node_data.get("node_type") == "Article":
            doc_text = node_data.get("article_text", "")
            doc_embedding = node_data.get("embedding", None)
        
            # Ensure that doc_embedding is a list (or array) of floats
            # Add the node to the collection by specifying IDs, documents, and embeddings
            collection_semantic.add(
                ids=[str(node_id)],         # The ID will be the node's ID (converted to a string)
                documents=[doc_text],         # The node's text
                embeddings=[doc_embedding]    # The precomputed embedding you already have
                )
            
        # One for the sage embeddings
    collection_sage = chroma_client.create_collection(name="sage_collection")
        # Fill collection
    for node_id in graph.nodes:
        node_data = graph.nodes[node_id]
        if node_data.get("node_type") == "Article":
            doc_text = node_data.get("article_text", "")
            doc_embedding = node_data.get("hybrid_embedding", None)
        
            # Ensure that doc_embedding is a list (or array) of floats
            # Add the node to the collection by specifying IDs, documents, and embeddings
            collection_sage.add(
                ids=[str(node_id)],         # The ID will be the node's ID (converted to a string)
                documents=[doc_text],         # The node's text
                embeddings=[doc_embedding]    # The precomputed embedding you already have
                )

    # Retrieve top k documents from the semantic collection based on the query semantic embedding for each query (based on semantic embedding)
    semantic_results = []
    for query in query_semantic_embeddings:
        topk_semantic    = collection_semantic.query(
            query_embeddings=[query] ,
            n_results=10
        )
        semantic_results.append(topk_semantic["ids"])


    # Retrieve top k documents from the sage collection based on the query sage embedding for each query (based on sage embedding)
    sage_results = []
    for query in query_sage_embeddings:
        topk_sage    = collection_sage.query(
            query_embeddings=query.detach().cpu().unsqueeze(0).tolist(),
            n_results=10
        )
        sage_results.append(topk_sage["ids"])

    
    return semantic_results, sage_results

In [6]:
semantic_results, sage_results = parallel_retrieval_pipeline(graph, queries, model)

Je suis travailleur salarié(e). Puis-je refuser de faire des heures supplémentaires ou de travailler de nuit ?
Peut-on saisir tous mes revenus ?
Je suis marié(e). Nous sommes mariés. Dois-je reconnaître mon enfant ?
Je mets un kot en location (bail de droit commun). Pour quels logements le permis de location est-il nécessaire en Wallonie ?
Suis-je payé pendant la procédure du trajet de réintégration ?
[-0.9013447165489197, 0.2166704535484314, -0.8958656191825867, 0.7593201994895935, -1.0512361526489258, -0.6891542077064514, -0.7547259330749512, 0.37134474515914917, -0.15967929363250732, -0.18797999620437622, -0.1272975653409958, 0.14231307804584503, -1.022972583770752, -0.6737866997718811, 1.1148275136947632, -0.5734367966651917, 0.050500936806201935, -1.329225778579712, -0.7811251282691956, -0.04244634136557579, -0.9014313220977783, -0.542297899723053, -1.4358255863189697, 0.7186563014984131, 0.7013106346130371, 0.512730062007904, -0.6749077439308167, -0.1051284670829773, 1.4610470533

In [7]:
semantic_results

[[['33.10.58',
   '14.2.48',
   '14.2.50',
   '14.2.32',
   '29.3.28',
   '33.10.1',
   '33.10.2',
   '33.10.6',
   '33.10.3',
   '33.5.98']],
 [['16.6.615',
   '16.6.494',
   '16.6.562',
   '4.3.57',
   '34.2.111',
   '34.2.120',
   '27.0.338',
   '16.8.152',
   '10.2.229',
   '16.8.161']],
 [['4.1.204',
   '4.1.231',
   '4.1.212',
   '4.1.222',
   '4.1.255',
   '4.1.233',
   '4.1.221',
   '4.1.213',
   '4.1.214',
   '4.1.224']],
 [['21.0.144',
   '17.13.6',
   '21.0.20',
   '21.0.141',
   '21.0.21',
   '21.0.81',
   '21.0.51',
   '21.0.97',
   '17.9.48',
   '3.0.150']],
 [['33.1.176',
   '33.1.172',
   '29.4.34',
   '16.6.974',
   '33.1.173',
   '16.6.986',
   '16.6.1502',
   '16.6.995',
   '16.6.985',
   '33.1.167']]]

In [8]:
sage_results

[[['33.1.150',
   '33.1.179',
   '33.1.174',
   '33.1.160',
   '33.1.148',
   '33.1.172',
   '33.1.187',
   '33.1.161',
   '33.1.189',
   '33.1.147']],
 [['10.0.373',
   '10.0.184',
   '10.0.549',
   '10.0.201',
   '10.0.293',
   '10.0.189',
   '10.0.291',
   '10.0.87',
   '10.0.149',
   '10.0.290']],
 [['25.18.167',
   '25.18.98',
   '25.18.52',
   '25.18.102',
   '25.18.121',
   '25.18.134',
   '25.18.87',
   '25.18.153',
   '25.18.185',
   '25.18.171']],
 [['3.0.147',
   '3.0.152',
   '3.0.114',
   '3.0.234',
   '3.0.201',
   '3.0.146',
   '3.0.184',
   '3.0.176',
   '3.0.259',
   '3.0.158']],
 [['16.6.965',
   '16.6.713',
   '16.6.1378',
   '16.6.417',
   '16.6.46',
   '16.6.709',
   '16.6.1001',
   '16.6.558',
   '16.6.874',
   '16.6.886']]]

## Evaluation

In [4]:
# queries_df
queries_df

# Load bsard corpus lean
bsard_corpus_lean_path = os.path.join(
    project_root, "data", "BSARD_dataset", "bsard_corpus_lean.csv"
    )
bsard_corpus_lean = pd.read_csv(bsard_corpus_lean_path)

Unnamed: 0,id,category,subcategory,question,extra_description,article_ids
0,1102,Travail,Travail et parentalité,Je suis travailleur salarié(e). Puis-je refuse...,Pendant la grossesse,"22225,22226,22227,22228,22229,22230,22231,2223..."
1,91,Argent,Dettes,Peut-on saisir tous mes revenus ?,"Procédures de récupération des dettes, Récupér...",585358545855
2,474,Famille,Situation de couples,Je suis marié(e). Nous sommes mariés. Dois-je ...,Mariage,109610971098110811091110
3,836,Logement,Location en Wallonie,Je mets un kot en location (bail de droit comm...,"Mettre un logement en location (Wallonie), Doi...",12012120301203112032120331203412035
4,1079,Travail,Maladie - incapacité de travail,Suis-je payé pendant la procédure du trajet de...,Rupture du contrat de travail pour force majeu...,"21114,21115,21116,21117,21118,21119,21120,2112..."
...,...,...,...,...,...,...
881,308,Famille,Lien parents/enfants,Quel est le rôle du tuteur d'un enfant mineur ?,"Tutelle, Rôle du tuteur",13211322132313241325132613271328
882,387,Famille,Personnes à l'autonomie fragilisée,Peut-on changer d'administrateur de biens et/o...,Administration de biens et/ou de la personne (...,1381
883,940,Logement,Location à Bruxelles,Mon propriétaire ne fait pas les réparations n...,"Bail de résidence principale (Bruxelles), Loye...",2494
884,364,Famille,Personnes à l'autonomie fragilisée,Que faire si le montant des frais et honoraire...,Administration de biens et/ou de la personne (...,1387


In [15]:
article_ids_code = []
for article_ids_list in queries_df['article_ids']:
    tmp = []
    for article_id in article_ids_list.split(','):
        #print(article_id)
        tmp.append(bsard_corpus_lean[bsard_corpus_lean['id'] == int(article_id)]['article_code'].values[0])
    article_ids_code.append(tmp)
    #print(tmp)

queries_df['article_ids_codes'] = article_ids_code


In [17]:
queries_df

Unnamed: 0,id,category,subcategory,question,extra_description,article_ids,article_ids_codes
0,1102,Travail,Travail et parentalité,Je suis travailleur salarié(e). Puis-je refuse...,Pendant la grossesse,"22225,22226,22227,22228,22229,22230,22231,2223...","[33.10.50, 33.10.51, 33.10.52, 33.10.53, 33.10..."
1,91,Argent,Dettes,Peut-on saisir tous mes revenus ?,"Procédures de récupération des dettes, Récupér...",585358545855,"[10.0.446, 10.0.447, 10.0.448]"
2,474,Famille,Situation de couples,Je suis marié(e). Nous sommes mariés. Dois-je ...,Mariage,109610971098110811091110,"[4.1.206, 4.1.207, 4.1.208, 4.1.218, 4.1.219, ..."
3,836,Logement,Location en Wallonie,Je mets un kot en location (bail de droit comm...,"Mettre un logement en location (Wallonie), Doi...",12012120301203112032120331203412035,"[21.0.1, 21.0.19, 21.0.20, 21.0.21, 21.0.22, 2..."
4,1079,Travail,Maladie - incapacité de travail,Suis-je payé pendant la procédure du trajet de...,Rupture du contrat de travail pour force majeu...,"21114,21115,21116,21117,21118,21119,21120,2112...","[33.1.167, 33.1.168, 33.1.169, 33.1.170, 33.1...."
...,...,...,...,...,...,...,...
881,308,Famille,Lien parents/enfants,Quel est le rôle du tuteur d'un enfant mineur ?,"Tutelle, Rôle du tuteur",13211322132313241325132613271328,"[4.1.431, 4.1.432, 4.1.433, 4.1.434, 4.1.435, ..."
882,387,Famille,Personnes à l'autonomie fragilisée,Peut-on changer d'administrateur de biens et/o...,Administration de biens et/ou de la personne (...,1381,[4.1.491]
883,940,Logement,Location à Bruxelles,Mon propriétaire ne fait pas les réparations n...,"Bail de résidence principale (Bruxelles), Loye...",2494,[4.0.507]
884,364,Famille,Personnes à l'autonomie fragilisée,Que faire si le montant des frais et honoraire...,Administration de biens et/ou de la personne (...,1387,[4.1.497]


In [None]:
# include the retrieved articles into queries_df as a new column
ret_semantic = []
ret_sage = []

for i in range(len(semantic_results)):
    n_articles = len(queries_df['article_ids_codes'][i])
    ret_semantic.append(semantic_results[i][0][:n_articles])
    ret_sage.append(sage_results[i][0][:n_articles])

queries_df['ret_semantic'] = ret_semantic
queries_df['ret_sage'] = ret_sage

In [None]:
#def recall_at_k(true_ids, pred_ids, k):
#    """Recall@K para UNA sola consulta."""
#    true_set = set(true_ids)
#    topk = set(pred_ids[:k])
#    return len(true_set & topk) / len(true_set)
#
#
#def average_precision(true_ids, pred_ids, k=None):
#    """Average Precision para UNA sola consulta."""
#    if k is not None:
#        pred_ids = pred_ids[:k]
#    true_set = set(true_ids)
#    hits = 0
#    sum_prec = 0.0
#    for i, pid in enumerate(pred_ids, start=1):
#        if pid in true_set:
#            hits += 1
#            sum_prec += hits / i
#    # normalizamos por el total de relevantes
#    return sum_prec / len(true_set) if true_set else 0.0
#
#
#def r_precision(true_ids, pred_ids):
#    """Precision@R para UNA sola consulta."""
#    R = len(true_ids)
#    return recall_at_k(true_ids, pred_ids, R)
#
#
## --- Cálculo agregado sobre todo el testset ---------------------
#
## supongamos:
##   q2true[q] = lista de IDs relevantes para la pregunta q
##   q2pred[q] = lista de IDs devuelta por tu sistema RAG (ordenada por similitud)
#
#recalls, aps, rps = [], [], []
#
#for q in q2true:
#    true_ids = q2true[q]
#    pred_ids = q2pred[q]
#
#    recalls.append(recall_at_k(true_ids, pred_ids, k=10))
#    aps.append(average_precision(true_ids, pred_ids))
#    rps.append(r_precision(true_ids, pred_ids))
#
#mean_recall10 = sum(recalls) / len(recalls)
#map_score      = sum(aps)     / len(aps)
#mean_rp        = sum(rps)     / len(rps)

#print(f"Recall@10: {mean_recall10:.3f}")
#print(f"MAP:        {map_score:.3f}")
#print(f"mRP:        {mean_rp:.3f}")
