Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .meta/mast/qwen3_14b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down Expand Up @@ -106,6 +106,7 @@ ref_model:
flavor: 14B
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_14b
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
Expand Down
3 changes: 2 additions & 1 deletion .meta/mast/qwen3_1_7b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down Expand Up @@ -108,6 +108,7 @@ ref_model:
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
# hf_assets_path: hf://${model}
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
Expand Down
3 changes: 2 additions & 1 deletion .meta/mast/qwen3_32b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down Expand Up @@ -106,6 +106,7 @@ ref_model:
flavor: 32B
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
Expand Down
3 changes: 2 additions & 1 deletion .meta/mast/qwen3_4b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down Expand Up @@ -108,6 +108,7 @@ ref_model:
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b
# hf_assets_path: hf://${model}
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
Expand Down
3 changes: 2 additions & 1 deletion .meta/mast/qwen3_8b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down Expand Up @@ -106,6 +106,7 @@ ref_model:
flavor: 8B
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_8b
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 1
compile:
Expand Down
6 changes: 3 additions & 3 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# Global configuration
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default

Expand Down Expand Up @@ -57,7 +57,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down
6 changes: 3 additions & 3 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# Global configuration
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default

Expand Down Expand Up @@ -53,7 +53,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down
3 changes: 3 additions & 0 deletions src/forge/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

from omegaconf import DictConfig, OmegaConf

# Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}}
OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True)


def _has_component(node: Any) -> bool:
"""Check if a node has a _component_ field."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ trainer:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
seq_len: 2048
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
max_norm: 1.0
steps: 1000000
dtype: bfloat16
Expand Down
Loading