In [1]:
import torch

In [2]:
import os
import copy
import json
import threading
import logging

from transformers import HfArgumentParser, AutoTokenizer, PreTrainedTokenizerFast
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
from datasets import Dataset, load_dataset

In [3]:
import sys
sys.path.append('src/')

In [4]:
from config import Arguments
from logger_config import logger
from data_utils import log_random_samples, load_corpus, format_documents_for_final_answer
from vllm_client_local import VllmClient, get_vllm_model_id
from utils import save_json_to_file, AtomicCounter
from agent import CoRagAgent, RagPath
from inference.metrics import compute_metrics_dict

In [5]:
vllm_ip = "10.197.17.39"
e5_ip = "10.197.17.38"

In [6]:
vllm_client: VllmClient = VllmClient(get_vllm_model_id(host=vllm_ip), host=vllm_ip)
corpus: Dataset = load_corpus()

[2025-04-09 15:16:41,024 INFO] HTTP Request: GET http://10.197.17.39:8000/v1/models "HTTP/1.1 200 OK"


Resolving data files:   0%|          | 0/37 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/37 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/37 [00:00<?, ?it/s]

[2025-04-09 15:29:02,658 INFO] Loaded 35678076 passages from corag/kilt-corpus


In [9]:
corag_agent: CoRagAgent = CoRagAgent(vllm_client=vllm_client, corpus=corpus, e5_ip=e5_ip, vllm_ip=vllm_ip)

[2025-04-07 22:49:30,482 INFO] HTTP Request: GET http://10.197.17.39:8000/v1/models "HTTP/1.1 200 OK"


In [7]:
import heapq
import math
from copy import deepcopy
from typing import Optional, List, Dict

from agent.agent_utils import RagPath
from openai.types.chat import ChatCompletion
from prompts import get_generate_subquery_prompt, get_generate_intermediate_answer_prompt, get_generate_final_answer_prompt

from agent.corag_agent import CoRagAgent
from agent.corag_agent import _normalize_subquery

from vllm_client_local import VllmClient
from datasets import Dataset

In [21]:
class TreeNode:
    def __init__(self, path: RagPath, logprob: float, parent: Optional["TreeNode"] = None):
        self.path = deepcopy(path)
        self.logprob = logprob  # cumulative log probability of subqueries
        self.depth = len(path.past_subqueries)
        self.parent = parent
        self.levin_cost = 0.0

    def __lt__(self, other):
        return self.levin_cost < other.levin_cost
    
