diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 800d2e973..d65f68d4d 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -37,7 +37,7 @@ policy: tensor_parallel_size: 1 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 8100a988b..c8b73325d 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -40,7 +40,7 @@ policy: tensor_parallel_size: 4 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index de855d1cb..30bd98bee 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -33,7 +33,7 @@ policy: tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/mast/qwen3_14b_mast.yaml b/apps/mast/qwen3_14b_mast.yaml index 484a71538..5cc781eee 100644 --- a/apps/mast/qwen3_14b_mast.yaml +++ b/apps/mast/qwen3_14b_mast.yaml @@ -42,7 +42,7 @@ policy: # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/mast/qwen3_1_7b_mast.yaml b/apps/mast/qwen3_1_7b_mast.yaml index 58d879579..c691e2098 100644 --- a/apps/mast/qwen3_1_7b_mast.yaml +++ b/apps/mast/qwen3_1_7b_mast.yaml @@ -42,7 +42,7 @@ policy: # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/mast/qwen3_32b_mast.yaml b/apps/mast/qwen3_32b_mast.yaml index 47368becd..5bf9890a5 100644 --- a/apps/mast/qwen3_32b_mast.yaml +++ b/apps/mast/qwen3_32b_mast.yaml @@ -42,7 +42,7 @@ policy: # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/mast/qwen3_4b_mast.yaml b/apps/mast/qwen3_4b_mast.yaml index 92119055a..323ae3d63 100644 --- a/apps/mast/qwen3_4b_mast.yaml +++ b/apps/mast/qwen3_4b_mast.yaml @@ -42,7 +42,7 @@ policy: # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/apps/mast/qwen3_8b_mast.yaml b/apps/mast/qwen3_8b_mast.yaml index 7f2f99694..122219873 100644 --- a/apps/mast/qwen3_8b_mast.yaml +++ b/apps/mast/qwen3_8b_mast.yaml @@ -42,7 +42,7 @@ policy: # TODO: Had to disable this becasue vLLm wouldn't like # needs to revisited. disable_custom_all_reduce: true - sampling_config: + sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9793021d7..d55d23dd4 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -12,7 +12,7 @@ import sys from collections.abc import Mapping from copy import copy -from dataclasses import asdict, dataclass, field, fields +from dataclasses import dataclass, field, fields import torch import torch.distributed.checkpoint as dcp @@ -26,7 +26,7 @@ from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs from vllm.lora.request import LoRARequest from vllm.outputs import CompletionOutput, RequestOutput -from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils import get_distributed_init_method @@ -62,49 +62,6 @@ logger.setLevel(logging.INFO) -@dataclass -class SamplingConfig: - """ - Overrides for vLLMs sampling params. - - Note: We'll want to tie this closer to or directly use vllm's - SamplingParams. It is currently used to track a supported - subset - - Args: - n: Number of samples to generate. - guided_decoding: Whether to use guided decoding. - max_tokens: Maximum number of tokens to generate. - """ - - n: int = 1 - guided_decoding: bool = False - max_tokens: int = 512 - temperature: float = 1.0 - top_p: float = 1.0 - logprobs: int = 1 - - def __post_init__(self): - super().__init__() - gd_params = None - if self.guided_decoding: - gd_params = GuidedDecodingParams(choice=["Positive", "Negative"]) - self.guided_decoding = gd_params - - @classmethod - def from_dict(cls, d: Mapping): - d = dict(d) - all_fields = set(cls.__dataclass_fields__.keys()) - valid_args = {k: v for k, v in d.items() if k in all_fields} - return cls(**valid_args) - - def asdict(self): - # Use the full object instead of a Dict - ret = asdict(self) - ret["guided_decoding"] = self.guided_decoding - return ret - - @dataclass class EngineConfig(EngineArgs): """ @@ -138,11 +95,10 @@ def create_vllm_config(self) -> VllmConfig: @dataclass class Policy(PolicyInterface): engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) - sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig) + sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) available_devices: str | None = None use_dcp: bool = True # Gets set up by setup - sampling_params: SamplingParams | None = None lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) policy_worker: "PolicyWorker" = None @@ -154,18 +110,20 @@ def __post_init__(self): self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.running = False + if isinstance(self.engine_config, Mapping): self.engine_config = EngineConfig.from_dict(self.engine_config) - if isinstance(self.sampling_config, Mapping): - self.sampling_config = SamplingConfig.from_dict(self.sampling_config) - # No conversion needed for boolean flag + + if isinstance(self.sampling_params, Mapping): + self.sampling_params = SamplingParams.from_optional(**self.sampling_params) + self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY @classmethod async def launch( # pyright: ignore[reportIncompatibleMethodOverride] cls: type["Policy"], *, engine_config: EngineConfig | Mapping = EngineConfig(), - sampling_config: SamplingConfig | Mapping = SamplingConfig(), + sampling_params: SamplingParams | Mapping = SamplingParams(), available_devices: str | None = None, use_dcp: bool = True, **kwargs, @@ -200,8 +158,10 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp ) - if isinstance(sampling_config, Mapping): - sampling_config = SamplingConfig(**sampling_config) + if isinstance(sampling_params, Mapping): + sampling_params = SamplingParams.from_optional(**sampling_params) + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY + logger.debug(f"Resolved sampling params: {sampling_params}") # TODO - expand support so name can stick within kwargs actor_name = kwargs.pop("name", cls.__name__) @@ -209,7 +169,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] actor_name, cls, engine_config=engine_config, - sampling_config=sampling_config, + sampling_params=sampling_params, available_devices=available_devices, policy_worker=workers, **kwargs, @@ -256,11 +216,6 @@ async def setup(self): self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() - # Setup sampling params - self.sampling_params = get_default_sampling_params( - self.vllm_config, overrides=self.sampling_config.asdict() - ) - # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` @@ -736,16 +691,3 @@ def convert_input(prompt=None, prompt_token_ids=None) -> dict: if prompt is not None: return {"prompt": prompt} return {"prompt_token_ids": prompt_token_ids} - - -def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams: - default_params = vllm_config.model_config.get_diff_sampling_param() - if overrides is not None: - default_params |= overrides - if default_params: - params = SamplingParams.from_optional(**default_params) - else: - params = SamplingParams() - # We only care about the final output - params.output_kind = RequestOutputKind.FINAL_ONLY - return params diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml index 4d3a56d04..c461bab08 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml @@ -14,7 +14,7 @@ policy: tensor_parallel_size: 1 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml index 0ac915d2a..04f450a98 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml @@ -16,7 +16,7 @@ policy: tensor_parallel_size: 4 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/tests/integration_tests/test_vllm_policy_correctness.py b/tests/integration_tests/test_vllm_policy_correctness.py index 806b80f64..d1edf1496 100644 --- a/tests/integration_tests/test_vllm_policy_correctness.py +++ b/tests/integration_tests/test_vllm_policy_correctness.py @@ -61,7 +61,7 @@ async def test_same_output(): "gpu_memory_utilization": GPU_MEMORY_UTILIZATION, "enable_prefix_caching": ENABLE_PREFIX_CACHING, }, - sampling_config={ + sampling_params={ "n": N_SAMPLES, "max_tokens": MAX_TOKENS, "temperature": TEMPERATURE, @@ -152,7 +152,7 @@ async def test_cache_usage(): "enable_prefix_caching": ENABLE_PREFIX_CACHING, "block_size": 16, }, - sampling_config={ + sampling_params={ "n": N_SAMPLES, "max_tokens": MAX_TOKENS, "temperature": TEMPERATURE, diff --git a/tests/sandbox/toy_rl/sumdigits-tp.yaml b/tests/sandbox/toy_rl/sumdigits-tp.yaml index f859b1e7c..4ce3e44cb 100644 --- a/tests/sandbox/toy_rl/sumdigits-tp.yaml +++ b/tests/sandbox/toy_rl/sumdigits-tp.yaml @@ -18,7 +18,7 @@ policy: tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/tests/sandbox/toy_rl/sumdigits.yaml b/tests/sandbox/toy_rl/sumdigits.yaml index 767bf7f3b..b9f349d40 100644 --- a/tests/sandbox/toy_rl/sumdigits.yaml +++ b/tests/sandbox/toy_rl/sumdigits.yaml @@ -19,7 +19,7 @@ policy: tensor_parallel_size: 1 pipeline_parallel_size: 1 enforce_eager: false - sampling_config: + sampling_params: n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 diff --git a/tests/sandbox/vllm/deepseek_r1.yaml b/tests/sandbox/vllm/deepseek_r1.yaml index 2255a5c03..ed5a59b36 100644 --- a/tests/sandbox/vllm/deepseek_r1.yaml +++ b/tests/sandbox/vllm/deepseek_r1.yaml @@ -8,9 +8,8 @@ policy: pipeline_parallel_size: 1 enable_expert_parallel: true # enforce_eager: true - sampling_config: + sampling_params: n: 2 - guided_decoding: false max_tokens: 512 provisioner: diff --git a/tests/sandbox/vllm/llama3_8b.yaml b/tests/sandbox/vllm/llama3_8b.yaml index 0e9a00607..55104a673 100644 --- a/tests/sandbox/vllm/llama3_8b.yaml +++ b/tests/sandbox/vllm/llama3_8b.yaml @@ -6,9 +6,8 @@ policy: tensor_parallel_size: 2 pipeline_parallel_size: 1 enforce_eager: true - sampling_config: + sampling_params: n: 2 - guided_decoding: false max_tokens: 512 services: diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0f3ce662c..54b093841 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -37,8 +37,7 @@ async def run(cfg: DictConfig): await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: - gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False) - prompt = "What is 3+5?" if gd else "Tell me a joke" + prompt = "Tell me a joke" print("Spawning service...") policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy) diff --git a/tests/sandbox/vllm/qwen2_5_32b.yaml b/tests/sandbox/vllm/qwen2_5_32b.yaml index 3edfaa9d3..a1decdb33 100644 --- a/tests/sandbox/vllm/qwen2_5_32b.yaml +++ b/tests/sandbox/vllm/qwen2_5_32b.yaml @@ -6,9 +6,8 @@ policy: tensor_parallel_size: 4 pipeline_parallel_size: 1 enforce_eager: true - sampling_config: + sampling_params: n: 2 - guided_decoding: false max_tokens: 512 services: diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_policy_config.py index 08de4f907..6225fcf0f 100644 --- a/tests/unit_tests/test_policy_config.py +++ b/tests/unit_tests/test_policy_config.py @@ -30,13 +30,14 @@ class TestPolicyConfig(unittest.TestCase): ) def test_policy_default_initialization(self): """Policy initializes with default values.""" - from forge.actors.policy import EngineConfig, Policy, SamplingConfig + from forge.actors.policy import EngineConfig, Policy + from vllm.sampling_params import SamplingParams policy = Policy() # Default factories self.assertIsInstance(policy.engine_config, EngineConfig) - self.assertIsInstance(policy.sampling_config, SamplingConfig) + self.assertIsInstance(policy.sampling_params, SamplingParams) self.assertIsNone(policy.available_devices) # Worker defaults @@ -47,17 +48,18 @@ def test_policy_default_initialization(self): self.assertTrue(policy.engine_config._is_v1_supported_oracle()) # Sampling defaults - self.assertEqual(policy.sampling_config.n, 1) - self.assertFalse(policy.sampling_config.guided_decoding) - self.assertEqual(policy.sampling_config.max_tokens, 512) + self.assertEqual(policy.sampling_params.n, 1) + self.assertFalse(policy.sampling_params.guided_decoding) + self.assertEqual(policy.sampling_params.max_tokens, 512) @pytest.mark.skipif( _import_error(), reason="Import error, likely due to missing dependencies on CI.", ) def test_policy_with_dict_configs(self): - """Policy accepts dicts for engine_config and sampling_config, including nested dicts.""" - from forge.actors.policy import EngineConfig, Policy, SamplingConfig + """Policy accepts dicts for engine_config and sampling_params, including nested dicts.""" + from forge.actors.policy import EngineConfig, Policy + from vllm.sampling_params import SamplingParams # Test with nested dict structure engine_dict = { @@ -74,18 +76,17 @@ def test_policy_with_dict_configs(self): sampling_dict = { "n": 1357, - "guided_decoding": True, "max_tokens": 2468, } policy = Policy( engine_config=engine_dict, - sampling_config=sampling_dict, + sampling_params=sampling_dict, available_devices="test-gpu-device-abcd", ) self.assertIsInstance(policy.engine_config, EngineConfig) - self.assertIsInstance(policy.sampling_config, SamplingConfig) + self.assertIsInstance(policy.sampling_params, SamplingParams) # Test basic fields self.assertEqual(policy.engine_config.model, "test-model-6789") @@ -94,10 +95,8 @@ def test_policy_with_dict_configs(self): self.assertTrue(policy.engine_config.enforce_eager) self.assertTrue(policy.engine_config._is_v1_supported_oracle()) - self.assertEqual(policy.sampling_config.n, 1357) - # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True - self.assertIsNotNone(policy.sampling_config.guided_decoding) - self.assertEqual(policy.sampling_config.max_tokens, 2468) + self.assertEqual(policy.sampling_params.n, 1357) + self.assertEqual(policy.sampling_params.max_tokens, 2468) # Test that engine_dict accepts and preserves nested dict structure # The original engine_dict should remain unchanged and accessible @@ -124,9 +123,8 @@ def test_policy_yaml_config_loading(self): pipeline_parallel_size: 5678 enforce_eager: true - sampling_config: + sampling_params: n: 2468 - guided_decoding: true max_tokens: 1357 available_devices: "yaml-test-device-xyz" @@ -147,10 +145,8 @@ def test_policy_yaml_config_loading(self): self.assertTrue(policy.engine_config.enforce_eager) self.assertTrue(policy.engine_config._is_v1_supported_oracle()) - self.assertEqual(policy.sampling_config.n, 2468) - # After __post_init__, guided_decoding becomes GuidedDecodingParams object when True - self.assertIsNotNone(policy.sampling_config.guided_decoding) - self.assertEqual(policy.sampling_config.max_tokens, 1357) + self.assertEqual(policy.sampling_params.n, 2468) + self.assertEqual(policy.sampling_params.max_tokens, 1357) self.assertEqual(policy.available_devices, "yaml-test-device-xyz")