Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Lint

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uv python install 3.10
- run: uv pip install ruff
- run: uv run ruff check .
- run: uv run ruff format --check .
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
2 changes: 0 additions & 2 deletions data_preprocess/aime2024_rstar2_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import os

import datasets

from verl.utils.hdfs_io import copy, makedirs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/rstar2-agent/aime2024")
Expand Down
2 changes: 0 additions & 2 deletions data_preprocess/aime2025_rstar2_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import os

import datasets

from verl.utils.hdfs_io import copy, makedirs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/rstar2-agent/aime2025")
Expand Down
2 changes: 0 additions & 2 deletions data_preprocess/dapo_rstar2_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import os

import datasets

from verl.utils.hdfs_io import copy, makedirs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/rstar2-agent/dapo-math-17k-en")
Expand Down
2 changes: 0 additions & 2 deletions data_preprocess/math500_rstar2_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import os

import datasets

from verl.utils.hdfs_io import copy, makedirs


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="~/data/rstar2-agent/math500")
Expand Down
9 changes: 4 additions & 5 deletions examples/chat_with_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,21 @@

curl http://localhost:8000/v1/models
"""
import aiohttp
import argparse
import asyncio
import json
import requests
import yaml

from pathlib import Path

import aiohttp
import requests
import yaml
from transformers import AutoTokenizer, PreTrainedTokenizer
from verl.tools.schemas import ToolResponse

from rstar2_agent.tools.code_judge_utils import (
run_tool_calls_on_server_async,
generate_tool_call_code,
generate_tool_call_input,
run_tool_calls_on_server_async,
)
from rstar2_agent.tools.tool_parser import (
RStar2AgentHermesToolParser,
Expand Down
3 changes: 2 additions & 1 deletion fused_compute_score/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .prime_math import compute_score as prime_compute_score
from .math_verify import compute_score as math_verify_compute_score
from .prime_math import compute_score as prime_compute_score


def compute_score(model_output: str, ground_truth: str) -> bool:
try:
Expand Down
6 changes: 3 additions & 3 deletions fused_compute_score/prime_math/grader.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,12 @@ def format_intervals(prediction):
return prediction


import multiprocessing
import os
import signal
import queue
import multiprocessing
import signal
from functools import wraps
from typing import Callable, Any
from typing import Any, Callable


def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]):
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,16 @@ dependencies = [

[tool.setuptools]
packages = ["rstar2_agent", "fused_compute_score"]

[tool.ruff]
line-length = 120

[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

[tool.uv]
# Install with: uv sync
2 changes: 1 addition & 1 deletion rstar2_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .reward import CodeJudgeRewardManager
from .rollout.rstar2_agent_loop import RStar2AgentLoop
from .tools import RStar2AgentHermesToolParser
from .reward import CodeJudgeRewardManager
2 changes: 1 addition & 1 deletion rstar2_agent/down_sample/reject_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import numpy as np
import torch

from verl.protocol import DataProto

from .utils import filter_by_mask


Expand Down
9 changes: 5 additions & 4 deletions rstar2_agent/down_sample/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# Licensed under the MIT license.

import re
import numpy as np
import torch
from pprint import pprint
from typing import List
from transformers import PreTrainedTokenizerFast

import numpy as np
import torch
from transformers import PreTrainedTokenizerFast
from verl.protocol import DataProto
from .utils import filter_by_mask, decode_prompt_response_str

from .utils import decode_prompt_response_str, filter_by_mask


def resample_of_correct(batch: DataProto, tokenizer: PreTrainedTokenizerFast, config: dict, do_sample=True, world_size=None):
Expand Down
2 changes: 1 addition & 1 deletion rstar2_agent/down_sample/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import torch
from verl.protocol import DataProto, DataProtoItem
from verl.protocol import DataProto


def filter_by_mask(batch: DataProto, mask: torch.Tensor, num_trainer_replicas: int) -> DataProto:
Expand Down
4 changes: 1 addition & 3 deletions rstar2_agent/main_rstar2_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import hydra
import ray
from omegaconf import OmegaConf

from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl.trainer.ppo.reward import load_reward_manager
Expand Down Expand Up @@ -199,7 +198,6 @@ def run(self, config):
from pprint import pprint

from omegaconf import OmegaConf

from verl.utils.fs import copy_to_local

print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
Expand Down Expand Up @@ -244,7 +242,7 @@ def run(self, config):
tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
tool_list = []
if tool_config_path is not None:
from verl.tools.utils.tool_registry import ToolType, get_tool_class, OpenAIFunctionToolSchema
from verl.tools.utils.tool_registry import OpenAIFunctionToolSchema, ToolType
tools_config = OmegaConf.load(tool_config_path)
for tool_config in tools_config.tools:
tool_type = ToolType(tool_config.config.type)
Expand Down
3 changes: 2 additions & 1 deletion rstar2_agent/reward/compute_score.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from verl.utils.reward_score.prime_math import compute_score as prime_compute_score
from verl.utils.reward_score.math_verify import compute_score as math_verify_compute_score
from verl.utils.reward_score.prime_math import compute_score as prime_compute_score


def compute_score(model_output: str, ground_truth: str) -> bool:
try:
Expand Down
9 changes: 4 additions & 5 deletions rstar2_agent/reward/server.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import base64
import re
from collections import defaultdict
from typing import Any

import aiohttp
import asyncio
import base64
import re
import torch

from verl import DataProto
from verl.workers.reward_manager import register
from verl.workers.reward_manager.abstract import AbstractRewardManager

from rstar2_agent.tools.code_judge_utils import run_tool_calls_on_server_async
from verl import DataProto

verify_math_prefix = """
from fused_compute_score import compute_score
Expand Down
5 changes: 2 additions & 3 deletions rstar2_agent/rstar2_agent_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import ray
import torch
from tqdm import tqdm