class CoRagAgentWithPHS(CoRagAgent):
    
    def __init__(self, vllm_client: VllmClient, corpus: Dataset, e5_ip: str, vllm_ip: str, 
                 confidence_threshold: float = 0.5):
        """Initializes the CoRagAgentWithPHS class.
        This class is a specialized version of the CoRagAgent that implements a tree search algorithm
        called 'Policy-guided heuristic search' (PHS) to quickly find a good path for answering a query.

        Args:
            vllm_client (VllmClient): VLLM client for answering queries.
            corpus (Dataset): Dataset containing the documents to be searched.
            e5_ip (str): IP address of the E5 server.
            vllm_ip (str): IP address of the VLLM server.
            confidence_threshold (float, optional): Confidence threshold to determine if completed. Defaults to 0.5.
        """
        super().__init__(vllm_client, corpus, e5_ip, vllm_ip)
        self.confidence_threshold = confidence_threshold

    def tree_search(
        self, query: str, 
        task_desc: str,
        max_path_length: int = 3,
        max_message_length: int = 4096,
        temperature: float = 0.7,
        expand_size: int = 4,
        max_tree_size = 100,
        **kwargs
    ) -> RagPath:
        root_path = RagPath(query=query, past_subqueries=[], past_subanswers=[], past_doc_ids=[])
        root_node = TreeNode(path=root_path, logprob=0.0)

        open_list = []
        heapq.heappush(open_list, root_node)
        explored_num = 0
        while open_list and explored_num < max_tree_size:
            explored_num += 1
            if explored_num % 10 == 0:
                logger.info(f"Explored nodes: {explored_num}")
            node = heapq.heappop(open_list)
            current_path = node.path

            if self._is_solution(current_path, task_desc, max_message_length):
                return current_path

            messages = get_generate_subquery_prompt(
                query=query,
                past_subqueries=current_path.past_subqueries,
                past_subanswers=current_path.past_subanswers,
                task_desc=task_desc
            )
            self._truncate_long_messages(messages, max_length=max_message_length)

            completion: ChatCompletion = self.vllm_client.call_chat(
                messages=messages,
                return_str=False,
                n=expand_size,
                # extra_body={"prompt_logprobs": 1},
                logprobs=True,
                temperature=temperature,
                **kwargs
            )

            for choice in completion.choices:
                subquery = _normalize_subquery(choice.message.content)
                if subquery in current_path.past_subqueries:
                    continue
                
                token_logprobs = [c.logprob for c in choice.logprobs.content]
                
                sub_logprob = sum(token_logprobs) / max(len(token_logprobs), 1)

                subanswer, doc_ids = self._get_subanswer_and_doc_ids(
                    subquery=subquery, max_message_length=max_message_length
                )

                new_path = RagPath(
                    query=query,
                    past_subqueries=current_path.past_subqueries + [subquery],
                    past_subanswers=current_path.past_subanswers + [subanswer],
                    past_doc_ids=current_path.past_doc_ids + [doc_ids]
                )

                running_cost = len(new_path.past_subqueries) # number of nodes (g(n) in the paper)
                policy_logprob = node.logprob + sub_logprob # cumulative log probability (log(pi(n)) in the paper)
                heuristic_cost = self._estimate_heuristic_llm(new_path, task_desc, max_message_length, **kwargs) # h(n) in the paper
                levin_cost = math.log(running_cost + heuristic_cost + 1e-5) - policy_logprob

                new_node = TreeNode(path=new_path, logprob=policy_logprob, parent=node)
                new_node.levin_cost = levin_cost
                heapq.heappush(open_list, new_node)

        # use this only if nothing worked idk?? can use logger instead
        print(f"Did not find a solution within {max_tree_size} nodes. Returning the root path.")
        return self.sample_path(
                    query=query,
                    task_desc=task_desc,
                    max_path_length=max_path_length,
                    max_message_length=max_message_length,
                    temperature=temperature,
                    **kwargs
                )

    def _is_solution(self, path: RagPath, task_desc: str, max_message_length: int) -> bool:
        log_prob = self._eval_single_path(
            path,
            max_message_length=max_message_length
        )
        # This is the log probability of the string 'No relevant information found'
        # So, if the log probability of the path is less than this, we consider it a solution.
        # This is a heuristic, and the threshold can be adjusted based on the model's behavior.
        return log_prob < math.log(self.confidence_threshold)

    def _estimate_heuristic_llm(self, path: RagPath, task_desc: str, max_message_length: int, **kwargs) -> int:
        """Heuristic function to estimate the number of remaining subqueries.
        This function uses the LLM to predict how many subqueries are needed to fully answer the original query.

        Args:
            path (RagPath): RagPath to the current node in the search tree
            task_desc (str): Task description
            max_message_length (int): Maximum message length for the LLM
            **kwargs: Additional arguments for the LLM call

        Returns:
            int: Estimated number of remaining subqueries
        """
        # Ask the LLM how many subqueries it expects are remaining
        messages: List[Dict] = get_generate_intermediate_answer_prompt(
            subquery=path.query,
            documents=[f'Q: {q}\nA: {a}' for q, a in zip(path.past_subqueries, path.past_subanswers)],
        )
        # messages.append({'role': 'user', 'content': 'How many more subqueries are needed to fully answer the original query. Respond with a single integer.'})
        messages.append({'role': 'user', 
                         'content': f'What subqueries should be asked to fully answer the original query: "{path.query}". Separate subqueries with question marks "?".'})
        self._truncate_long_messages(messages, max_length=max_message_length)

        response: ChatCompletion = self.vllm_client.call_chat(
            messages=messages,
            return_str=False,
            # max_tokens=5,
            **kwargs
        )

        text = response.choices[0].message.content.strip()
        
        est_remaining = max(1,len(text.split("?")))

        # response: ChatCompletion = self.vllm_client.call_chat(
        #     messages=messages,
        #     return_str=False,
        #     max_tokens=5,
        #     **kwargs
        # )

        # # idk how to ensure its an int, this was an attempt
        # try:
        #     est_remaining = int(text.split()[0])
        # except Exception:
        #     est_remaining = max(1, 3 - len(path.past_subqueries))  # fallback

        return max(0, est_remaining)

