In [1]:
from vllm import LLM, SamplingParams
import json, os, random
import pdb
from datasets import load_dataset
import time
import faiss

import hydra 
from omegaconf import DictConfig, OmegaConf
from FlagEmbedding import FlagModel

INFO 09-30 23:40:10 [__init__.py:216] Automatically detected platform cuda.


In [2]:
from vllm import LLM, SamplingParams
import json, os, random
import pdb
from datasets import load_dataset
import tqdm
import requests
import time
import hydra
from omegaconf import DictConfig

def mystrip(one_str):
    one_str = one_str.strip()
    one_str = one_str.strip("\\n")
    one_str = one_str.strip("#")
    return one_str

def extract_substring2(text, start_str, stop_strs):
    start_index = text.find(start_str)
    if start_index == -1:
        return None
    start = start_index + len(start_str)
    
    end = len(text)
    
    for stop_str in stop_strs:
        temp_index = text.find(stop_str, start)
        if temp_index != -1 and temp_index < end:
            end = temp_index
    if start < end:
        return mystrip(text[start:end])
    else:
        return None

def split_response(response):
    mydict = {
        "original":response
    }
    str_analysis = "The problem analysis:"
    str_query = "The retrieval query:"
    str_answer = "The final answer:"
    stop_strs = [str_analysis, str_query, str_answer, "The retrieval documents:", "###", "####"]
    stop_strs_query = [str_analysis, str_query, str_answer, "The retrieval documents:", "###", "####", "\nStep", "?"]
    stop_strs_answer = [str_analysis, str_query, str_answer, "The retrieval documents:", "###", "####", "\nStep"]
    
    start_index = response.find(str_analysis)
    if start_index==-1:    
        mydict['analysis']=None
        return mydict
    else:
        mydict["analysis"]=extract_substring2(response, str_analysis, stop_strs)
    start_index_query = response.find(str_query, start_index+len(str_analysis))
    start_index_answer = response.find(str_answer, start_index+len(str_analysis))
    if start_index_query==-1 and start_index_answer==-1:
        mydict['analysis']=None
        return mydict
    elif start_index_query!=-1 and start_index_answer!=-1:
        if start_index_query<start_index_answer:
            mydict['query']=extract_substring2(response[start_index_query:], str_query, stop_strs_query)
        else:
            mydict['answer']=extract_substring2(response[start_index_answer:], str_answer, stop_strs_answer)
    elif start_index_query!=-1:
        mydict['query']=extract_substring2(response[start_index_query:], str_query, stop_strs_query)
    elif start_index_answer!=-1:
        mydict['answer']=extract_substring2(response[start_index_answer:], str_answer, stop_strs_answer)
    else:
        raise ValueError
    return mydict

def GetRetrieval(retrieve_url: str, querys: list, cfg: DictConfig):
    res = []
    for i in tqdm.tqdm(range(0, len(querys), cfg.retrieval.post_batch_size), desc="Retrieving documents"):
        subset = querys[i:i + cfg.retrieval.post_batch_size]
        for _ in range(cfg.retrieval.ssl_retry):
            try:
                response = requests.post(retrieve_url, json={"queries": subset}, headers={"Content-Type": "application/json"})
                if response.status_code == 200 and response.json():
                    res.extend(response.json())
                    break
            except requests.exceptions.RequestException as e:
                print(f"Request failed: {e}, retrying...")
                time.sleep(2) # 재시도 전 잠시 대기
        else:
            # 최종적으로 실패한 경우
            print(f"Fail info: {response.text if 'response' in locals() else 'No response'}")
            raise ValueError(f"Failed to retrieve query:{i} ~ {i + cfg.retrieval.post_batch_size}!!!!!!!!!!")
    return res

def solve(cfg: DictConfig):
    ckpt, records = solve_init(cfg)
    solve_main(cfg, ckpt, records)
    
    remain_idxs = [i for i, record in enumerate(records) if 'answer' not in record]
    print(f"Remain records: {len(remain_idxs)}")
    
    if len(remain_idxs) > 0:
        solve_directly(cfg, ckpt, records)

    with open("records.jsonl", "w", encoding='utf-8') as f:
        for record in records:
            json.dump(record, f, ensure_ascii=False)
            f.write('\n')

