# 基于AReaL强化学习训练长程搜索智能体（Search Agent）

本教程利用AReaL中的基本组件快速搭建一个强化学习流程，用来训练一个可以进行长程搜索的智能体。

该教程包括以下步骤：
1. 实验准备（包括从yaml加载实验配置，配置环境变量，启动SGLang服务器，启动本地RAG服务器，加载训练数据集）
2. 定义简单的工作流，多次调用搜索工具；
3. 每次生成**多条**轨迹（i.e., GRPO)；
4. 测试工作流；
5. 将工作流接入端到端GRPO强化学习训练；

## 实验准备
### 加载实验配置

通过`load_expr_config`加载预定义的asearcher基于local RAG训练的yaml实验配置模板。

实验配置模板内配置了优化器、模板、学习率等参数，可以直接使用。

In [None]:
from dataclasses import asdict, dataclass, field

from areal.api.cli_args import GRPOConfig, load_expr_config


@dataclass
class AgentRLConfig(GRPOConfig):
    max_turns: int = field(
        default=128, metadata={"help": "maximum number of turns per trajectory"}
    )


args = ["--config", "examples/configs/search-agent/local_1.5b_example.yaml"]
config, _ = load_expr_config(args, AgentRLConfig)
config: AgentRLConfig

### 配置环境变量

我们预先分配SGLang服务器和PyTorch分布式启动的IP地址和端口，并设置相应的环境变量。

这些环境变量会在引擎初始化时被读取。

***在非notebook环境下，这些环境变量会被launcher设置，用户无需自行设置。***

In [None]:
from areal.utils.network import find_free_ports

SGLANG_PORT, MASTER_PORT = 11451, 14514

SGLANG_HOST = "127.0.0.1"

# ----------------------------------------------------------------------------
# Environment variables used by inference/train engines
import os
import subprocess
import sys

os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{SGLANG_HOST}:{SGLANG_PORT}"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["LOCAL_RANK"] = str(0)

### 启动SGLang服务器

AReaL默认采用训推分离式架构，推理和训练异步执行，能够打满GPU资源、快速完成端到端训练。

在这个样例中，强化学习的算法编排（GRPO）运行在GPU 0上。

GPU 1运行一个推理服务器，强化学习的算法编排可以向GPU 1上的推理服务发送生成请求。

下面的代码块在GPU 1上启动对应的推理服务。

本次教程中使用`Qwen/Qwen2.5-1.5B`作为例子。

In [None]:
# 启动sglang server
from areal.api.cli_args import SGLangConfig
from areal.utils.network import find_free_ports

config.sglang.log_level = "info"
config.sglang.decode_log_interval = 10
sglang_cmd = SGLangConfig.build_cmd(
    config.sglang,
    tp_size=1,
    base_gpu_id=1,
    host=SGLANG_HOST,
    port=SGLANG_PORT,
)
sglang_process = subprocess.Popen(
    sglang_cmd,
    shell=True,
    stdout=sys.stdout,
    stderr=sys.stderr,
)

print("sglang process is launched")

### 加载训练数据集

使用HuggingFace `datasets` 包加载训练数据集，并查看数据集格式

In [None]:
# load search dataset
from datasets import load_dataset

print("dataset is at {}".format(config.train_dataset.path))
dataset = load_dataset(
    path="json",
    split="train",
    data_files=config.train_dataset.path,
)
print(f">>> dataset column names: {dataset.column_names}")
print(f">>> example data: {dataset[0]}")

### 导入必要的包和模块


In [None]:
import asyncio
import json
import os
import sys
import time
import uuid

import numpy as np
import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from areal.api.cli_args import (
    GenerationHyperparameters,
    load_expr_config,
)
from areal.api.engine_api import InferenceEngine
from areal.api.io_struct import (
    AllocationMode,
    FinetuneSpec,
    ModelRequest,
    WeightUpdateMeta,
)
from areal.engine.ppo.actor import FSDPPPOActor
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.utils.data import concat_padded_tensors

tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

使用`torchdata.stateful_dataloader.StatefulDataLoader` 作为dataloader

In [None]:
# setup dataloader

from torchdata.stateful_dataloader import StatefulDataLoader

dataloader = StatefulDataLoader(
    dataset,
    batch_size=config.train_dataset.batch_size,
    shuffle=True,
    collate_fn=lambda x: x,
    drop_last=True,
)

from itertools import cycle

data_generator = cycle(dataloader)

ft_spec = FinetuneSpec(
    total_train_epochs=config.total_train_epochs,
    dataset_size=len(dataloader) * config.train_dataset.batch_size,
    train_batch_size=config.train_dataset.batch_size,
)

