In [1]:
import os
import pickle
import sys

import numpy as np
import torch
import yaml
import json
import pandas as pd

# ChromaDB imports
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import Dict, List, Optional, Tuple, Union

try:
    # This will work in scripts where __file__ is defined
    current_dir = os.path.dirname(os.path.abspath(__file__))
    # Assuming "src" is parallel to the script folder
    project_root = os.path.abspath(os.path.join(current_dir, ".."))
except NameError:
    # In notebooks __file__ is not defined: assume we're in notebooks/
    project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

src_path = os.path.join(project_root, "src")
if src_path not in sys.path:
    sys.path.append(src_path)

from main.node_embedding_models import GraphSAGE
from main.ollama_utils import get_ollama_embedding
from main.sage_utils import get_new_sage_embedding


In [2]:
#config

In [3]:
bundle_name = "bsard_2hop_20e_1024-768-512"

# Load Graph
graph_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_name, "graph.pkl"
    )
with open(graph_path, "rb") as f:
    graph = pickle.load(f)

# Load sage config
config_file_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_name, "config.yaml"
    )
with open(config_file_path, "r") as f:
    config = yaml.safe_load(f)


# Load SAGE model
sage_model_path = os.path.join(
    project_root, "data", "retrieval_bundles", bundle_name, "graphsage.pth"
    )

# Define model and load weights
model = GraphSAGE(
    channels=config["model_params"]["channels"],   
    )
try:
    model.load_state_dict(torch.load(sage_model_path))
except:
    model.load_state_dict(torch.load(sage_model_path, map_location=torch.device('cpu')))
model.eval() 

# Set device for model inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device
model = model.to(device)

# Load queries
queries_path = os.path.join(
    project_root, "data", "BSARD_dataset", "bsard_train_questions.csv"
    )

queries_df = pd.read_csv(queries_path)


In [4]:
queries = [queries_df.head()['question'][0]]

In [5]:
def parallel_retrieval_pipeline(graph, queries, sage_model, semantic_embedder=get_ollama_embedding):

    # Get queries semantic embedding
    query_semantic_embeddings = []
    for query in queries:
        print(query)
        query_semantic_embeddings.append(semantic_embedder(query)['embedding'] )
    
    # Get queries sage embedding
    query_sage_embeddings = []
    for query in query_semantic_embeddings:
        print(query)
        query_sage_embeddings.append(get_new_sage_embedding(model=model, new_feature=torch.tensor(query, dtype=torch.float32).reshape(model.convs[0].in_channels), device=device))

    # Create two chromadb collections
    chroma_client = chromadb.Client()

        # One for the semantic embeddings
    collection_semantic = chroma_client.create_collection(name="semantic_collection")
        # Fill collection
    for node_id in graph.nodes:
        node_data = graph.nodes[node_id]
        if node_data.get("node_type") == "Article":
            doc_text = node_data.get("article_text", "")
            doc_embedding = node_data.get("embedding", None)
        
            # Ensure that doc_embedding is a list (or array) of floats
            # Add the node to the collection by specifying IDs, documents, and embeddings
            collection_semantic.add(
                ids=[str(node_id)],         # The ID will be the node's ID (converted to a string)
                documents=[doc_text],         # The node's text
                embeddings=[doc_embedding]    # The precomputed embedding you already have
                )
            
        # One for the sage embeddings
    collection_sage = chroma_client.create_collection(name="sage_collection")
        # Fill collection
    for node_id in graph.nodes:
        node_data = graph.nodes[node_id]
        if node_data.get("node_type") == "Article":
            doc_text = node_data.get("article_text", "")
            doc_embedding = node_data.get("hybrid_embedding", None)
        
            # Ensure that doc_embedding is a list (or array) of floats
            # Add the node to the collection by specifying IDs, documents, and embeddings
            collection_sage.add(
                ids=[str(node_id)],         # The ID will be the node's ID (converted to a string)
                documents=[doc_text],         # The node's text
                embeddings=[doc_embedding]    # The precomputed embedding you already have
                )

    # Retrieve top k documents from the semantic collection based on the query semantic embedding for each query (based on semantic embedding)
    semantic_results = []
    for query in query_semantic_embeddings:
        topk_semantic    = collection_semantic.query(
            query_embeddings=[query] ,
            n_results=10
        )
        semantic_results.append(topk_semantic["ids"])


    # Retrieve top k documents from the sage collection based on the query sage embedding for each query (based on sage embedding)
    sage_results = []
    for query in query_sage_embeddings:
        topk_sage    = collection_sage.query(
            query_embeddings=query.detach().cpu().unsqueeze(0).tolist(),
            n_results=10
        )
        sage_results.append(topk_sage["ids"])

    
    return semantic_results, sage_results

In [6]:
semantic, sage = parallel_retrieval_pipeline(graph, queries, model)

Je suis travailleur salarié(e). Puis-je refuser de faire des heures supplémentaires ou de travailler de nuit ?
[-0.9013447165489197, 0.2166704535484314, -0.8958656191825867, 0.7593201994895935, -1.0512361526489258, -0.6891542077064514, -0.7547259330749512, 0.37134474515914917, -0.15967929363250732, -0.18797999620437622, -0.1272975653409958, 0.14231307804584503, -1.022972583770752, -0.6737866997718811, 1.1148275136947632, -0.5734367966651917, 0.050500936806201935, -1.329225778579712, -0.7811251282691956, -0.04244634136557579, -0.9014313220977783, -0.542297899723053, -1.4358255863189697, 0.7186563014984131, 0.7013106346130371, 0.512730062007904, -0.6749077439308167, -0.1051284670829773, 1.4610470533370972, 0.9967589378356934, -1.044190764427185, -0.5095573663711548, 0.7193320989608765, -1.8120445013046265, 0.08646604418754578, -0.6776453852653503, 0.3242875933647156, 0.8090276122093201, -0.7873283624649048, -0.11631450802087784, -0.29564133286476135, -0.6130490899085999, 0.45040744543075

In [7]:
semantic

[[['33.10.58',
   '14.2.48',
   '14.2.50',
   '14.2.32',
   '29.3.28',
   '33.10.1',
   '33.10.2',
   '33.10.6',
   '33.10.3',
   '33.5.98']]]

In [8]:
sage

[[['33.1.150',
   '33.1.179',
   '33.1.174',
   '33.1.160',
   '33.1.148',
   '33.1.172',
   '33.1.187',
   '33.1.161',
   '33.1.189',
   '33.1.147']]]