From 2d0922c13f99e6ee079aab8d54bab536e9777e4a Mon Sep 17 00:00:00 2001 From: Alex Korbonits Date: Sun, 12 Apr 2026 22:22:34 -0700 Subject: [PATCH] ci: add Ruff config, pre-commit hooks, and GitHub Actions lint workflow - Add Ruff config (E/F/I rules, per-file ignores for __init__.py) - Add .pre-commit-config.yaml with ruff lint + format hooks - Add .github/workflows/lint.yml triggering on PRs - Auto-fix 24 violations (unsorted imports, unused imports, f-strings) Co-Authored-By: Claude Sonnet 4.6 --- .github/workflows/lint.yml | 17 +++++++++++++++++ .pre-commit-config.yaml | 7 +++++++ data_preprocess/aime2024_rstar2_agent_loop.py | 2 -- data_preprocess/aime2025_rstar2_agent_loop.py | 2 -- data_preprocess/dapo_rstar2_agent_loop.py | 2 -- data_preprocess/math500_rstar2_agent_loop.py | 2 -- examples/chat_with_tool_call.py | 9 ++++----- fused_compute_score/__init__.py | 3 ++- fused_compute_score/prime_math/grader.py | 6 +++--- pyproject.toml | 13 +++++++++++++ rstar2_agent/__init__.py | 2 +- rstar2_agent/down_sample/reject_sampling.py | 2 +- rstar2_agent/down_sample/roc.py | 9 +++++---- rstar2_agent/down_sample/utils.py | 2 +- rstar2_agent/main_rstar2_agent.py | 4 +--- rstar2_agent/reward/compute_score.py | 3 ++- rstar2_agent/reward/server.py | 9 ++++----- rstar2_agent/rstar2_agent_ray_trainer.py | 5 ++--- rstar2_agent/tools/__init__.py | 2 +- rstar2_agent/tools/code_judge_tool.py | 7 +++---- rstar2_agent/tools/code_judge_utils.py | 10 +++++----- rstar2_agent/tools/request_processor.py | 7 ++++--- rstar2_agent/tools/tool_parser.py | 3 +-- 23 files changed, 77 insertions(+), 51 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..fe58f73 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,17 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - run: uv python install 3.10 + - run: uv pip install ruff + - run: uv run ruff check . + - run: uv run ruff format --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..700cbc4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/data_preprocess/aime2024_rstar2_agent_loop.py b/data_preprocess/aime2024_rstar2_agent_loop.py index fcb7d6d..b6f6f29 100644 --- a/data_preprocess/aime2024_rstar2_agent_loop.py +++ b/data_preprocess/aime2024_rstar2_agent_loop.py @@ -9,10 +9,8 @@ import os import datasets - from verl.utils.hdfs_io import copy, makedirs - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local_dir", default="~/data/rstar2-agent/aime2024") diff --git a/data_preprocess/aime2025_rstar2_agent_loop.py b/data_preprocess/aime2025_rstar2_agent_loop.py index e226f89..e2b5189 100644 --- a/data_preprocess/aime2025_rstar2_agent_loop.py +++ b/data_preprocess/aime2025_rstar2_agent_loop.py @@ -9,10 +9,8 @@ import os import datasets - from verl.utils.hdfs_io import copy, makedirs - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local_dir", default="~/data/rstar2-agent/aime2025") diff --git a/data_preprocess/dapo_rstar2_agent_loop.py b/data_preprocess/dapo_rstar2_agent_loop.py index 1cd7f90..5b1b520 100644 --- a/data_preprocess/dapo_rstar2_agent_loop.py +++ b/data_preprocess/dapo_rstar2_agent_loop.py @@ -9,10 +9,8 @@ import os import datasets - from verl.utils.hdfs_io import copy, makedirs - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local_dir", default="~/data/rstar2-agent/dapo-math-17k-en") diff --git a/data_preprocess/math500_rstar2_agent_loop.py b/data_preprocess/math500_rstar2_agent_loop.py index 76bf014..431e5d0 100644 --- a/data_preprocess/math500_rstar2_agent_loop.py +++ b/data_preprocess/math500_rstar2_agent_loop.py @@ -9,10 +9,8 @@ import os import datasets - from verl.utils.hdfs_io import copy, makedirs - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local_dir", default="~/data/rstar2-agent/math500") diff --git a/examples/chat_with_tool_call.py b/examples/chat_with_tool_call.py index eafc471..a959a1a 100644 --- a/examples/chat_with_tool_call.py +++ b/examples/chat_with_tool_call.py @@ -14,22 +14,21 @@ curl http://localhost:8000/v1/models """ -import aiohttp import argparse import asyncio import json -import requests -import yaml - from pathlib import Path +import aiohttp +import requests +import yaml from transformers import AutoTokenizer, PreTrainedTokenizer from verl.tools.schemas import ToolResponse from rstar2_agent.tools.code_judge_utils import ( - run_tool_calls_on_server_async, generate_tool_call_code, generate_tool_call_input, + run_tool_calls_on_server_async, ) from rstar2_agent.tools.tool_parser import ( RStar2AgentHermesToolParser, diff --git a/fused_compute_score/__init__.py b/fused_compute_score/__init__.py index c0c8477..66aed43 100644 --- a/fused_compute_score/__init__.py +++ b/fused_compute_score/__init__.py @@ -1,5 +1,6 @@ -from .prime_math import compute_score as prime_compute_score from .math_verify import compute_score as math_verify_compute_score +from .prime_math import compute_score as prime_compute_score + def compute_score(model_output: str, ground_truth: str) -> bool: try: diff --git a/fused_compute_score/prime_math/grader.py b/fused_compute_score/prime_math/grader.py index 1a12ad5..5dbcd69 100644 --- a/fused_compute_score/prime_math/grader.py +++ b/fused_compute_score/prime_math/grader.py @@ -381,12 +381,12 @@ def format_intervals(prediction): return prediction +import multiprocessing import os -import signal import queue -import multiprocessing +import signal from functools import wraps -from typing import Callable, Any +from typing import Any, Callable def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]): diff --git a/pyproject.toml b/pyproject.toml index d3ba134..71ac06b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,16 @@ dependencies = [ [tool.setuptools] packages = ["rstar2_agent", "fused_compute_score"] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "F841", "F403", "F405"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.uv] +# Install with: uv sync diff --git a/rstar2_agent/__init__.py b/rstar2_agent/__init__.py index b019cab..d6b517a 100644 --- a/rstar2_agent/__init__.py +++ b/rstar2_agent/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from .reward import CodeJudgeRewardManager from .rollout.rstar2_agent_loop import RStar2AgentLoop from .tools import RStar2AgentHermesToolParser -from .reward import CodeJudgeRewardManager diff --git a/rstar2_agent/down_sample/reject_sampling.py b/rstar2_agent/down_sample/reject_sampling.py index d6e0a5c..8df5d83 100644 --- a/rstar2_agent/down_sample/reject_sampling.py +++ b/rstar2_agent/down_sample/reject_sampling.py @@ -3,8 +3,8 @@ import numpy as np import torch - from verl.protocol import DataProto + from .utils import filter_by_mask diff --git a/rstar2_agent/down_sample/roc.py b/rstar2_agent/down_sample/roc.py index d552365..da1a27a 100644 --- a/rstar2_agent/down_sample/roc.py +++ b/rstar2_agent/down_sample/roc.py @@ -2,14 +2,15 @@ # Licensed under the MIT license. import re -import numpy as np -import torch from pprint import pprint from typing import List -from transformers import PreTrainedTokenizerFast +import numpy as np +import torch +from transformers import PreTrainedTokenizerFast from verl.protocol import DataProto -from .utils import filter_by_mask, decode_prompt_response_str + +from .utils import decode_prompt_response_str, filter_by_mask def resample_of_correct(batch: DataProto, tokenizer: PreTrainedTokenizerFast, config: dict, do_sample=True, world_size=None): diff --git a/rstar2_agent/down_sample/utils.py b/rstar2_agent/down_sample/utils.py index c473a91..2a69a9a 100644 --- a/rstar2_agent/down_sample/utils.py +++ b/rstar2_agent/down_sample/utils.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import torch -from verl.protocol import DataProto, DataProtoItem +from verl.protocol import DataProto def filter_by_mask(batch: DataProto, mask: torch.Tensor, num_trainer_replicas: int) -> DataProto: diff --git a/rstar2_agent/main_rstar2_agent.py b/rstar2_agent/main_rstar2_agent.py index 6773b49..3524abb 100644 --- a/rstar2_agent/main_rstar2_agent.py +++ b/rstar2_agent/main_rstar2_agent.py @@ -11,7 +11,6 @@ import hydra import ray from omegaconf import OmegaConf - from verl.trainer.constants_ppo import get_ppo_ray_runtime_env from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler from verl.trainer.ppo.reward import load_reward_manager @@ -199,7 +198,6 @@ def run(self, config): from pprint import pprint from omegaconf import OmegaConf - from verl.utils.fs import copy_to_local print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") @@ -244,7 +242,7 @@ def run(self, config): tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path tool_list = [] if tool_config_path is not None: - from verl.tools.utils.tool_registry import ToolType, get_tool_class, OpenAIFunctionToolSchema + from verl.tools.utils.tool_registry import OpenAIFunctionToolSchema, ToolType tools_config = OmegaConf.load(tool_config_path) for tool_config in tools_config.tools: tool_type = ToolType(tool_config.config.type) diff --git a/rstar2_agent/reward/compute_score.py b/rstar2_agent/reward/compute_score.py index f05b8f7..6a0a83e 100644 --- a/rstar2_agent/reward/compute_score.py +++ b/rstar2_agent/reward/compute_score.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from verl.utils.reward_score.prime_math import compute_score as prime_compute_score from verl.utils.reward_score.math_verify import compute_score as math_verify_compute_score +from verl.utils.reward_score.prime_math import compute_score as prime_compute_score + def compute_score(model_output: str, ground_truth: str) -> bool: try: diff --git a/rstar2_agent/reward/server.py b/rstar2_agent/reward/server.py index 1ff9553..6df44ba 100644 --- a/rstar2_agent/reward/server.py +++ b/rstar2_agent/reward/server.py @@ -1,20 +1,19 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio +import base64 +import re from collections import defaultdict from typing import Any import aiohttp -import asyncio -import base64 -import re import torch - -from verl import DataProto from verl.workers.reward_manager import register from verl.workers.reward_manager.abstract import AbstractRewardManager from rstar2_agent.tools.code_judge_utils import run_tool_calls_on_server_async +from verl import DataProto verify_math_prefix = """ from fused_compute_score import compute_score diff --git a/rstar2_agent/rstar2_agent_ray_trainer.py b/rstar2_agent/rstar2_agent_ray_trainer.py index 02f5da2..aa4be0b 100644 --- a/rstar2_agent/rstar2_agent_ray_trainer.py +++ b/rstar2_agent/rstar2_agent_ray_trainer.py @@ -9,8 +9,6 @@ import ray import torch from tqdm import tqdm - -from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss from verl.trainer.ppo.metric_utils import ( @@ -30,6 +28,8 @@ from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip +from verl import DataProto + from .down_sample import reject_equal_reward, resample_of_correct @@ -79,7 +79,6 @@ def fit(self): Most logic is same with RayPPOTrainer, mainly add down sample related. """ from omegaconf import OmegaConf - from verl.utils.tracking import Tracking logger = Tracking( diff --git a/rstar2_agent/tools/__init__.py b/rstar2_agent/tools/__init__.py index b613806..bed9d2d 100644 --- a/rstar2_agent/tools/__init__.py +++ b/rstar2_agent/tools/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from .code_judge_tool import CodeJudgeTool, SimJupyterTool, PythonTool +from .code_judge_tool import CodeJudgeTool, PythonTool, SimJupyterTool from .tool_parser import RStar2AgentHermesToolParser diff --git a/rstar2_agent/tools/code_judge_tool.py b/rstar2_agent/tools/code_judge_tool.py index f8f2b26..83c46e5 100644 --- a/rstar2_agent/tools/code_judge_tool.py +++ b/rstar2_agent/tools/code_judge_tool.py @@ -1,19 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json from functools import partial from typing import Any, Optional from uuid import uuid4 import aiohttp -import json - -from verl.utils.rollout_trace import rollout_trace_op from verl.tools.base_tool import BaseTool from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse +from verl.utils.rollout_trace import rollout_trace_op +from .code_judge_utils import generate_tool_call_code, generate_tool_call_input, run_tool_calls_on_server_async from .request_processor import RequestProcessor -from .code_judge_utils import run_tool_calls_on_server_async, generate_tool_call_code, generate_tool_call_input class CodeJudgeTool(BaseTool): diff --git a/rstar2_agent/tools/code_judge_utils.py b/rstar2_agent/tools/code_judge_utils.py index ef8a2e6..73472fe 100644 --- a/rstar2_agent/tools/code_judge_utils.py +++ b/rstar2_agent/tools/code_judge_utils.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import json -import aiohttp import asyncio -import traceback -import os import datetime +import json +import os +import traceback +from typing import Callable, Dict, List, Literal, Optional -from typing import Dict, List, Literal, Callable, Optional +import aiohttp # Global variable to store the path for failed submissions _failed_submissions_path = os.path.expanduser("~") diff --git a/rstar2_agent/tools/request_processor.py b/rstar2_agent/tools/request_processor.py index 4f9eca3..c3fff66 100644 --- a/rstar2_agent/tools/request_processor.py +++ b/rstar2_agent/tools/request_processor.py @@ -2,12 +2,13 @@ # Licensed under the MIT license. import asyncio +import collections import time +import traceback import uuid -import collections +from typing import Any, Awaitable, Callable, Dict, List + import aiohttp -import traceback -from typing import List, Dict, Any, Callable, Awaitable # Define the expected signature for the batch submission function # It should be an async callable that takes: diff --git a/rstar2_agent/tools/tool_parser.py b/rstar2_agent/tools/tool_parser.py index 9bd20bf..0485ae3 100644 --- a/rstar2_agent/tools/tool_parser.py +++ b/rstar2_agent/tools/tool_parser.py @@ -7,8 +7,7 @@ import os import regex as re - -from verl.experimental.agent_loop.tool_parser import ToolParser, FunctionCall +from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser from verl.utils.rollout_trace import rollout_trace_op logger = logging.getLogger(__file__)