In [None]:
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.document_loaders import JSONLoader, DirectoryLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document

from langchain_ollama import OllamaEmbeddings, OllamaLLM
from langchain.prompts import ChatPromptTemplate
from tqdm.notebook import tqdm
import faiss
import os
import pickle
import random
import pandas as pd
import re
import json
import ast
import glob

from typing import Dict, List
import torch
from prompt import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# import os

# def count_chars_in_files(folder_path):
#     file_lengths = []
#     small_files = []

#     for filename in os.listdir(folder_path):
#         if filename.endswith(".txt"):  # Process only .txt files
#             file_path = os.path.join(folder_path, filename)
#             with open(file_path, "r", encoding="utf-8") as file:
#                 char_count = len(file.read())
#                 file_lengths.append(char_count)
#                 if char_count > 10000:
#                     small_files.append((filename, char_count))

#     if file_lengths:
#         min_chars = min(file_lengths)
#         max_chars = max(file_lengths)
#         mean_chars = sum(file_lengths) / len(file_lengths)

#         print(f"Min: {min_chars} chars")
#         print(f"Max: {max_chars} chars")
#         print(f"Mean: {mean_chars:.2f} chars")

#         if small_files:
#             print("\nFiles with fewer than 10000 characters:")
#             for filename, char_count in small_files:
#                 print(f"{filename}: {char_count} chars")
#         else:
#             print("\nNo files with fewer than 100 characters.")
#     else:
#         print("No .txt files found in the folder.")

# # Example usage
# folder_path = "/home/cc/PHD/ragkg/data/kgbase-new"
# count_chars_in_files(folder_path)

In [None]:
# import re

# import pandas as pd

# # Load the input CSV
# df = pd.read_csv('/home/cc/PHD/ragkg/closed_questions/merged_output.csv')

# # Define a function to transform the 'qea' string
# def transform_qea(qea_str):
#     question_match = re.search(r'Question:\s*(.*?)\nA:', qea_str, re.S)
#     question = question_match.group(1).strip() if question_match else ""

#     # Extract choices and convert to a list (removing letters A:, B:, etc.)
#     choices_match = re.findall(r'[A-D]:\s*(.*?)(?=\n[A-D]:|\nCorrect answer:|\Z)', qea_str, re.S)
#     choices = [choice.strip() for choice in choices_match]

#     # Extract correct answer
#     correct_match = re.search(r'Correct answer:\s*([A-D])', qea_str)
#     correct_option = correct_match.group(1).strip() if correct_match else ""
    
#     return question, choices, correct_option

# # Create new columns based on the 'qea' column
# df[['question', 'choices', 'correct_option']] = df['qea'].apply(lambda x: pd.Series(transform_qea(x)))

# # Now, you can save the new CSV with the required columns
# df = df[['question', 'choices', 'correct_option', 'paths', 'name']]

# # Save the result to a new CSV
# df.to_csv('/home/cc/PHD/ragkg/closed_questions/old_quiz.csv', index=False)

In [None]:
class VectorStore:
    def __init__(self, index_path: str, embedder_name: str = "mxbai-embed-large"):
        self.index_path = index_path
        self.embedder_name = embedder_name
        self.embedder = OllamaEmbeddings(model=embedder_name)
        self._load_vector_store()

    def _load_vector_store(self):
        if os.path.exists(self.index_path):
            print("### LOAD VECTOR DB ###")

            self.index = faiss.read_index(self.index_path+'index.faiss')
            
            with open(self.index_path+'doc_to_id.pkl', "rb") as f:
                self.index_to_doc_id = pickle.load(f)

            with open(self.index_path+'docstore.pkl', "rb") as f:
                self.docstore = pickle.load(f)

            self.vector_store = FAISS(
                embedding_function=self.embedder,
                index=self.index,
                docstore=self.docstore,
                index_to_docstore_id=self.index_to_doc_id
            )   
        else:
            print("### CREATE VECTOR DB ###")

            self.index = faiss.IndexFlatL2(len(self.embedder.embed_query('hello world')))
            self.index_to_doc_id = {}
            self.docstore = InMemoryDocstore()

            self.vector_store = FAISS(
                embedding_function=self.embedder,
                index=self.index,
                docstore=self.docstore,
                index_to_docstore_id=self.index_to_doc_id
            )

            if not os.path.exists(self.index_path):
                os.makedirs(self.index_path)

    def _load_documents(self, doc_path: str, doc_type: str = "*.txt") -> list[Document]:
        loader = DirectoryLoader(doc_path, glob=doc_type)
        documents = loader.load()
        return documents

    def _split_text(self, documents: list[Document]):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=150,
            length_function=len,
            add_start_index=True,
        )
        chunks = text_splitter.split_documents(documents)
        print(f"Split {len(documents)} documents into {len(chunks)} chunks.")

        return chunks

    def add_documents(self, doc_path: str, doc_type: str = "*.txt"):
        documents = self._load_documents(doc_path, doc_type)
        chunks = self._split_text(documents)
        self.vector_store.add_documents(documents=chunks)
        self._update_vector_db()
    
    def search(self, query: str, k: int = 3):
        return self.vector_store.similarity_search(query=query, k=k)
        # return self.vector_store.similarity_search(query=self._transform_query(query), k=3)

    def _update_class(self):
        self.index = self.vector_store.index
        self.index_to_doc_id = self.vector_store.index_to_docstore_id
        self.docstore = self.vector_store.docstore
    
    def _update_vector_db(self):
        faiss.write_index(self.vector_store.index, self.index_path+'index.faiss')

        with open(self.index_path+'doc_to_id.pkl', "wb") as f:
            pickle.dump(self.vector_store.index_to_docstore_id, f)

        with open(self.index_path+'docstore.pkl', "wb") as f:
            pickle.dump(self.vector_store.docstore, f)

        self._update_class()

        print("### UPDATE VECTOR DB ###")