batch = next(data_generator)
print(f">>> The type of a batch is: {type(batch)}\n")
print(f">>> Each piece of data has keys: {batch[0].keys()}\n")
print(f">>> Example input question: {batch[0]['question']}\n")

### 配置搜索工具

本地RAG服务器部署方式请见：[ASearcher 仓库](https://github.com/inclusionAI/ASearcher/blob/main/docs/training.md#b-training-a-search-agent-with-local-knowledge-base) - Step 2.

通过5001端口给本地RAG服务器发送查询，并接收结果。

如下展示了一个搜索关键词"China"的例子。

In [None]:
# setup tool

import asyncio
import json

import aiohttp

TOOL_SERVER_ADDR = "localhost:5001"


async def call_search_tool(**req_meta):
    async with aiohttp.ClientSession() as session:
        async with session.post(
            f"http://{TOOL_SERVER_ADDR}/retrieve",
            json=req_meta,
            timeout=aiohttp.ClientTimeout(total=120, sock_connect=120),
        ) as response:
            response.raise_for_status()
            res = await response.json()
            return res["result"]


result = (await call_search_tool(queries=["China"], topk=5, return_scores=False))[0]
print(json.dumps(result, indent=4))

## 定义简单的智能体工作流

### 模型输出

使用prompt控制模型输出，模型输出应遵循特定格式：
- `<think></think>` 包含模型思考过程
- `<search></search>` 包含给本地RAG服务器的查询
- `<answer></answer>` 包含模型输出的答案

此外，使用`<information></information>` 包含RAG服务器返回的查询内容。

In [None]:
PROMPT_TEMPLATE = """A conversation between User and Assistant. The user asks a question, and the Assistant answers it. The Assistant analyzes the given question and information in the mind, retains important relevant information, calls a search engine to find necessary information, accesses web pages with certain urls, and provides the user with the answer. The Assistant conducts search by <search> query </search> and the top search results will be returned between <information> and </information>. The reasoning processes are enclosed within <think> </think>. Finally, the Assistant provides answer inside <answer> and </answer>, i.e. <answer> answer here </answer>. If there are multiple queries, ensure all answers are enclosed within <answer> </answer>, seperated with comma. 

User: 
{question}

Assistant:
<think>"""

batch = next(data_generator)
prompt = PROMPT_TEMPLATE.format(question=batch[0]["question"])

print(f">>> PROMPT: {prompt}")

通过`RemoteSGlangEngine`向已经启动的SGLang服务器发送生成请求，测试上述prompt控制下的模型输出。

In [None]:
asyncio.get_running_loop()

In [None]:
# initialize inference engine
rollout_engine = RemoteSGLangEngine(config.rollout)
rollout_engine.initialize(None, None)

# generation config
gconfig = GenerationHyperparameters(
    max_new_tokens=512, stop=["</search>", "</answer>", "</access>"]
)

# tokenize the prompt
input_ids = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
req = ModelRequest(rid=uuid.uuid4().hex, input_ids=input_ids, gconfig=gconfig)

# generate rollout with inference engine
resp = await rollout_engine.agenerate(req)
completion_str = tokenizer.decode(resp.output_tokens)

# logging
print(f">>> prompt str: {tokenizer.decode(resp.input_tokens)}")
print(f">>> generated: {tokenizer.decode(resp.output_tokens)}")

#### 解析智能体工具调用

定义`parse_search_query`函数从模型输出解析调用搜索工具的查询。

In [None]:
# parse tool calling

import re


def parse_search_query(text):
    pattern = r"<search>(.*?)</search>"
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return matches[-1].strip()
    return None


test_tool_str = "<think> I would like to search for AI.</think>\n<search> Artificial Intelligence </search>"
print(">>> input: ", test_tool_str)
print(">>> search query: ", parse_search_query(test_tool_str))

在模型输出上测试工具调用解析函数`parse_search_query`。

In [None]:
# generate rollout with inference engine
resp = await rollout_engine.agenerate(req)
completion_str = tokenizer.decode(resp.output_tokens)

# logging
print(f">>> prompt str: {tokenizer.decode(resp.input_tokens)}")
print(f">>> generated: {tokenizer.decode(resp.output_tokens)}")
print(f">>> search query: {parse_search_query(completion_str)}")

#### 解析智能体答案

定义 `parse_answer` 函数从模型输出中解析模型答案。

In [None]:
# parse answer


def parse_answer(text):
    pattern = r"<answer>(.*?)</answer>"
    matches = re.findall(pattern, text, re.DOTALL)
    if matches:
        return matches[-1].strip()
    return None


test_answer_str = (
    "<think> I already found the answer! </think>\n<answer> 1997 </answer>"
)
print(">>> input: ", test_answer_str)
print(">>> answer: ", parse_answer(test_answer_str))

在模型输出上测试答案解析函数`parse_answer`。

In [None]:
# generate rollout with inference engine
resp = await rollout_engine.agenerate(req)
completion_str = tokenizer.decode(resp.output_tokens)

# logging
print(f">>> prompt str: {tokenizer.decode(resp.input_tokens)}")
print(f">>> generated: {tokenizer.decode(resp.output_tokens)}")
print(f">>> answer: {parse_answer(completion_str)}")

### 奖励函数

我们默认使用F1 score作为奖励函数。

In [None]:
# F1 reward


def f1_score(pred_ans, gt):
    # 预处理文本（此处为简化版本）
    pred_ans = pred_ans.strip().lower()
    gt = gt.strip().lower()

    pred_tokens = set(pred_ans.split())
    gt_tokens = set(gt.split())

    if not gt_tokens or not pred_tokens:
        return 0

    # 计算共同的词数
    common_tokens = pred_tokens & gt_tokens

    # 计算精确率和召回率
    precision = len(common_tokens) / len(pred_tokens) if pred_tokens else 0
    recall = len(common_tokens) / len(gt_tokens) if gt_tokens else 0

    # 计算F1分数
    f1 = 0
    if precision + recall > 0:
        f1 = 2 * (precision * recall) / (precision + recall)

    return f1


print(
    "f1_score('James Bond', 'James Bond'): {:.2f}".format(
        f1_score("James Bond", "James Bond")
    )
)
print(
    "f1_score('James Smith', 'James Bond'): {:.2f}".format(
        f1_score("James Smith", "James Bond")
    )
)

### 定义搜索智能体工作流

实现搜索智能体 (Search Agent) 的工作流非常简单，从一个初始问题出发，在每一轮：
1. 调用推理引擎进行生成，当生成到EOS、`</search>`、`</answer>`之一时停止生成
2. 如果检测到搜索查询，调用搜索工具，并将搜索结果加入到历史中
3. 如果检测到答案，计算奖励并退出循环

最后将数据组合成训练需要的形式

In [None]:
# TODO: Implement search agent workflow


class SearchAgentWorkflow:
    def __init__(self, gconfig, tokenizer, max_tokens, max_turns, verbose):
        self.gconfig = gconfig
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens
        self.max_turns = max_turns
        self.verbose = verbose

    async def arun_episode(self, engine: InferenceEngine, data):
        prompt = PROMPT_TEMPLATE.format(question=data["question"])

        # an unique trajectory rid to ensure all requests goes to the same sglang server
        rid = uuid.uuid4().hex

        # trajectory (input ids/logprobs/loss mask)
        input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        logprobs = [0.0] * len(input_ids)
        loss_mask = [0] * len(input_ids)

        answer, reward = None, 0

        num_turns = 0
        while num_turns < self.max_turns and len(input_ids) < self.max_tokens:
            num_turns += 1

            # LLM Request
            req = ModelRequest(
                rid=rid,
                input_ids=input_ids,
                gconfig=self.gconfig.new(n_samples=1),
            )
            resp = await engine.agenerate(req)
            completion_str = self.tokenizer.decode(resp.output_tokens)

            input_ids += resp.output_tokens
            input_ids += resp.output_tokens
            logprobs += resp.output_logprobs
            loss_mask += [1] * resp.output_len

            # parse search query & trigger tool call
            search_query = parse_search_query(completion_str)
            if search_query:
                search_results = (
                    await call_search_tool(
                        queries=[search_query], topk=3, return_scores=False
                    )
                )[0]
                search_results_str = (
                    "\n\n<information>\n"
                    + "\n\n".join(
                        [
                            '<p title="{}">\n{}\n</p>'.format(
                                r["wikipedia_title"], r["contents"]
                            )
                            for r in search_results
                        ]
                    )
                    + "\n</information>"
                )

                search_token_ids = self.tokenizer.encode(
                    search_results_str, add_special_tokens=False
                )
                input_ids += search_token_ids
                logprobs += [0.0] * len(search_token_ids)
                loss_mask += [0] * len(search_token_ids)

            # parse answer
            answer = parse_answer(completion_str)
            if answer:
                reward = max([f1_score(answer, gt) for gt in data["answer"]])
                break

            if input_ids[-1] in [
                self.tokenizer.pad_token_id,
                self.tokenizer.eos_token_id,
            ]:
                break

        if self.verbose:
            print(f"[LOGGING] turns={num_turns} length={len(input_ids)}")

        res = dict(
            input_ids=torch.tensor(input_ids),
            logprobs=torch.tensor(logprobs),
            loss_mask=torch.tensor(loss_mask),
            rewards=torch.tensor(float(reward)),
            attention_mask=torch.ones(len(input_ids), dtype=torch.bool),
        )
        res = {k: v.unsqueeze(0) for k, v in res.items()}
        return TensorDict(res, batch_size=[1])

#### 测试搜索智能体工作流

1. 创建推理引擎；
2. 创建工作流，设定`max_new_tokens`, `max_turns` 和 `max_tokens`；
3. 将工作流传入推理引擎进行批量生成。

In [None]:
# initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)

# TODO: create workflow
workflow = SearchAgentWorkflow(
    gconfig=GenerationHyperparameters(
        max_new_tokens=512, stop=["</answer>", "</search>"]
    ),
    tokenizer=tokenizer,
    max_tokens=4096,
    max_turns=32,
    verbose=True,
)
sample_data = next(data_generator)[:4]
res = await asyncio.gather(
    *[workflow.arun_episode(rollout, sample_data[i]) for i in range(4)]
)
res = concat_padded_tensors(res)
print(res)

rollout.destroy()

# log the trajectories
traj_lens = res["attention_mask"].sum(dim=1).numpy().tolist()
for i in range(4):
    token_ids = res["input_ids"][i, : traj_lens[i]]
    print(f">>> Trajectory {i} >>>\n{tokenizer.decode(token_ids)}")

### 让智能体工作流对每个问题生成多条轨迹


类似GRPO的算法需要针对每个问题生成一组多条轨迹。

我们可以通过一个asyncio的并行技巧同时高效地生成多个轨迹。

In [None]:
# Group generation for GRPO


class GroupedSearchAgentWorkflow:
    def __init__(self, gconfig, tokenizer, max_tokens, max_turns, group_size, verbose):
        self.gconfig = gconfig
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens
        self.max_turns = max_turns
        self.group_size = group_size
        self.verbose = verbose

    async def arun_episode(self, engine, data):
        workflows = [
            SearchAgentWorkflow(
                self.gconfig.new(n_samples=1),
                self.tokenizer,
                self.max_tokens,
                self.max_turns,
                self.verbose,
            )
            for _ in range(self.group_size)
        ]
        tasks = [workflow.arun_episode(engine, data) for workflow in workflows]
        results = await asyncio.gather(*tasks)
        return concat_padded_tensors(results)

In [None]:
# initialize inference engine
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)
try:
    # TODO: create workflow
    workflow = GroupedSearchAgentWorkflow(
        gconfig=GenerationHyperparameters(
            max_new_tokens=512, stop=["</answer>", "</search>"]
        ),
        tokenizer=tokenizer,
        max_tokens=4096,
        max_turns=32,
        group_size=4,
        verbose=True,
    )
    sample_data = next(data_generator)[:2]
    res = rollout.rollout_batch(sample_data, workflow=workflow)
    print(res)
