In [1]:
import sys
sys.path.append("..")
import csv
import pandas as pd
from pathlib import Path
from collections import defaultdict

from src.pipeline.kg_pipeline import KGPipeline
from src.components.ner import OpenAINER
from src.components.retrieval.triplets import BM25TripletsRetriever
from src.components.subgraph_creation import FirstShortestPathSubgraphCreator, ConstrainedShortestPathSubgraphCreator
from src.components.pruning import PageRankPruner
from src.components.reasoning_path_generation import ShortestPathReasoningPathGenerator
from src.components.explanation import MainExplainer
from src.utils import load_graph
from logging import info

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [15]:
csv_paths = {}
for file in Path("../data/qna/combined/preprocessed").iterdir():
    prefix = file.stem.split("_")[0]
    if "train" in file.stem:
        csv_paths[prefix] = Path(file)

In [21]:
G, nodes_df, edges_df = load_graph("../data/primekg", return_df=True, remove_node_types=['gene/protein'])

In [22]:
info(f"Vertices Count = {G.vcount()}, Edges Count = {G.ecount()}")

[2025-03-26 18:03:05,980] [INFO] [782782662] Vertices Count = 101704, Edges Count = 1739958


In [23]:
ner = OpenAINER()
info("Loading Retriever...")
triplets_retriever = BM25TripletsRetriever(edges_df)
info("Retriever Loaded")
# subgraph_creator = ConstrainedShortestPathSubgraphCreator()
subgraph_creator = FirstShortestPathSubgraphCreator()
pruner = PageRankPruner()
reasoning_path_generator = ShortestPathReasoningPathGenerator()
explainer = MainExplainer()

pipeline = KGPipeline(
    G=G,
    ner=ner,
    triplets_retriever=triplets_retriever,
    subgraph_creator=subgraph_creator,
    pruner=pruner,
    reasoning_path_generator=reasoning_path_generator,
    explainer=explainer
)



[2025-03-26 18:03:11,827] [INFO] [3783104294] Loading Retriever...


[2025-03-26 18:03:29,087] [INFO] [3783104294] Retriever Loaded


In [None]:
for name, path in tqdm(csv_paths.items()):
    csv_writer = csv.writer(open(name + "_100_exp_rp.csv", "w"))
    df = pd.read_csv(path)
    df = df.sample(100, random_state=42)
    csv_writer.writerow(df.columns.tolist() + ['explanation', 'reasoning_paths'])

    for i, row in tqdm(df.iterrows(), desc=f"Running {name}", total=len(df)):
        question = row['question']
        answer = row['answer']

        options = []
        if 'option1' in row.keys():
            options = [row[f"option{i}"] for i in range(1, 5)]

        explanation, reasoning_paths = pipeline.run(question, 
                    answer, 
                    qna_context_prefix="Options:" if len(options) > 0 else "", 
                    qna_context="\n".join(options) if len(options) > 0 else "",
                    top_k_triplets=20,
                    pruned_top_k_nodes=20)

        csv_writer.writerow(row.values.tolist() + [explanation, reasoning_paths])

  0%|          | 0/7 [00:00<?, ?it/s][2025-03-26 18:04:28,760] [INFO] [kg_pipeline] Running pipeline...
[2025-03-26 18:04:29,941] [INFO] [_client] HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
[2025-03-26 18:04:31,205] [INFO] [_client] HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
[2025-03-26 18:04:31,940] [INFO] [_client] HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
[2025-03-26 18:04:31,944] [INFO] [kg_pipeline] Retrieving triplets...
[2025-03-26 18:04:37,650] [INFO] [kg_pipeline] # Unique nodes: 31
[2025-03-26 18:04:37,652] [INFO] [kg_pipeline] # Triplets: 20
[2025-03-26 18:04:37,652] [INFO] [kg_pipeline] Unique nodes: ['heart defects-limb shortening syndrome', 'insulin autoimmune syndrome', 'aspirin resistance', 'Mandibular pain', 'chest bone', 'insulin metabolic process', 'behavioral response to pain', 'benign hypertension', 'hypertension', 'Abnormal insulin level', 'renal hyperte