In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import re
import os
import pandas as pd
import ast
import argparse
from typing import Dict, List
from alive_progress import alive_bar

from classes.vector_store import VectorStore
# from classes.llm_inference import LLMinference
from classes.utils import check_results
from prompt import *

# parser = argparse.ArgumentParser(description="LLM inference with different modalities.")
# parser.add_argument("-base", action="store_true", help="Run in baseline mode.")
# parser.add_argument("-quiz", action="store_true", help="Run in quiz mode.")
# args = parser.parse_args()

# Set BASELINE based on the argument
# BASELINE = args.base
# QUIZ = args.quiz
BASELINE = False
QUIZ = True
PATH = "/home/cc/PHD/HealthBranches/"
EXT = "QUIZ" if QUIZ else "OPEN"

print("##### BASELINE MODE #####\n" if BASELINE else "##### BENCHMARK MODE #####\n")
print("##### QUIZ EXP #####\n" if QUIZ else "##### OPEN EXP #####\n")

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.prompts import ChatPromptTemplate
from typing import List
import torch

class LLMinference:
    def __init__(self, llm_name: str, temperature: float = 0.01, num_predict: int = 128, device: int = 0):
        self.llm_name = llm_name
        self.temperature = temperature
        self.num_predict = num_predict
        # Load tokenizer and model in FP16 (if using GPU; for CPU set device=-1 and remove torch_dtype)
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            llm_name,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        # Create a text-generation pipeline; set device=0 for GPU (or -1 for CPU)
        self.generator = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            device=device
        )

    # def _remove_reasoning(self, text):
    #     """
    #     Remove any chain-of-thought reasoning enclosed in <think>...</think> tags.
    #     """
    #     return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()

    def _remove_reasoning(self, text):
        """
        Remove all text before the closing </think> tag.
        If the tag is found, returns the text after the tag; otherwise, returns the original text.
        """
        pos = text.find("</think>")
        if pos != -1:
            return text[pos + len("</think>"):].strip()
        return text.strip()

    def invoke(self, prompt: str) -> str:
        """
        Invoke the model with the given prompt and return the generated text.
        """
        d = """You are a professional assistant. Answer the question directly and do not include any internal reasoning or chain-of-thought.\n """+prompt
        outputs = self.generator(d, max_new_tokens=self.num_predict, temperature=self.temperature, do_sample=True)
        # Extract generated text from the first output
        # print(outputs)
        for output in outputs:
            answer = output['generated_text']
            final_answer = self._remove_reasoning(answer)
            # print("Final Answer:")
            # print(final_answer)

        response_text = final_answer
        return response_text

    def single_inference(self, query: str, template: str, path: str, text: str, choices: List[str], cond: str, context):
        # Join context documents (assumed to be a list of objects with page_content attribute)
        context_text = "\n\n---\n\n".join([doc.page_content for doc in context])
        prompt_template = ChatPromptTemplate.from_template(template)
        
        if choices:  # quiz mode
            if path != "" and text != "":
                prompt = prompt_template.format(
                    context=context_text,
                    question=query,
                    path=path,
                    text=text,
                    condition=cond,
                    o1=choices[0],
                    o2=choices[1],
                    o3=choices[2],
                    o4=choices[3],
                    o5=choices[4]
                )
            else:
                prompt = prompt_template.format(
                    context=context_text,
                    question=query,
                    condition=cond,
                    o1=choices[0],
                    o2=choices[1],
                    o3=choices[2],
                    o4=choices[3],
                    o5=choices[4]
                )
        else:  # open question
            if path != "" and text != "":
                prompt = prompt_template.format(
                    context=context_text,
                    question=query,
                    path=path,
                    text=text,
                    condition=cond
                )
            else:
                prompt = prompt_template.format(
                    context=context_text,
                    question=query,
                    condition=cond
                )

        response_text = self.invoke(prompt)
        response_text = response_text.strip().replace("\n", "").replace("  ", "")
        sources = [doc.metadata.get("source", None) for doc in context]
        
        return response_text, sources

    def qea_evaluation(self, query: str, template: str, path: str, txt: str, choices: List[str], cond: str, vector_store, k: int = 3) -> str:
        results = vector_store.search(query=query, k=k)
        response, sources = self.single_inference(query, template, path, txt, choices, cond, results)
        return response

In [None]:
# Create an empty vector store in the indicated path. If the path already exists, load the vector store
vector_store = VectorStore(f'{PATH}indexes/kgbase-new/')