finally:
    rollout.destroy()

## 将智能体工作流接入强化学习训练流程

上面我们已经测试好了负责的推理工作流，接下来我们需要将这个工作流接入到训练过程中。

这需要我们额外创建一个专门针对PPO的训练引擎，并在training loop中循环调用推理和训练。

In [None]:
# Training for 5 steps

workflow = GroupedSearchAgentWorkflow(
    gconfig=GenerationHyperparameters(
        max_new_tokens=512, stop=["</answer>", "</search>"]
    ),
    tokenizer=tokenizer,
    max_tokens=4096,
    max_turns=32,
    group_size=4,
    verbose=True,
)
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)

rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)

weight_update_meta = WeightUpdateMeta.from_fsdp_nccl(
    AllocationMode.from_str("sglang.d1p1t1+d1p1t1"), actor
)

warmup_steps = 1
times = []
for global_step in range(5):
    if global_step >= warmup_steps:
        tik = time.perf_counter()
    batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
    print(batch)
    batch = batch.to(actor.device)

    logp = actor.compute_logp(batch)
    batch["prox_logp"] = logp

    actor.compute_advantages(batch)

    stats = actor.ppo_update(batch)
    actor.step_lr_scheduler()

    rollout.pause()
    future = rollout.update_weights(weight_update_meta)
    actor.upload_weights(weight_update_meta)
    future.result()
    torch.cuda.synchronize()
    rollout.resume()

    actor.set_version(global_step + 1)
    rollout.set_version(global_step + 1)
    if global_step >= warmup_steps:
        times.append(time.perf_counter() - tik)
print(times)