In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# set random seed for reproducibility
np.random.seed(42)


In [2]:
icliniq_df = pd.read_csv('data/disease_csv_files/icliniq_medical_qa_cleaned.csv')


In [3]:
icliniq_df_train, icliniq_df_test = train_test_split(icliniq_df, test_size=0.25, random_state=42)

In [4]:
print(len(icliniq_df_test))
print(len(icliniq_df_train))

9885
29655


In [5]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import networkx as nx
from sentence_transformers import SentenceTransformer
# import faiss
import numpy as np
# from utils import encode_text_and_match_diseases, update_graph_data
from vector_db_files.searcher_class import MilvusSearcher
from gnn_files.gnn import DiseaseSymptomGraphGNN
import pandas as pd
from llm_files.llm_model import run_generation
from time import time
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

############# Create graph model instance #############
dataset_path = "all_disease_graphs.pkl"
disease_symptom_df_path = "data/disease_csv_files/diseases_symptoms_merged.csv"

dict_paths = {"main_graph_path" : "data/disease_csv_files/diseases_symptoms_merged.csv",
"disease_symptom_csv_path" :  "data/Final_Augmented_dataset_Diseases_and_Symptoms.csv",
"patient_doctor_csv_path" : "data/patient-doctor.csv",
"disease_list_path" : "data/disease_csv_files/unique_aliases.csv",
"disease_aliases_path" : "data/disease_aliases.json",
"symptom_list_path" : "data/disease_csv_files/unique_symptoms.csv"}

df = pd.read_csv(disease_symptom_df_path)
graph_model = DiseaseSymptomGraphGNN(
    df=df,
    dict_paths=dict_paths,
    hidden_channels=128,
    out_channels=1,
    num_layers=2,
    dropout=0.5,
    seed=42,
)

print("creating graph model instance")
graph_model.load_state_dict(torch.load("gnn_files/best_model.pt", map_location=torch.device('cpu')))
collection_name = "pmc_trec_2016"

############# Create searcher instance #############
print("creating searcher instance")
searcher = MilvusSearcher(uri=os.getenv("PATH_TO_MILVUS_DB"), collection_name=collection_name)

# def run_pipeline(query):
#     encoded_dict = encode_text_and_match_diseases(query)
#     print("Encoded dict:", encoded_dict)
#     related_graph_data = update_graph_data(graph_data, encoded_dict)
#     graph_results = graph_model(related_graph_data.x, related_graph_data.edge_index, related_graph_data.edge_attr)
    
#     print("Graph model results:", graph_results)
#     results = ""

#     return results

def run_pipeline(query, reference=None, testing=False):
    # print("Running pipeline for query:", query)
    # start_time = time()
    graph_results = graph_model.text_forward(query)
    # print("GNN Time:", time() - start_time)
    # start_time = time()
    prompt = f"potential diseases: {' ,'.join(graph_results)} \n query: {query}"

    # print("Prompt for search:\n", prompt)
    search_results_with_gnn = searcher.search(prompt, limit=5)
    # print("Search Time:", time() - start_time)
    search_results_without_gnn = None
    if testing:
        search_results_without_gnn = searcher.search(f"query: {query}")
    # print("Search results:\n", search_results)
    # start_time = time()
    final_answer = run_generation(query, graph_results, search_results_with_gnn, testing=testing, reference_diagnosis=reference, retrieved_contexts_no_gnn=search_results_without_gnn)
    # print("LLM Time:", time() - start_time)
    return final_answer

  from .autonotebook import tqdm as notebook_tqdm


Number of unique diseases: 773
Number of disease nodes in model graph: 773
creating graph model instance
creating searcher instance
Collection 'pmc_trec_2016' exists.
Using device: cuda


In [6]:
res_df = pd.DataFrame(columns=["Question", "True_Answer", "Rag_GNN", "Rag_only", "Basic_llm"])

for row in icliniq_df_test.head(1).itertuples():
    query = row.Question
    reference = row.Answer

    cur_res_dict = run_pipeline(query, reference=reference, testing=True)
    cur_row = {
        "Question": query,
        "True_Answer": reference,
        "Basic_llm": cur_res_dict["baseline_results"],
        "Rag_GNN": cur_res_dict["rag_gnn_results"],
        "Rag_only": cur_res_dict["baseline_rag_results"]
    }
    res_df = pd.concat([res_df, pd.DataFrame([cur_row])], ignore_index=True)
    print("Query completed:", query)

display(res_df)

Query completed: Hi doctor,
I am a 51-year-old female with a height of 5 feet 8 inches and a weight of 145 lbs. I have Eagle syndrome with left internal jugular vein compression, but the right internal jugular vein is slightly compressed. I have had a sharp stabbing pain in my head for the past month. I got three TIAs Transient ischemic Attack with stroke- like symptoms like slurred speech, right-sided weakness, and right side facial drooping in the past 18 months. CT Computed Tomography report shows no blood clot. I got an appointment with the doctor next month. Please help.


Unnamed: 0,Question,True_Answer,Rag_GNN,Rag_only,Basic_llm
0,"Hi doctor,\r\nI am a 51-year-old female with a...","Hi,\r\nWelcome to icliniq.com.\r\nSevere jugul...","{'score': 97, 'clinical_accuracy': 18, 'use_of...","{'score': 100, 'clinical_accuracy': 20, 'use_o...","{'score': 97, 'clinical_accuracy': 19, 'use_of..."


In [8]:
print(res_df["Rag_GNN"].values)

[{'score': 97, 'clinical_accuracy': 18, 'use_of_retrieved_evidence': 9, 'transparency': 10}]


In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd

def process_single_row(row):
    try:
        query = row.Question
        reference = row.Answer
        cur_res_dict = run_pipeline(query, reference=reference, testing=True)
        result = {
            "Question": query,
            "True_Answer": reference,
            "Basic_llm": cur_res_dict["baseline_results"],
            "Rag_GNN": cur_res_dict["rag_gnn_results"],
            "Rag_only": cur_res_dict["baseline_rag_results"]
        }
        print(f"Query completed: {query[:50]}...")
        return result
    except Exception as e:
        print(f"Error processing row: {e}")
        return None

results = []
test_data = list(icliniq_df_test.head(20).itertuples())

# Use ThreadPoolExecutor with max 4 worker
with ThreadPoolExecutor(max_workers=4) as executor:
    future_to_row = {executor.submit(process_single_row, row): row for row in test_data}
    
    for future in as_completed(future_to_row):
        result = future.result()
        if result is not None:
            results.append(result)

res_df = pd.DataFrame(results) if results else pd.DataFrame()
print(f"Processing completed. Total results: {len(res_df)}")
display(res_df)