def solve_init(cfg: DictConfig):
    llm_args = {
        'model': cfg.model.path,
        'tensor_parallel_size': cfg.model.tensor_parallel_size
    }
    if cfg.debug:
        llm_args['tensor_parallel_size'] = 1

    ckpt = LLM(**llm_args)
    print("CKPT is ready.")

    dataset = dataset = load_dataset('hotpotqa/hotpot_qa', 'fullwiki')['validation']
    
    if cfg.debug:
        dataset_size = len(dataset)
        sample_size = min(8, dataset_size)
        sampled_indices = random.sample(range(dataset_size), sample_size)
        dataset = dataset.select(sampled_indices)

    records = []
    query_list = [data['question'] for data in dataset]
    
    for i, data in enumerate(dataset):
        record = {
            'question': data['question'],
            'golden_answers': data['answer'],
            'state': "undo",
            'resample_times': 0
        }
        records.append(record)
        
    doc_list = GetRetrieval(cfg.retrieval.url, query_list, cfg)
    
    for doc_one, record in zip(doc_list, records):
        record['doc'] = "\n".join([doc_one_one['contents'] for doc_one_one in doc_one[:cfg.retrieval.num_of_docs]])
        
    return ckpt, records

def generate_naive_rag_prompt(question, doc):
    system_message = f"""Answer the question based on the given document. Only give me the answer and do not output any other words.
The following are given documents.
{doc}
"""
    user_message = f"""The question: {question}"""
    message_list = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ]
    return message_list
def generate_naive_rag_cot_prompt(question, doc):
    system_message = """You are a helpful assistant that answers questions based on document retrieval with step-by-step reasoning.

For any question, please structure your response in this format:
The problem analysis: [Provide detailed step-by-step reasoning]
The final answer: [Provide the concise final answer]

Example:
User: The question: What was the company's revenue in 2023?
Assistant:
The problem analysis: I need to find information about the company's revenue in 2023. Looking at the provided document, I can see in the third paragraph that "the company's total revenue for fiscal year 2023 reached $128 million." This clearly states the exact revenue figure I'm looking for.
The final answer: The company's revenue in 2023 was $128 million.

Please carefully analyze the provided documents, ensure your answer is fully based on the document content, and use step-by-step reasoning to reach accurate conclusions."""

    system_message += f"""

The following are the provided documents:
{doc}
"""

    user_message = f"""The question: {question}"""
    message_list = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ]
    return message_list


def solve_main(cfg: DictConfig, ckpt: LLM, records: list):
    sampling_params = SamplingParams(temperature=cfg.params.temperature_main, max_tokens=cfg.params.max_tokens)
    messages = [generate_naive_rag_cot_prompt(record['question'], record['doc']) for record in records]
    
    outputs = ckpt.chat(messages, sampling_params)
    outputs = [output.outputs[0].text for output in outputs]
    vals = [split_response(output) for output in outputs]
        
    for i, val in enumerate(vals):
        records[i]['output'] = val['original']
        if val.get('answer'):
            records[i]['answer'] = val['answer']
            records[i]['state'] = "done"
        else:
            records[i]['state'] = "wrong"

def solve_directly(cfg: DictConfig, ckpt: LLM, records: list):
    sampling_params = SamplingParams(temperature=cfg.params.temperature_main, max_tokens=cfg.params.max_tokens)
    
    remain_idxs = [i for i, record in enumerate(records) if 'answer' not in record]
    messages = [generate_naive_rag_prompt(records[remain_idx]['question'], records[remain_idx]['doc']) for remain_idx in remain_idxs]
    
    outputs = ckpt.chat(messages, sampling_params)
    outputs = [output.outputs[0].text for output in outputs]
        
    for output, remain_idx in zip(outputs, remain_idxs):  
        records[remain_idx]['answer'] = output
        records[remain_idx]['state'] = "done"
        records[remain_idx]['resample_times'] = records[remain_idx].get('resample_times', 0) + 1


In [None]:
file_path = '/home/kangjh/Research/ParametricReasoning/RPRAG/benchmark/NaiveRAG/conf/config.yaml'
cfg = OmegaConf.load(file_path)
print(OmegaConf.to_yaml(cfg))
start = time.time()
print(f"Start at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start))}")

