Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ policy:
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ policy:
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ policy:
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/mast/qwen3_14b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ policy:
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/mast/qwen3_1_7b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ policy:
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/mast/qwen3_32b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ policy:
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/mast/qwen3_4b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ policy:
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion apps/mast/qwen3_8b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ policy:
# TODO: Had to disable this becasue vLLm wouldn't like
# needs to revisited.
disable_custom_all_reduce: true
sampling_config:
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
86 changes: 14 additions & 72 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sys
from collections.abc import Mapping
from copy import copy
from dataclasses import asdict, dataclass, field, fields
from dataclasses import dataclass, field, fields

import torch
import torch.distributed.checkpoint as dcp
Expand All @@ -26,7 +26,7 @@
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import get_distributed_init_method
Expand Down Expand Up @@ -62,49 +62,6 @@
logger.setLevel(logging.INFO)


@dataclass
class SamplingConfig:
"""
Overrides for vLLMs sampling params.

Note: We'll want to tie this closer to or directly use vllm's
SamplingParams. It is currently used to track a supported
subset

Args:
n: Number of samples to generate.
guided_decoding: Whether to use guided decoding.
max_tokens: Maximum number of tokens to generate.
"""

n: int = 1
guided_decoding: bool = False
max_tokens: int = 512
temperature: float = 1.0
top_p: float = 1.0
logprobs: int = 1

def __post_init__(self):
super().__init__()
gd_params = None
if self.guided_decoding:
gd_params = GuidedDecodingParams(choice=["Positive", "Negative"])
self.guided_decoding = gd_params

@classmethod
def from_dict(cls, d: Mapping):
d = dict(d)
all_fields = set(cls.__dataclass_fields__.keys())
valid_args = {k: v for k, v in d.items() if k in all_fields}
return cls(**valid_args)

def asdict(self):
# Use the full object instead of a Dict
ret = asdict(self)
ret["guided_decoding"] = self.guided_decoding
return ret


@dataclass
class EngineConfig(EngineArgs):
"""
Expand Down Expand Up @@ -138,11 +95,10 @@ def create_vllm_config(self) -> VllmConfig:
@dataclass
class Policy(PolicyInterface):
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
available_devices: str | None = None
use_dcp: bool = True
# Gets set up by setup
sampling_params: SamplingParams | None = None
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
Expand All @@ -154,18 +110,20 @@ def __post_init__(self):
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.running = False

if isinstance(self.engine_config, Mapping):
self.engine_config = EngineConfig.from_dict(self.engine_config)
if isinstance(self.sampling_config, Mapping):
self.sampling_config = SamplingConfig.from_dict(self.sampling_config)
# No conversion needed for boolean flag

if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
engine_config: EngineConfig | Mapping = EngineConfig(),
sampling_config: SamplingConfig | Mapping = SamplingConfig(),
sampling_params: SamplingParams | Mapping = SamplingParams(),
available_devices: str | None = None,
use_dcp: bool = True,
**kwargs,
Expand Down Expand Up @@ -200,16 +158,18 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
"vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp
)

if isinstance(sampling_config, Mapping):
sampling_config = SamplingConfig(**sampling_config)
if isinstance(sampling_params, Mapping):
sampling_params = SamplingParams.from_optional(**sampling_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for whoever in the future needs to use rich fields of vllm/v1 in yaml:

from_optional doesn't play well with rich fields (nesting)

sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
logger.debug(f"Resolved sampling params: {sampling_params}")

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = policy_proc.spawn(
actor_name,
cls,
engine_config=engine_config,
sampling_config=sampling_config,
sampling_params=sampling_params,
available_devices=available_devices,
policy_worker=workers,
**kwargs,
Expand Down Expand Up @@ -256,11 +216,6 @@ async def setup(self):

self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()

# Setup sampling params
self.sampling_params = get_default_sampling_params(
self.vllm_config, overrides=self.sampling_config.asdict()
)

# Setup processors
# TODO: move all processing to the Environment
# TODO: add support for `log_stats` and `mm_registry`
Expand Down Expand Up @@ -736,16 +691,3 @@ def convert_input(prompt=None, prompt_token_ids=None) -> dict:
if prompt is not None:
return {"prompt": prompt}
return {"prompt_token_ids": prompt_token_ids}


def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
default_params = vllm_config.model_config.get_diff_sampling_param()
if overrides is not None:
default_params |= overrides
if default_params:
params = SamplingParams.from_optional(**default_params)
else:
params = SamplingParams()
# We only care about the final output
params.output_kind = RequestOutputKind.FINAL_ONLY
return params
2 changes: 1 addition & 1 deletion tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ policy:
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ policy:
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_vllm_policy_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def test_same_output():
"gpu_memory_utilization": GPU_MEMORY_UTILIZATION,
"enable_prefix_caching": ENABLE_PREFIX_CACHING,
},
sampling_config={
sampling_params={
"n": N_SAMPLES,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
Expand Down Expand Up @@ -152,7 +152,7 @@ async def test_cache_usage():
"enable_prefix_caching": ENABLE_PREFIX_CACHING,
"block_size": 16,
},
sampling_config={
sampling_params={
"n": N_SAMPLES,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits-tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ policy:
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ policy:
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
sampling_params:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
Expand Down
3 changes: 1 addition & 2 deletions tests/sandbox/vllm/deepseek_r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ policy:
pipeline_parallel_size: 1
enable_expert_parallel: true
# enforce_eager: true
sampling_config:
sampling_params:
n: 2
guided_decoding: false
max_tokens: 512

provisioner:
Expand Down
3 changes: 1 addition & 2 deletions tests/sandbox/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ policy:
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
sampling_params:
n: 2
guided_decoding: false
max_tokens: 512

services:
Expand Down
3 changes: 1 addition & 2 deletions tests/sandbox/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ async def run(cfg: DictConfig):
await mlogger.init_backends.call_one(metric_logging_cfg)

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"
prompt = "Tell me a joke"
Comment on lines 39 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"
prompt = "Tell me a joke"
prompt := cfg.get("prompt", "Tell me a joke")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh can't apply suggestion on deleted lines :/


print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)
Expand Down
3 changes: 1 addition & 2 deletions tests/sandbox/vllm/qwen2_5_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ policy:
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
sampling_params:
n: 2
guided_decoding: false
max_tokens: 512

services:
Expand Down
Loading
Loading