In [None]:
# Create an empty vector store in the indicated path. If the path already exists, load the vector store
vector_store = VectorStore('/home/cc/PHD/ragkg/indexes/kgbase/')

# Add documents in vector store (comment this line after the first add)
# vector_store.add_documents('/home/cc/PHD/ragkg/data/kgbase')

In [None]:
class LLMinference:
    def __init__(self, llm_name):
        self.llm_name = llm_name
        self.model = OllamaLLM(model=llm_name) 

    def _transform_query(self, query: str) -> str:
        return f'Represent this sentence for searching relevant passages: {query}'

    def single_inference(self, query: str, template: str, path: str,  choices: List[str], cond: str,  context) -> str | List[str]:
        context_text = "\n\n---\n\n".join([doc.page_content for doc in context])

        prompt_template = ChatPromptTemplate.from_template(template)
        if path != "":
            prompt = prompt_template.format(context=context_text, question=query, path=path, condition=cond, 
                                            o1=choices[0], o2=choices[1], o3=choices[2], o4=choices[3], o5=choices[4])
        else:
            prompt = prompt_template.format(context=context_text, question=query, condition=cond, o1=choices[0], 
                                            o2=choices[1], o3=choices[2], o4=choices[3], o5=choices[4])

        response_text = self.model.invoke(prompt)
        response_text = response_text.strip().replace("\n", "").replace("  ", "")

        sources = [doc.metadata.get("source", None) for doc in context]
        
        return response_text, sources

    def qea_evaluation(self, query: str, template: str, path: str, choices: List[str], cond: str,  vector_store):

        results = vector_store.search(query=query, k=5)

        response, sources = self.single_inference(query, template, path, choices, cond, results)

        return response

In [None]:
folder_path = "/home/cc/PHD/ragkg/MedQA"
# models = ["mistral", "llama3.1:8b", "llama2:7b", "medllama2:7b", "gemma:7b", "gemma2:9b", "phi4:14b", "qwen2.5:7b", "mixtral:8x7b", "deepseek-r1:7b"]
models = ["mistral"]

# templates = [PROMPT_TEMPLATE, PROMPT_TEMPLATE_ONE, PROMPT_TEMPLATE_RAG]
templates = [PROMPT_MED, PROMPT_MED_RAG]

med_files = glob.glob(f"{folder_path}/*/top*", recursive=True)

In [None]:
cnt_rag = 0
cnt = 0

rows = []

for model in models:
    model_name = model
    llm = LLMinference(llm_name=model_name)

    cnt = 0
    rows = []

    for jso in tqdm(med_files):
        questions = pd.read_json(jso, lines=True)

        for index, row in questions.iterrows():
            res = []

            cond = jso.split('/')[-2].lower()

            for template in templates:
                res.append(llm.qea_evaluation(row['question'], template, "", list(row['options'].values()), cond, vector_store))
                # res.append(llm.qea_evaluation(row['question'], template, row['reasoning_trace'], ast.literal_eval(row['choices']), row['name'].lower(), vector_store)) # Baseline

            res.append(row["answer_idx"])
            res.append(row['question'])
            res.insert(0, cond)

            rows.append(res)

In [None]:
df = pd.DataFrame(rows, columns=["name", "zero_shot", "zero_shot_rag", "real", "question"]) # medqa

In [None]:
df.to_csv(f"/home/cc/PHD/ragkg/results_medqa_{model_name}.csv", index=False)