solve(cfg)

end = time.time()
print(f"End at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end))}")
elapsed_time = end - start
print(f"Elapsed time: {elapsed_time:.2f} seconds")

defaults:
- _self_
debug: false
paths:
  dev_dataset: /mnt/raid5/kangjh/downloads/datasets/hotpotqa/dev/dev.json
  log_dir: ./logs
model:
  path: meta-llama/Meta-Llama-3-8B-Instruct
  tensor_parallel_size: 1
retrieval:
  url: http://10.0.12.120:8001/search_batch
  post_batch_size: 2048
  ssl_retry: 8
  num_of_docs: 10
params:
  max_tokens: 512
  temperature_main: 0.0
  temperature_fallback: 0.0

Start at 2025-09-30 23:40:12
INFO 09-30 23:40:12 [utils.py:328] non-default args: {'disable_log_stats': True, 'model': 'meta-llama/Meta-Llama-3-8B-Instruct'}


INFO 09-30 23:40:21 [__init__.py:742] Resolved architecture: LlamaForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 09-30 23:40:22 [__init__.py:1815] Using max model len 8192
INFO 09-30 23:40:25 [scheduler.py:222] Chunked prefill is enabled with max_num_batched_tokens=8192.
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:26 [core.py:654] Waiting for init message from front-end.
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:26 [core.py:76] Initializing a V1 LLM engine (v0.10.2) with config: model='meta-llama/Meta-Llama-3-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Meta-Llama-3-8B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=Fal



[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:30 [gpu_model_runner.py:2370] Loading model from scratch...
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:30 [cuda.py:362] Using Flash Attention backend on V1 engine.
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:30 [weight_utils.py:348] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]


[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:35 [default_loader.py:268] Loading weights took 3.89 seconds
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:35 [gpu_model_runner.py:2392] Model loading took 14.9596 GiB and 5.016109 seconds
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:41 [backends.py:539] Using cache directory: /home/kangjh/.cache/vllm/torch_compile_cache/ec52f37144/rank_0_0/backbone for vLLM's torch.compile
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:41 [backends.py:550] Dynamo bytecode transform time: 5.35 s
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:43 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 1.839 s
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:44 [monitor.py:34] torch.compile takes 5.35 s in total
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:46 [gpu_worker.py:298] Available KV cache memory: 26.48 GiB
[1;36m(EngineCore_DP0

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 67/67 [00:04<00:00, 14.77it/s]


[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:51 [gpu_model_runner.py:3118] Graph capturing finished in 5 secs, took 0.53 GiB
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:51 [gpu_worker.py:391] Free memory on device (47.13/47.43 GiB) on startup. Desired GPU memory utilization is (0.9, 42.69 GiB). Actual usage is 14.96 GiB for weight, 1.24 GiB for peak activation, 0.02 GiB for non-torch memory, and 0.53 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=27701056921` to fit into requested memory, or `--kv-cache-memory=32465419264` to fully utilize gpu memory. Current kv cache memory in use is 28428768665 bytes.
[1;36m(EngineCore_DP0 pid=2338256)[0;0m INFO 09-30 23:40:51 [core.py:218] init engine (profile, create kv cache, warmup model) took 16.23 seconds
INFO 09-30 23:40:53 [llm.py:295] Supported_tasks: ['generate']
INFO 09-30 23:40:53 [__init__.py:36] No IOProcessor plugins requested by the model
CKPT is ready.


Retrieving documents:  25%|██▌       | 1/4 [49:04<2:27:14, 2944.79s/it]

Retrieving documents:  25%|██▌       | 1/4 [1:18:23<3:55:09, 4703.21s/it]

ERROR 10-01 00:59:22 [core_client.py:564] Engine core proc EngineCore_DP0 died unexpectedly, shutting down client.





KeyboardInterrupt: 

In [None]:
ret_result = GetRetrieval('http://10.0.12.120:8001/search_batch', ['What is the capital of France?', 'Who is the president of the United States?'], cfg)

Retrieving documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving documents: 100%|██████████| 1/1 [00:14<00:00, 14.35s/it]