In [9]:
tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(get_vllm_model_id(host=vllm_ip))
tokenizer_lock: threading.Lock = threading.Lock()
processed_cnt: AtomicCounter = AtomicCounter()
total_cnt: int = 0

[2025-04-09 15:32:08,004 INFO] HTTP Request: GET http://10.197.17.39:8000/v1/models "HTTP/1.1 200 OK"


In [10]:
ds: Dataset = load_dataset('corag/multihopqa', "musique", split="validation")
ds = ds.remove_columns([name for name in ['subqueries', 'subanswers', 'predictions'] if name in ds.column_names])
ds = ds.add_column('task_desc', ['answer multi-hop questions' for _ in range(len(ds))])

In [11]:
ds = ds.select(range(16))
args = Arguments()
ex = ds[0]

In [22]:
phs_agent = CoRagAgentWithPHS(
    vllm_client=vllm_client,
    corpus=corpus,
    e5_ip=e5_ip,
    vllm_ip=vllm_ip,
    confidence_threshold=0.05
)

[2025-04-09 16:03:17,680 INFO] HTTP Request: GET http://10.197.17.39:8000/v1/models "HTTP/1.1 200 OK"


In [23]:
path: RagPath = phs_agent.tree_search(
        query=ex['query'], 
        task_desc=ex['task_desc'],
        max_path_length=6,
        temperature=0.2,
        max_tree_size=50
    )

[2025-04-09 16:03:20,430 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:20,762 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:21,134 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:21,567 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:21,786 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:22,118 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:22,336 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:22,871 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:23,091 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "

Did not find a solution within 50 nodes. Returning the root path.


[2025-04-09 16:03:24,874 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:25,032 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:25,189 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:25,346 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:25,504 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:25,691 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:26,062 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:26,224 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-09 16:03:26,367 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "

In [None]:
path: RagPath = corag_agent.sample_path(
        query=ex['query'], 
        task_desc=ex['task_desc'],
        max_path_length=6,
        temperature=0., 
        max_tokens=64
    )

[2025-04-07 17:27:01,817 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:02,093 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:02,251 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:02,394 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:02,719 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:02,864 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:03,019 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:03,175 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "HTTP/1.1 200 OK"
[2025-04-07 17:27:03,317 INFO] HTTP Request: POST http://10.197.17.38:8000/v1/chat/completions "

In [17]:
documents: List[str] = format_documents_for_final_answer(
    args=args,
    context_doc_ids=ex['context_doc_ids'],
    tokenizer=tokenizer, corpus=corpus,
    lock=tokenizer_lock
)

In [18]:
prediction: str = phs_agent.generate_final_answer(
        corag_sample=path,
        task_desc=ex['task_desc'],
        documents=documents,
        max_message_length=args.max_len,
        temperature=0., max_tokens=128
    )

ex_with_path = copy.deepcopy(ex)
ex_with_path['subqueries'] = path.past_subqueries
ex_with_path['subanswers'] = path.past_subanswers
ex_with_path['path_doc_ids'] = path.past_doc_ids
if 'title' in corpus.column_names:
    ex_with_path['path_doc_titles'] = [
        [corpus[int(doc_id)]['title'] for doc_id in doc_ids] for doc_ids in path.past_doc_ids
    ]
ex_with_path['prediction'] = prediction

[2025-04-09 15:34:31,439 INFO] HTTP Request: POST http://10.197.17.39:8000/v1/chat/completions "HTTP/1.1 200 OK"


In [19]:
ex_with_path['subqueries']

['What is the name of the Green performer?',
 'Is the Green performer male?',
 'Is the Green performer married?',
 'Is the Green performer married to a woman?',
 'Is the Green performer married to a woman named Sarah?',
 "What is the name of the Green performer's spouse?"]

In [20]:
ex_with_path['subanswers']

['No relevant information found',
 'No relevant information found.',
 'Yes',
 'Yes',
 'No relevant information found.',
 'Laura Bayley']