# Add documents in vector store (comment this line after the first add)
# vector_store.add_documents('/home/cc/PHD/ragkg/data/kgbase')

# folder_path = f"{PATH}questions_pro/ultimate_questions_v3_full_balanced.csv"
folder_path = f"{PATH}questions_pro/dataset_updated.csv"
questions = pd.read_csv(folder_path)

models = ["deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"]

models = check_results(PATH+"results/", f"results_{EXT}_baseline_*.csv" if BASELINE else f"results_{EXT}_bench_*.csv", models)

templates = [PROMPT_QUIZ, PROMPT_QUIZ_RAG] if QUIZ else [PROMPT_OPEN, PROMPT_OPEN_RAG]

if BASELINE:
    templates = [PROMPT_QUIZ_BASELINE] if QUIZ else [PROMPT_OPEN_BASELINE]

cnt_rag = 0
cnt = 0

rows = []
questions = pd.read_csv(folder_path)

for model_name in models:
    llm = LLMinference(llm_name=model_name)

    cnt = 0
    rows = []
    print(f"Running model {model_name}...")
    with alive_bar(len(questions)) as bar:
        for index, row in questions[:5].iterrows():
            res = []
            opts = []

            try:
                opts = ast.literal_eval(row['options'].replace("['", '["').replace("']", '"]').replace("', '", '", "'))
                if not isinstance(opts, list) or len(opts) != 5:
                    print(f"Skipping row {index} due to invalid options")
                    continue  # Skip this iteration if the condition is not met

            except (ValueError, SyntaxError):
                print(f"Skipping row {index} due to value/syntax error")
                continue  # Skip if there's an issue with evaluation

            txt_name = row['condition'].upper()+".txt"
            txt_folder_name = f"{PATH}data/kgbase-new/"

            try:
                with open(os.path.join(txt_folder_name, txt_name), 'r') as file:
                    text = file.readlines()
            except Exception:
                print(os.path.join(txt_folder_name, txt_name))
                print(f"{txt_name} text is EMPTY!")
                continue    
            
            for template in templates:
                if BASELINE:
                    try:
                        res.append(llm.qea_evaluation(row['question'], template, row['path'], text, opts, row['condition'].lower(), vector_store)) # Baseline
                    except Exception:
                        print(row)
                else:
                    try:
                        res.append(llm.qea_evaluation(row['question'], template, "", "", opts, row['condition'].lower(), vector_store))
                    except Exception as e:
                        print(row)

            if QUIZ:
                res.append(row["correct_option"])
            else:
                res.append(opts[ord(row["correct_option"].strip().upper()) - ord('A')])

            res.append(row['question'])
            res.append(row['path'])
            res.insert(0, row['condition'].lower())

            rows.append(res)
            bar()

        if BASELINE:
            df = pd.DataFrame(rows, columns=["name", "zero_shot", "real", "question", "path"]) # Baseline
            df.to_csv(f"{PATH}/results/results_{EXT}_baseline_{model_name}.csv", index=False) # Baseline
        else:
            df = pd.DataFrame(rows, columns=["name", "zero_shot", "zero_shot_rag", "real", "question", "path"])
            df.to_csv(f"{PATH}/results/results_{EXT}_bench_DeepSeek-R1-Distill-Qwen-7B.csv", index=False)

    print(f"Model {model_name} done!\n")

In [None]:
# def remove_reasoning(text):
#     """
#     Remove any chain-of-thought reasoning enclosed in <think>...</think> tags.
#     """
#     return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()

# # Specify the model ID from Hugging Face Hub.
# model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

# # Load the tokenizer.
# tokenizer = AutoTokenizer.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True)

# # Load the model in FP16 (half-precision) mode.
# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True)

# # Create a text-generation pipeline; set device=0 if using GPU.
# generator = pipeline(
#     "text-generation",
#     model=model,
#     tokenizer=tokenizer,
#     device=0,
#     trust_remote_code=True,
#     torch_dtype=torch.float16,
# )

# # Define the prompt with an instruction to not reveal internal reasoning.
# prompt = (
#     "You are a professional assistant. Answer the question directly and do not include any internal reasoning or chain-of-thought. "
#     "Who are you?"
# )

# # Generate the model response.
# outputs = generator(prompt, max_new_tokens=100, do_sample=False)

# # Process and print the final answer, filtering out any chain-of-thought text.
# for output in outputs:
#     answer = output['generated_text']
#     final_answer = remove_reasoning(answer)
#     print("Final Answer:")
#     print(final_answer)