diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index d65f68d4d..8ff427ad4 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -32,7 +32,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 @@ -115,7 +115,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 with_gpus: true ref_model: diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index c8b73325d..8fc056a6d 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -35,7 +35,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: ${model} tensor_parallel_size: 4 pipeline_parallel_size: 1 @@ -118,7 +118,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 hosts: 1 with_gpus: true diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 30bd98bee..fedf2f36a 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -28,7 +28,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: ${model} tensor_parallel_size: 2 pipeline_parallel_size: 1 @@ -114,7 +114,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 with_gpus: true ref_model: diff --git a/apps/mast/qwen3_14b_mast.yaml b/apps/mast/qwen3_14b_mast.yaml index 5cc781eee..d9e9d7edd 100644 --- a/apps/mast/qwen3_14b_mast.yaml +++ b/apps/mast/qwen3_14b_mast.yaml @@ -34,7 +34,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56 tensor_parallel_size: 2 pipeline_parallel_size: 1 @@ -129,7 +129,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 2 with_gpus: true mesh_name: policy diff --git a/apps/mast/qwen3_1_7b_mast.yaml b/apps/mast/qwen3_1_7b_mast.yaml index c691e2098..5c1033db2 100644 --- a/apps/mast/qwen3_1_7b_mast.yaml +++ b/apps/mast/qwen3_1_7b_mast.yaml @@ -34,7 +34,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5 tensor_parallel_size: 1 pipeline_parallel_size: 1 @@ -125,7 +125,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 2 with_gpus: true mesh_name: policy diff --git a/apps/mast/qwen3_32b_mast.yaml b/apps/mast/qwen3_32b_mast.yaml index 5bf9890a5..f0e57edac 100644 --- a/apps/mast/qwen3_32b_mast.yaml +++ b/apps/mast/qwen3_32b_mast.yaml @@ -34,7 +34,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470 tensor_parallel_size: 2 pipeline_parallel_size: 1 @@ -128,7 +128,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 2 with_gpus: true mesh_name: policy diff --git a/apps/mast/qwen3_4b_mast.yaml b/apps/mast/qwen3_4b_mast.yaml index 323ae3d63..2a8d2b864 100644 --- a/apps/mast/qwen3_4b_mast.yaml +++ b/apps/mast/qwen3_4b_mast.yaml @@ -34,7 +34,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed tensor_parallel_size: 2 pipeline_parallel_size: 1 @@ -125,7 +125,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 2 with_gpus: true mesh_name: policy diff --git a/apps/mast/qwen3_8b_mast.yaml b/apps/mast/qwen3_8b_mast.yaml index 122219873..81c1f75dd 100644 --- a/apps/mast/qwen3_8b_mast.yaml +++ b/apps/mast/qwen3_8b_mast.yaml @@ -34,7 +34,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model tensor_parallel_size: 2 pipeline_parallel_size: 1 @@ -125,7 +125,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 2 with_gpus: true mesh_name: policy diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index d55d23dd4..686ec973b 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -12,7 +12,7 @@ import sys from collections.abc import Mapping from copy import copy -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field import torch import torch.distributed.checkpoint as dcp @@ -62,39 +62,9 @@ logger.setLevel(logging.INFO) -@dataclass -class EngineConfig(EngineArgs): - """ - EngineConfig extends EngineArgs with worker-specific fields. - Overlapping keys in input dict will override EngineArgs defaults. - """ - - model: str = "meta-llama/Llama-3.1-8B-Instruct" - tensor_parallel_size: int = 1 - pipeline_parallel_size: int = 1 - enforce_eager: bool = False - enable_expert_parallel: bool = False - - # Original method returns False when not run in the main thread - _is_v1_supported_oracle = lambda *_: True - - @classmethod - def from_dict(cls, d: Mapping): - d = dict(d) - all_fields = [f.name for f in fields(cls)] - valid_args = {k: v for k, v in d.items() if k in all_fields} - return cls(**valid_args) - - def create_vllm_config(self) -> VllmConfig: - """Converts the current EngineConfig into vLLM's vLLMConfig.""" - # Note: EngineArgs.create_engine_config - # creates a VllmConfig - return self.create_engine_config(UsageContext.LLM_CLASS) - - @dataclass class Policy(PolicyInterface): - engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig) + engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) available_devices: str | None = None use_dcp: bool = True @@ -111,8 +81,9 @@ def __post_init__(self): self._worker_procs: ProcMesh | None = None self.running = False - if isinstance(self.engine_config, Mapping): - self.engine_config = EngineConfig.from_dict(self.engine_config) + if isinstance(self.engine_args, Mapping): + self.engine_args = EngineArgs(**self.engine_args) + self.engine_args._is_v1_supported_oracle = lambda *_: True if isinstance(self.sampling_params, Mapping): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) @@ -122,7 +93,7 @@ def __post_init__(self): async def launch( # pyright: ignore[reportIncompatibleMethodOverride] cls: type["Policy"], *, - engine_config: EngineConfig | Mapping = EngineConfig(), + engine_args: EngineArgs | Mapping = EngineArgs(), sampling_params: SamplingParams | Mapping = SamplingParams(), available_devices: str | None = None, use_dcp: bool = True, @@ -150,10 +121,12 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] policy_proc_config.with_gpus = False policy_proc = await get_proc_mesh(process_config=policy_proc_config) - if isinstance(engine_config, Mapping): - engine_config = EngineConfig.from_dict(engine_config) + if isinstance(engine_args, Mapping): + engine_args = EngineArgs(**engine_args) + engine_args._is_v1_supported_oracle = lambda *_: True # Always default on + logger.debug(f"Resolved engine args: {engine_args}") - vllm_config = engine_config.create_vllm_config() + vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS) workers = worker_procs.spawn( "vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp ) @@ -168,7 +141,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] policy = policy_proc.spawn( actor_name, cls, - engine_config=engine_config, + engine_args=engine_args, sampling_params=sampling_params, available_devices=available_devices, policy_worker=workers, @@ -214,7 +187,9 @@ async def setup(self): # Guard for updating requests self.update_lock = asyncio.Condition() - self.vllm_config: VllmConfig = self.engine_config.create_vllm_config() + self.vllm_config: VllmConfig = self.engine_args.create_engine_config( + UsageContext.LLM_CLASS + ) # Setup processors # TODO: move all processing to the Environment diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml index c461bab08..8b64b83ca 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml @@ -9,7 +9,7 @@ off_by_n: 1 # Off by one by default # Policy configuration policy: - engine_config: + engine_args: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 @@ -63,7 +63,7 @@ trainer: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 with_gpus: true trainer: diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml index 04f450a98..5d754c3ad 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml @@ -11,7 +11,7 @@ off_by_n: 1 # Off by one by default # Policy configuration policy: - engine_config: + engine_args: model: ${model} tensor_parallel_size: 4 pipeline_parallel_size: 1 @@ -65,7 +65,7 @@ trainer: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 1 with_gpus: true trainer: diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 506fc5553..bbc41bc5c 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -166,7 +166,7 @@ async def test_sanity_check(self, request): cfg = self._load_config(config_path=config_path) trainer_proc_size = cfg.actors.trainer.procs - policy_tp_size = cfg.policy.engine_config.tensor_parallel_size + policy_tp_size = cfg.policy.engine_args.tensor_parallel_size if policy_tp_size != cfg.services.policy.procs: pytest.fail( diff --git a/tests/integration_tests/test_vllm_policy_correctness.py b/tests/integration_tests/test_vllm_policy_correctness.py index d1edf1496..b512591ba 100644 --- a/tests/integration_tests/test_vllm_policy_correctness.py +++ b/tests/integration_tests/test_vllm_policy_correctness.py @@ -53,7 +53,7 @@ async def test_same_output(): policy = await Policy.options( procs=1, num_replicas=1, with_gpus=True ).as_service( - engine_config={ + engine_args={ "model": MODEL_NAME, "tensor_parallel_size": TENSOR_PARALLEL_SIZE, "enforce_eager": ENFORCE_EAGER, @@ -143,7 +143,7 @@ async def test_cache_usage(): policy = await Policy.options( procs=1, num_replicas=1, with_gpus=True ).as_service( - engine_config={ + engine_args={ "model": MODEL_NAME, "tensor_parallel_size": TENSOR_PARALLEL_SIZE, "enforce_eager": ENFORCE_EAGER, diff --git a/tests/sandbox/toy_rl/sumdigits-tp.yaml b/tests/sandbox/toy_rl/sumdigits-tp.yaml index 4ce3e44cb..74fb57e4a 100644 --- a/tests/sandbox/toy_rl/sumdigits-tp.yaml +++ b/tests/sandbox/toy_rl/sumdigits-tp.yaml @@ -13,7 +13,7 @@ dataset: # Policy configuration policy: - engine_config: + engine_args: model: ${model} tensor_parallel_size: 2 pipeline_parallel_size: 1 diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index f6e66a141..e862ac60d 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -427,8 +427,6 @@ async def main(cfg: DictConfig): group_size = cfg.group_size max_req_tokens = cfg.max_req_tokens max_res_tokens = cfg.max_res_tokens - # TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184 - policy_tp_size = cfg.policy.engine_config.tensor_parallel_size # ---- Setup services ---- # print(f"{cfg.policy=}") diff --git a/tests/sandbox/toy_rl/sumdigits.yaml b/tests/sandbox/toy_rl/sumdigits.yaml index b9f349d40..06a192431 100644 --- a/tests/sandbox/toy_rl/sumdigits.yaml +++ b/tests/sandbox/toy_rl/sumdigits.yaml @@ -14,7 +14,7 @@ dataset: # Policy configuration policy: use_dcp: false - engine_config: + engine_args: model: ${model} tensor_parallel_size: 1 pipeline_parallel_size: 1 diff --git a/tests/sandbox/vllm/deepseek_r1.yaml b/tests/sandbox/vllm/deepseek_r1.yaml index ed5a59b36..fd4228d5a 100644 --- a/tests/sandbox/vllm/deepseek_r1.yaml +++ b/tests/sandbox/vllm/deepseek_r1.yaml @@ -2,7 +2,7 @@ # NOTE - this won't work until we have proper HostMesh support policy: - engine_config: + engine_args: model: "deepseek-ai/DeepSeek-R1-0528" tensor_parallel_size: 16 pipeline_parallel_size: 1 diff --git a/tests/sandbox/vllm/llama3_8b.yaml b/tests/sandbox/vllm/llama3_8b.yaml index 55104a673..95a2ad53a 100644 --- a/tests/sandbox/vllm/llama3_8b.yaml +++ b/tests/sandbox/vllm/llama3_8b.yaml @@ -1,7 +1,7 @@ # >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml policy: - engine_config: + engine_args: model: "meta-llama/Llama-3.1-8B-Instruct" tensor_parallel_size: 2 pipeline_parallel_size: 1 @@ -12,7 +12,7 @@ policy: services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 4 with_gpus: true diff --git a/tests/sandbox/vllm/qwen2_5_32b.yaml b/tests/sandbox/vllm/qwen2_5_32b.yaml index a1decdb33..6590b791a 100644 --- a/tests/sandbox/vllm/qwen2_5_32b.yaml +++ b/tests/sandbox/vllm/qwen2_5_32b.yaml @@ -1,7 +1,7 @@ # >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/qwen2_5_32b.yaml policy: - engine_config: + engine_args: model: "Qwen/Qwen2.5-32B" tensor_parallel_size: 4 pipeline_parallel_size: 1 diff --git a/tests/unit_tests/test_policy_config.py b/tests/unit_tests/test_policy_config.py index 6225fcf0f..c1a68e540 100644 --- a/tests/unit_tests/test_policy_config.py +++ b/tests/unit_tests/test_policy_config.py @@ -30,22 +30,23 @@ class TestPolicyConfig(unittest.TestCase): ) def test_policy_default_initialization(self): """Policy initializes with default values.""" - from forge.actors.policy import EngineConfig, Policy + from forge.actors.policy import Policy + from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams policy = Policy() # Default factories - self.assertIsInstance(policy.engine_config, EngineConfig) + self.assertIsInstance(policy.engine_args, EngineArgs) self.assertIsInstance(policy.sampling_params, SamplingParams) self.assertIsNone(policy.available_devices) # Worker defaults - self.assertEqual(policy.engine_config.model, "meta-llama/Llama-3.1-8B-Instruct") - self.assertEqual(policy.engine_config.tensor_parallel_size, 1) - self.assertEqual(policy.engine_config.pipeline_parallel_size, 1) - self.assertFalse(policy.engine_config.enforce_eager) - self.assertTrue(policy.engine_config._is_v1_supported_oracle()) + self.assertEqual(policy.engine_args.model, "meta-llama/Llama-3.1-8B-Instruct") + self.assertEqual(policy.engine_args.tensor_parallel_size, 1) + self.assertEqual(policy.engine_args.pipeline_parallel_size, 1) + self.assertFalse(policy.engine_args.enforce_eager) + self.assertTrue(policy.engine_args._is_v1_supported_oracle()) # Sampling defaults self.assertEqual(policy.sampling_params.n, 1) @@ -57,8 +58,9 @@ def test_policy_default_initialization(self): reason="Import error, likely due to missing dependencies on CI.", ) def test_policy_with_dict_configs(self): - """Policy accepts dicts for engine_config and sampling_params, including nested dicts.""" - from forge.actors.policy import EngineConfig, Policy + """Policy accepts dicts for engine_args and sampling_params, including nested dicts.""" + from forge.actors.policy import Policy + from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams # Test with nested dict structure @@ -80,20 +82,20 @@ def test_policy_with_dict_configs(self): } policy = Policy( - engine_config=engine_dict, + engine_args=engine_dict, sampling_params=sampling_dict, available_devices="test-gpu-device-abcd", ) - self.assertIsInstance(policy.engine_config, EngineConfig) + self.assertIsInstance(policy.engine_args, EngineArgs) self.assertIsInstance(policy.sampling_params, SamplingParams) # Test basic fields - self.assertEqual(policy.engine_config.model, "test-model-6789") - self.assertEqual(policy.engine_config.tensor_parallel_size, 7777) - self.assertEqual(policy.engine_config.pipeline_parallel_size, 8888) - self.assertTrue(policy.engine_config.enforce_eager) - self.assertTrue(policy.engine_config._is_v1_supported_oracle()) + self.assertEqual(policy.engine_args.model, "test-model-6789") + self.assertEqual(policy.engine_args.tensor_parallel_size, 7777) + self.assertEqual(policy.engine_args.pipeline_parallel_size, 8888) + self.assertTrue(policy.engine_args.enforce_eager) + self.assertTrue(policy.engine_args._is_v1_supported_oracle()) self.assertEqual(policy.sampling_params.n, 1357) self.assertEqual(policy.sampling_params.max_tokens, 2468) @@ -117,7 +119,7 @@ def test_policy_yaml_config_loading(self): from forge.actors.policy import Policy yaml_content = """ - engine_config: + engine_args: model: "yaml-test-model-9876" tensor_parallel_size: 1234 pipeline_parallel_size: 5678 @@ -139,37 +141,17 @@ def test_policy_yaml_config_loading(self): policy = Policy(**config) - self.assertEqual(policy.engine_config.model, "yaml-test-model-9876") - self.assertEqual(policy.engine_config.tensor_parallel_size, 1234) - self.assertEqual(policy.engine_config.pipeline_parallel_size, 5678) - self.assertTrue(policy.engine_config.enforce_eager) - self.assertTrue(policy.engine_config._is_v1_supported_oracle()) + self.assertEqual(policy.engine_args.model, "yaml-test-model-9876") + self.assertEqual(policy.engine_args.tensor_parallel_size, 1234) + self.assertEqual(policy.engine_args.pipeline_parallel_size, 5678) + self.assertTrue(policy.engine_args.enforce_eager) + self.assertTrue(policy.engine_args._is_v1_supported_oracle()) self.assertEqual(policy.sampling_params.n, 2468) self.assertEqual(policy.sampling_params.max_tokens, 1357) self.assertEqual(policy.available_devices, "yaml-test-device-xyz") - @pytest.mark.skipif( - _import_error(), - reason="Import error, likely due to missing dependencies on CI.", - ) - def test_engineconfig_ignores_invalid_keys(self): - """EngineConfig.from_dict ignores unexpected keys.""" - from forge.actors.policy import EngineConfig - - engine_config = { - "model": "custom-model", - "tensor_parallel_size": 2, - "invalid_key_123": "should be ignored", - } - - config = EngineConfig.from_dict(engine_config) - - self.assertEqual(config.model, "custom-model") - self.assertEqual(config.tensor_parallel_size, 2) - self.assertFalse(hasattr(config, "invalid_key_123")) - if __name__ == "__main__": unittest.main()