Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
local_batch_size: 8 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
group_size: 16
local_batch_size: 32 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -69,8 +69,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -90,7 +90,7 @@ replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 8
dp_size: 1

# Reference model configuration
ref_model:
Expand Down Expand Up @@ -119,7 +119,7 @@ ref_model:
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
num_replicas: 4
hosts: 1
with_gpus: true
mesh_name: policy
Expand Down
Loading