from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
from verl.trainer.ppo.metric_utils import (
Expand All @@ -30,6 +28,8 @@
from verl.utils.metric import reduce_metrics
from verl.utils.rollout_skip import RolloutSkip

from verl import DataProto

from .down_sample import reject_equal_reward, resample_of_correct


Expand Down Expand Up @@ -79,7 +79,6 @@ def fit(self):
Most logic is same with RayPPOTrainer, mainly add down sample related.
"""
from omegaconf import OmegaConf

from verl.utils.tracking import Tracking

logger = Tracking(
Expand Down
2 changes: 1 addition & 1 deletion rstar2_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .code_judge_tool import CodeJudgeTool, SimJupyterTool, PythonTool
from .code_judge_tool import CodeJudgeTool, PythonTool, SimJupyterTool
from .tool_parser import RStar2AgentHermesToolParser
7 changes: 3 additions & 4 deletions rstar2_agent/tools/code_judge_tool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
from functools import partial
from typing import Any, Optional
from uuid import uuid4

import aiohttp
import json

from verl.utils.rollout_trace import rollout_trace_op
from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse
from verl.utils.rollout_trace import rollout_trace_op

from .code_judge_utils import generate_tool_call_code, generate_tool_call_input, run_tool_calls_on_server_async
from .request_processor import RequestProcessor
from .code_judge_utils import run_tool_calls_on_server_async, generate_tool_call_code, generate_tool_call_input


class CodeJudgeTool(BaseTool):
Expand Down
10 changes: 5 additions & 5 deletions rstar2_agent/tools/code_judge_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import aiohttp
import asyncio
import traceback
import os
import datetime
import json
import os
import traceback
from typing import Callable, Dict, List, Literal, Optional

from typing import Dict, List, Literal, Callable, Optional
import aiohttp

# Global variable to store the path for failed submissions
_failed_submissions_path = os.path.expanduser("~")
Expand Down
7 changes: 4 additions & 3 deletions rstar2_agent/tools/request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Licensed under the MIT license.

import asyncio
import collections
import time
import traceback
import uuid
import collections
from typing import Any, Awaitable, Callable, Dict, List

import aiohttp
import traceback
from typing import List, Dict, Any, Callable, Awaitable

# Define the expected signature for the batch submission function
# It should be an async callable that takes:
Expand Down
3 changes: 1 addition & 2 deletions rstar2_agent/tools/tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import os

import regex as re

from verl.experimental.agent_loop.tool_parser import ToolParser, FunctionCall
from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser
from verl.utils.rollout_trace import rollout_trace_op

logger = logging.getLogger(__file__)
Expand Down