ProRAG is a process-supervised reinforcement learning framework designed to resolve the credit assignment problem in multi-hop RAG tasks.
- [January 30, 2026] 📄 Our paper is now available on arXiv.
- [January 29, 2026] 🤗 We have released our models on Hugging Face.
Retrieval-Augmented Generation (RAG) models often suffer from reward sparsity and inefficient credit assignment when optimized with traditional outcome-based Reinforcement Learning (RL). Coarse-grained scalar rewards fail to identify specific erroneous steps within long-horizon trajectories, leading to "process hallucinations"—where models reach correct answers through flawed logic.
ProRAG addresses these challenges by integrating learned step-level supervision directly into the online optimization loop.
Our framework consists of four progressive stages:
- Supervised Policy Warmup (SFT): Initialize the model with a structured reasoning format.
- MCTS-based Process Reward Model (PRM): Quantify intermediate reasoning quality using Monte Carlo Tree Search.
- PRM-Guided Reasoning Refinement (RFT): Align the policy with fine-grained process preferences to mitigate the cold-start problem.
- Process-Supervised Reinforcement Learning: Optimize with a dual-granularity advantage mechanism that aggregates step-level process rewards with global outcome signals.
- Python 3.13+
- CUDA 12.x (Recommended)
# 1. Create and activate conda environment
conda create -n prorag python=3.13.11
conda activate prorag
# 2. Install vLLM
pip install vllm==0.11.0
# 3. Install Requirements
pip install -e .
# 4. Install Flash Attention 2
pip install flash-attn==2.8.3 --no-build-isolation
# 5. Install W&B
pip install wandb
wandb loginIf you would like to call a local retriever as the search engine, you can install the environment as follows. (We recommend using a separate environment.)
conda create -n retriever python=3.10
conda activate retriever
# We recommend installing torch with conda for faiss-gpu
conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers datasets pyserini
# Install the gpu version faiss to guarantee efficient RL rollout
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
# API function
pip install uvicorn fastapiOur training pipeline corresponds strictly to the four stages described in the paper.
Before running any training or generation tasks, you need to start the retrieval service and prepare the data for training.
conda activate retriever
export RETRIEVAL_PATH="data/indices/wikipedia"
# 1. Download Index
bash search/download.sh
# 2. Launch Service
bash search/retrieval_launch.shTip: Keep this terminal open. Open a new terminal and activate prorag for the next steps.
conda activate prorag
# Data Preprocessing requires an API Key (OpenAI/DeepSeek/vLLM)
export OPENAI_API_KEY="YOUR_KEY"
bash scripts/preprocess.shFine-tune the model using constructed datasets with structured reasoning-action formats to establish a reference policy (
bash scripts/sft.shTrain the Process Reward Model (PRM) using contrastive pairs collected via Monte Carlo Tree Search (MCTS). This model provides step-level feedback.
# Ensure you have sufficient GPU memory for vLLM servers
export OPENAI_API_KEY="YOUR_KEY"
bash scripts/prm.sh
⚠️ Note: Ensure your GPUs have sufficient memory. The script automatically spins up vLLM servers on GPUs 0-3 for parallel MCTS generation, and then releases resources for the subsequent PRM training.
Perform Rejection Sampling Fine-Tuning (RFT) using high-quality trajectories filtered by the PRM. This step bridges the gap between SFT and RL.
bash scripts/rft.shFinally, run the online reinforcement learning with the Dual-Granularity Advantage mechanism, combining outcome rewards and process rewards.
bash scripts/rl.sh@misc{wang2026proragprocesssupervisedreinforcementlearning, title={ProRAG: Process-Supervised Reinforcement Learning for Retrieval-Augmented Generation}, author={Zhao Wang and Ziliang Zhao and Zhicheng Dou}, year={2026}, eprint={2601.21912}, archivePrefix={arXiv}, primaryClass={cs.AI}, url={https://arxiv.org/abs/2601.21912}, }
This project is licensed under the MIT License.
