Retrieval-Grounded Reinforcement Learning for Judgment Document Generation
Judgment document generation aims to draft a complete court judgment from a case fact description, requiring both accurate grounding in legal knowledge and multi-step legal reasoning. Prior retrieval-augmented generation (RAG) improves coverage of relevant statutes, but a fixed top-k retrieval strategy often introduces irrelevant articles, leading to over-citation and degraded legal faithfulness. To address this issue, we propose a retrieval-grounded framework that strengthens both evidence acquisition and long-form drafting. On the retrieval side, an LLM first plans multiple legally oriented queries based on the fact description. On the generation side, we apply retrieval-augmented supervised fine-tuning to stabilize document structure and employ multi-objective Group Relative Policy Optimization to jointly optimize legal correctness, writing quality, and reasoning-form compliance. Experiments on the JuDGE benchmark show consistent improvements over strong retrieval-augmented baselines, particularly in statute referencing precision and overall legal coherence.
JuDGE_RL is an end-to-end system for automatic Chinese criminal judgment document generation. Given a case fact description, the system generates a complete judgment including factual findings, legal analysis, reasoning, and sentencing conclusions.
Traditional RAG for judgment generation suffers from two issues:
- Low recall: using only the case fact as query misses statutes related to sentencing circumstances, supplementary penalties, etc.
- Low precision: fixed top-k retrieval introduces irrelevant statutes that mislead the generation model into over-citation.
Our framework addresses these through three layers:
| Layer | Method | Problem Solved |
|---|---|---|
| Retrieval | Dual-path retrieval (MRAG + LLM Agent) | Improve statute recall and precision |
| SFT | Retrieval-Augmented Supervised Fine-Tuning | Learn to use retrieved statutes to produce well-structured documents |
| RL | Multi-objective GRPO | Further optimize legal correctness, text quality, and reasoning form |
Why dual-path retrieval? MRAG (Dense Retriever + Reranker) achieves high recall, while the LLM Agent (QueryGen + LawSelect) achieves high precision. Fusion combines both strengths.
Why GRPO over PPO/DPO? GRPO (Group Relative Policy Optimization) does not require a separate critic model. It ranks multiple generations within the same batch. For long-form judgment generation, GRPO offers better memory efficiency and training stability than PPO.
Why multi-objective reward? Judgment quality cannot be measured by a single metric. Optimizing only text fluency (BERTScore) leads to legally inaccurate outputs; optimizing only statute F1 leads to outputs that list statutes without proper analysis. Multi-objective weighting forces the model to balance across dimensions.
Case Fact Description
|
+-------------------------------------+
| |
v v
+-------------------+ +-------------------------------+
| MRAG Path | | LLM Agent Path |
| | | |
| Fact -> Dense -> | | Fact -> QueryGen(LLM) -> |
| Reranker -> K | | Dense -> Reranker -> |
| | | LawSelect(LLM) -> refined |
| Strength: | | Strength: |
| High recall | | High precision |
+--------+----------+ +--------------+----------------+
| |
+---------------+---------------+
v
Hybrid Fusion (RRF / Agent-First)
|
v
Top-10 Statutes + Similar Cases
|
v
+--------------------------------------+
| Judgment Generation Model |
| |
| Stage 1: SFT (LoRA) |
| Supervised fine-tuning with |
| reference judgments |
| |
| Stage 2: GRPO (full-parameter) |
| Reward = 0.60 * Legal Accuracy |
| + 0.30 * Text Quality |
| + 0.10 * Reasoning Form |
| |
| Base: Qwen3-4B / Qwen2.5-3B |
+------------------+-------------------+
v
Complete Judgment
(Facts + Legal Analysis + Reasoning + Sentencing)
Uses the case fact directly as query for two-stage retrieval:
Case Fact -> Dense Retriever (top-50) -> Reranker (top-10) -> Statutes
- Dense Retriever: Fine-tuned
chinese-roberta-wwm-extwith contrastive learning + hard negatives, trained with K-Fold cross-validation. - Reranker: Fine-tuned
chinese-roberta-wwm-extas cross-encoder with pairwise ranking loss.
LLM understands the case, plans multi-angle retrieval queries, and filters candidates:
Case Fact -> QueryGen (5-8 queries) -> Dense (top-50) -> Reranker (top-20) -> LawSelect (5-10 statutes)
| Component | Model | Role | RL Reward |
|---|---|---|---|
| QueryGen | Qwen2.5-7B | Generate diverse legal queries | 0.05Format + 0.25Diversity + 0.70*DenseScore |
| LawSelect | Qwen2.5-7B | Filter truly relevant statutes | 0.45R@5 + 0.35P@5 + 0.15R@10 + 0.05Quantity |
Why multiple queries? A case involves multiple legal dimensions (crime definition, sentencing range, supplementary penalties, mitigating factors). A single query cannot cover all. For example, a theft case requires statutes for: theft elements (Art. 264), imprisonment range (Art. 45), fines (Art. 52/53), and voluntary confession (Art. 67).
Why LawSelect? Dense + Reranker only ranks by textual similarity. LLM can perform legal reasoning to distinguish "textually similar but inapplicable" statutes (e.g., distinguishing theft vs. robbery).
Two retrieval paths are merged via output-level fusion:
- RRF (recommended): Reciprocal Rank Fusion weighted by source reliability.
- Agent-First: Keep all Agent outputs (filtered by LawSelect), supplement with MRAG results not covered by Agent.
Stage 1 — SFT (LoRA):
- Fine-tune with reference judgments as labels.
- LoRA config: r=128, alpha=256, target modules: q/k/v/o_proj.
- Purpose: learn document structure and format.
Stage 2 — GRPO (full-parameter):
- Further optimize on top of SFT model with multi-objective reward.
- Purpose: SFT learns "how to write", GRPO teaches "how to write correctly".
Reward function (train/src/rl_plugin1.py):
Total = 0.60 * Legal Accuracy + 0.30 * Text Quality + 0.10 * Reasoning Form
Legal Accuracy = 0.35 * Statute_F1 # Correct statute citations
+ 0.30 * Crime_F1 # Correct crime identification
+ 0.20 * Prison_Match # Reasonable prison term
+ 0.15 * Fine_Match # Reasonable fine amount
Text Quality = BERTScore # Computed separately for reasoning and sentencing sections
Reasoning Form = <think> format + length + no repetition # Only for Thinking models
Why Statute F1 has the highest weight (35%)? Statute citation is the legal foundation of a judgment. Incorrect citations invalidate the judgment entirely. In contrast, minor variations in prison terms or fine amounts within a reasonable range are acceptable.
data/train.json (raw: text_id + case facts + reference judgment)
|
+-- script/sft_data.py --> data/train_sft.json (SFT: prompt + reference)
+-- script/sft_data.py --> data/test_sft.json (Inference: prompt only)
+-- script/rl_data.py --> data/rl_train/train.jsonl (RL: messages + reference)
# With retrieval results, MRAG-augmented versions are generated:
data/train.json + retrieval --> data/train_sft_mrag.json
--> data/test_sft_mrag.json
--> data/rl_train_mrag/train.jsonl
Prompt consistency: sft_data.py, rl_data.py, and inf.py share identical prompt templates, ensuring no distribution shift between training and inference.
JuDGE_RL/
├── bash/ # Shell scripts
│ ├── agent/ # LLM Agent
│ ├── retriever/ # Dense Retriever
│ ├── reranker/ # Reranker
│ ├── data_train.sh # Generate SFT/RL training data
│ ├── train_sft.sh # SFT training
│ ├── train_rl.sh # RL (GRPO) training
│ ├── loramerge.sh # SFT LoRA merge
│ ├── gen.sh # Inference (9 modes)
│ ├── convert.sh # Format conversion
│ └── eval.sh # Evaluation
├── data/ # Data files (included in repo)
├── evaluation/ # Evaluation scripts
├── mrag/ # Retrieval modules
│ └── agent/ # LLM Agent (QueryGen, LawSelect, Fusion)
├── reranker/ # Reranker module
├── train/ # Training & inference
│ ├── src/ # SFT training, RL reward functions
│ └── deploy/ # vLLM inference, LoRA merge
└── script/ # Data generation scripts
Three conda environments are needed:
| Environment | Purpose |
|---|---|
swift |
SFT/RL training, inference (based on ms-swift) |
judge |
Retriever/Reranker training, evaluation |
vllm |
Inference acceleration (optional) |
conda create -n swift python=3.10 -y && conda activate swift && pip install -r requirements_swift.txt
conda create -n judge python=3.10 -y && conda activate judge && pip install -r requirements_judge.txt
conda create -n vllm python=3.10 -y && conda activate vllm && pip install -r requirements_vllm.txt| Model | Purpose | Link |
|---|---|---|
| Qwen2.5-3B-Instruct | Generation base model | HuggingFace |
| chinese-roberta-wwm-ext | Retriever / Reranker | HuggingFace |
| Qwen3-4B | Thinking model experiments | HuggingFace |
| Qwen2.5-7B-Instruct | Agent (QueryGen/LawSelect) | HuggingFace |
All model paths are centralized in a single file: bash/paths.sh. Before running any script, you must edit this file to point to your local model directories:
# bash/paths.sh — edit these paths to match your environment
export QWEN3_MODEL_PATH="/path/to/Qwen3-4B"
export QWEN25_MODEL_PATH="/path/to/Qwen2.5-3B-Instruct"
export QWEN25_7B_MODEL_PATH="/path/to/Qwen2.5-7B-Instruct"
export ROBERTA_MODEL_PATH="/path/to/chinese-roberta-wwm-ext"
export BERT_MODEL_PATH="/path/to/bert-base-chinese"Every shell script automatically sources bash/paths.sh and validates that the required model directory exists before proceeding. If a path is wrong, the script will print a clear error message and exit.
You can also override any path via environment variable without editing the file:
QWEN3_MODEL_PATH=/my/models/Qwen3-4B bash bash/train_sft.shbash/paths.sh manages paths for downloaded base models only. In addition, bash/gen.sh contains paths for trained model outputs — SFT merged models and RL checkpoints — which point to directories under output/ created by the training scripts.
SFT model paths default to output/sft_*/merge and typically do not need manual adjustment, because the merge script (loramerge.sh) always writes to the same merge/ subdirectory:
# gen.sh — SFT paths (usually no changes needed)
SFT_QWEN3="${SFT_QWEN3:-output/sft_qwen3-4b_lora/merge}"
SFT_MRAG_QWEN3="${SFT_MRAG_QWEN3:-output/sft_qwen3-4b_lora_mrag/merge}"RL checkpoint paths require manual configuration. RL training (ms-swift GRPO) generates output directories with version timestamps, for example:
output/rl_qwen3-4b_grpo_sft_full/v19-20260116-061030/checkpoint-501
The path must point to a specific checkpoint directory with full model weights. The default values in gen.sh will NOT work unless your training happens to produce the exact same version directory names.
How to find your checkpoint path — after RL training completes, inspect the output:
ls output/rl_qwen3-4b_grpo_sft_full/
# v19-20260116-061030/
ls output/rl_qwen3-4b_grpo_sft_full/v19-20260116-061030/
# checkpoint-167/ checkpoint-334/ checkpoint-501/Pick the final (or best) checkpoint and set the full path in gen.sh.
Configuration methods — edit gen.sh directly or override via environment variables:
# Method 1: Edit gen.sh directly
RL_SFT_QWEN3_PATH="${RL_SFT_QWEN3_PATH:-output/rl_qwen3-4b_grpo_sft_full/v19-20260116-061030/checkpoint-501}"
# Method 2: Override via environment variable at runtime
RL_SFT_QWEN3_PATH=output/rl_qwen3-4b_grpo_sft_full/v19-20260116-061030/checkpoint-501 \
RL_SFT_QWEN25_PATH=output/rl_qwen2.5-3b_grpo_sft_full/v17-20260117-091241/checkpoint-501 \
MODES=sft_rl bash bash/gen.shAll RL path variables in gen.sh:
| Variable | Description | Example Value |
|---|---|---|
RL_BASE_QWEN3_PATH |
Qwen3 Base→RL | output/rl_qwen3-4b_grpo_full/<version>/checkpoint-<step> |
RL_BASE_QWEN25_PATH |
Qwen2.5 Base→RL | output/rl_qwen2.5-3b_grpo_full/<version>/checkpoint-<step> |
RL_SFT_QWEN3_PATH |
Qwen3 SFT→RL | output/rl_qwen3-4b_grpo_sft_full/<version>/checkpoint-<step> |
RL_SFT_QWEN25_PATH |
Qwen2.5 SFT→RL | output/rl_qwen2.5-3b_grpo_sft_full/<version>/checkpoint-<step> |
RL_BASE_MRAG_QWEN3_PATH |
Qwen3 Base→RL + MRAG | output/rl_qwen3-4b_grpo_mrag_full/<version>/checkpoint-<step> |
RL_BASE_MRAG_QWEN25_PATH |
Qwen2.5 Base→RL + MRAG | output/rl_qwen2.5-3b_grpo_mrag_full/<version>/checkpoint-<step> |
RL_SFT_MRAG_QWEN3_PATH |
Qwen3 SFT+MRAG→RL | output/rl_qwen3-4b_grpo_sft_mrag_full/<version>/checkpoint-<step> |
RL_SFT_MRAG_QWEN25_PATH |
Qwen2.5 SFT+MRAG→RL | output/rl_qwen2.5-3b_grpo_sft_mrag_full/<version>/checkpoint-<step> |
If you only want to reproduce the main experiment results without training, follow these steps. The pre-trained model checkpoint is available on Google Drive.
| Item | How to Get |
|---|---|
Model checkpoint (JuDGE_RL.tar.gz) |
Download from Google Drive (~8GB, SFT+MRAG+RL trained Qwen3-4B) |
bert-base-chinese |
Download from HuggingFace (required for BERTScore in evaluation) |
| GPU | 1x GPU with >= 16GB VRAM (inference only) |
Only two environments are needed for evaluation:
# For inference (vLLM)
conda create -n vllm python=3.10 -y && conda activate vllm
pip install -r requirements_vllm.txt
# For evaluation metrics (BERTScore, METEOR)
conda create -n judge python=3.10 -y && conda activate judge
pip install -r requirements_judge.txtDownload JuDGE_RL.tar.gz from Google Drive and extract it:
# Method 1: Download via browser, then extract
tar -xzf JuDGE_RL.tar.gz -C .
# Method 2: Download via gdown (pip install gdown)
gdown 1lquq4EePHRQWE8wOWdsFwEUpNzyiZolx
tar -xzf JuDGE_RL.tar.gz -C .After extraction, check the model directory path and note it for the next step:
# Verify the directory contains model files
ls JuDGE_R1/release_model # or your extracted directory name
# config.json model-00001-of-00002.safetensors model-00002-of-00002.safetensors
# model.safetensors.index.json tokenizer.json added_tokens.json
#merges.txt special_tokens_map.json chat_template.jinja
#vocab.json generation_config.jsoncd Judge-R1
mkdir -p outputs
conda activate vllm
export CUDA_VISIBLE_DEVICES=0
# Main experiment: SFT+MRAG+RL model on MRAG test set
# Replace <MODEL_PATH> with your extracted model directory (e.g., checkpoint-501)
python train/deploy/inf.py \
--model_path <MODEL_PATH> \
--dataset_path data/test_sft_mrag.json \
--output_path outputs/qwen3_sft_mrag_rl_raw.json \
--mode rl \
--tensor_parallel_size 1 \
--gpu_memory_utilization 0.85# Convert inference output to evaluation format
python -c "
import json
fd2id = {}
with open('data/test.json', 'r') as f:
for line in f:
obj = json.loads(line)
fd2id[obj['fd']] = obj['text_id']
data = json.load(open('outputs/qwen3_sft_mrag_rl_raw.json', 'r'))
with open('outputs/qwen3_sft_mrag_rl.jsonl', 'w') as out:
for item in data:
cid = item.get('text_id') or fd2id.get(item.get('exp_ans'))
gen = item.get('gen_ans')
if cid and gen is not None:
out.write(json.dumps({'id': cid, 'document': gen}, ensure_ascii=False) + '\n')
print('Converted to outputs/qwen3_sft_mrag_rl.jsonl')
"conda activate judge
export BERT_MODEL_PATH="/path/to/bert-base-chinese" # Required for BERTScore
cd evaluation
# Legal accuracy (Crime F1, Law Article F1, Prison Score, Fine Score)
python calc.py \
--gen_file ../outputs/qwen3_sft_mrag_rl.jsonl \
--exp_file ../data/expected.jsonl
# Text quality (METEOR, BERTScore)
python calc_rel.py \
--gen_file ../outputs/qwen3_sft_mrag_rl.jsonl \
--exp_file ../data/expected.jsonlThe MRAG test set (data/test_sft_mrag.json) already contains pre-computed retrieval results embedded in each prompt (top-10 statutes + similar cases). No retrieval models are needed for evaluation.
To inspect the retrieval quality and explainability:
mrag/retriever_output/ablation_both_rl_eval.txt— Retrieval evaluation metrics (Recall@K, MRR, etc.)mrag/retriever_output/ablation_both_rl_details.json— Per-case explainability: generated queries, selected statutes with reasons, rejected statutes with reasons
conda activate swift
bash bash/data_train.sh # Standard modeconda activate judge
bash bash/retriever/kfold_train_retriever.sh # Dense Retriever
bash bash/retriever/encode_corpus.sh # Encode corpus
bash bash/retriever/retrieve.sh # Run retrieval
bash bash/reranker/kfold_train_reranker.sh # Reranker
bash bash/reranker/run_reranker.sh # Run reranking
bash bash/retriever/eval_retriever.sh # Evaluate retrieval
# Generate MRAG training data
conda activate swift
USE_MRAG=true bash bash/data_train.shconda activate swift
bash bash/agent/prepare_agent_rl_data.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 bash bash/agent/train_rl_querygen.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 bash bash/agent/train_rl_lawselect.sh
bash bash/agent/merge_agent_lora.sh querygen
bash bash/agent/merge_agent_lora.sh lawselect
CUDA_VISIBLE_DEVICES=0 bash bash/agent/eval_ablation.shconda activate swift
# Qwen3-4B
MODEL_NAME=qwen3 bash bash/train_sft.sh
MODEL_NAME=qwen3 USE_MRAG=true bash bash/train_sft.sh
MERGE_CONFIG=sft_qwen3 bash bash/loramerge.sh
MERGE_CONFIG=sft_qwen3_mrag bash bash/loramerge.sh
MODEL_NAME=qwen3 bash bash/train_rl.sh
MODEL_NAME=qwen3 USE_MRAG=true bash bash/train_rl.sh
# Qwen2.5-3B
MODEL_NAME=qwen2 bash bash/train_sft.sh
MODEL_NAME=qwen2 USE_MRAG=true bash bash/train_sft.sh
MERGE_CONFIG=sft_qwen2 bash bash/loramerge.sh
MERGE_CONFIG=sft_qwen2_mrag bash bash/loramerge.sh
MODEL_NAME=qwen2 bash bash/train_rl.sh
MODEL_NAME=qwen2 USE_MRAG=true bash bash/train_rl.shconda activate swift
MODES=all bash bash/gen.sh # All models x all 9 modes9 inference modes:
| Mode | Model | Data | Description |
|---|---|---|---|
direct |
Base | Raw | Zero-shot |
icl |
Base | Raw | Few-shot |
sft |
SFT | Standard | Supervised fine-tuned |
mrag |
Base | MRAG | Base + retrieval |
rl |
Base->RL | Standard | RL only |
sft_mrag |
SFT+MRAG | MRAG | SFT + retrieval |
sft_rl |
SFT->RL | Standard | SFT + RL |
mrag_rl |
Base->RL | MRAG | RL + retrieval |
sft_mrag_rl |
SFT+MRAG->RL | MRAG | Full pipeline (best) |
conda activate swift
bash bash/convert.sh
conda activate judge
bash bash/eval.sh
cat result/eval_summary.txtAll scripts are controlled via environment variables.
| Variable | Default | Options |
|---|---|---|
MODEL_NAME |
qwen2 |
qwen3 (Qwen3-4B), qwen2 (Qwen2.5-3B) |
USE_MRAG |
false |
true to use MRAG training data |
| Variable | Default | Options |
|---|---|---|
MODEL_NAME |
qwen3 |
qwen3 or qwen2 |
EXPERIMENT |
sft_full |
sft_full (GRPO on SFT model), base_full (GRPO on base model) |
USE_MRAG |
false |
true to use MRAG data |
USE_VLLM |
false |
true to use external vLLM server |
| Variable | Default | Options |
|---|---|---|
MODEL_NAME |
qwen3,qwen2 |
Comma-separated model names |
MODES |
all |
Comma-separated from the 9 modes above |
| Metric | Description |
|---|---|
| Crime F1 | Crime identification F1 score |
| Law Article F1 | Statute citation F1 score |
| Prison Time Score | Sentence term matching (closer = better) |
| Fine Amount Score | Fine amount matching (closer = better) |
| Metric | Description |
|---|---|
| METEOR | Text similarity (segment-level: reasoning + sentencing) |
| BERTScore | Semantic similarity (segment-level: reasoning + sentencing) |
Evaluation first segments the judgment into "reasoning" and "sentencing" sections via evaluation/segment/, then computes metrics for each section separately.
| Experiment | Comparison | Output Files |
|---|---|---|
| Base model | Qwen2.5 vs Qwen3 | qwen25_* / qwen3_* |
| Training stage | Direct -> ICL -> SFT -> SFT+RL | *_direct / *_icl / *_sft / *_sft_rl |
| Retrieval augmentation | w/o retrieval vs MRAG | *_sft vs *_sft_mrag |
| Retrieval + RL | SFT+RL vs SFT+MRAG+RL | *_sft_rl vs *_sft_mrag_rl |
| Retrieval components | Dense only vs Dense+Reranker | eval_retriever.sh output |
| Agent components | +/-QueryGen RL x +/-LawSelect RL | eval_ablation.sh output |
| Retrieval source | MRAG vs Agent vs Hybrid | eval_ablation.sh + fuse_results.py output |
| Script | Scope | Notes |
|---|---|---|
bash/loramerge.sh |
SFT models (Qwen3/Qwen2.5) | Includes extract_lora step for DeepSpeed |
bash/agent/merge_agent_lora.sh |
Agent RL models | Auto-finds latest checkpoint (ms-swift) |
Do not mix: use loramerge.sh for SFT models, merge_agent_lora.sh for Agent RL models.
This project is licensed under the MIT License. See LICENSE for details.