Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ More example config files can be found in `examples`.

For more detailed examples about how to use Trinity-RFT, please refer to the following tutorials:
+ [A quick example with GSM8k](./docs/sphinx_doc/source/tutorial/example_reasoning_basic.md);
+ [Off-policy / asynchronous modes of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md);
+ [Off-policy mode of RFT](./docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md);
+ [Asynchronous mode of RFT](./docs/sphinx_doc/source/tutorial/example_async_mode.md);
+ [Multi-turn tasks](./docs/sphinx_doc/source/tutorial/example_multi_turn.md);
+ [Data processing pipelines](./docs/sphinx_doc/source/tutorial/example_data_functionalities.md);
+ [Offline learning by DPO](./docs/sphinx_doc/source/tutorial/example_dpo.md).
Expand Down
Binary file added docs/sphinx_doc/assets/async-curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/sphinx_doc/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ More example config files can be found in `examples`.

For more detailed examples about how to use Trinity-RFT, please refer to the following documents:
+ [A quick example with GSM8k](tutorial/example_reasoning_basic.md);
+ [Off-policy / asynchronous modes of RFT](tutorial/example_reasoning_advanced.md);
+ [Off-policy mode of RFT](tutorial/example_reasoning_advanced.md);
+ [Asynchronous mode of RFT](tutorial/example_async_mode.md);
+ [Multi-turn tasks](tutorial/example_multi_turn.md);
+ [Data processing pipelines](tutorial/example_data_functionalities.md);
+ [Offline learning by DPO](tutorial/example_dpo.md).
Expand Down
41 changes: 41 additions & 0 deletions docs/sphinx_doc/source/tutorial/example_async_mode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# A quick example for asynchronous mode

This example shows how to run RFT in asynchronous mode with the GRPO algorithm, Qwen-2.5-1.5B-Instruct model and GSM8K dataset.

Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes.

For this purpose, we prepare two main config files: `trainer.yaml` and `explorer.yaml`.
The main difference between them is that in `trainer.yaml` we set `mode=train`, while in `explorer.yaml` we set `mode=explore`.
In addition, we need to configure the following parameters in both files.
The model weights of the explorer and trainer are synchronized once every `sync_iteration_interval * batch_size` tasks.

```yaml
data:
batch_size: <batch_size>
# The same checkpoint path
model:
checkpoint_path: /PATH/TO/CHECKPOINT

# The same data_base path
buffer:
train_dataset:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'

synchronizer:
sync_method: 'checkpoint'
sync_iteration_interval: <sync_iteration_interval>
```

You may run this example with the following command:

```bash
bash examples/async_gsm8k/run.sh
```

The following plot shows the learning curve of GRPO in the asynchronous mode.
> This result should be regarded merely as a baseline, since GRPO is supposed to be an on-policy algorithm.
> We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode.

![async](../../assets/async-curve.png)
16 changes: 1 addition & 15 deletions docs/sphinx_doc/source/tutorial/example_reasoning_advanced.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Example: off-policy / asynchronous RFT mode
# Example: off-policy RFT mode


Let's continue with the [previous GSM8k example](./example_reasoning_basic.md) and show some advanced features provided by Trinity-RFT, namely, off-policy or asynchronous RFT mode.
Expand Down Expand Up @@ -35,17 +35,3 @@ A similar performance boost is shown at step 21, which leads to a converged scor


![opmd](../../assets/opmd-curve.png)





## Asynchronous mode


Trinity-RFT supports the asynchronous and decoupled mode of RFT, where explorer and trainer act independently and asynchronously.
To run this mode, the explorer and trainer need to be launched separately, with the `mode` parameter in the config file set to `explore` and `train` respectively.



*We are still testing this mode more thoroughly. A concrete example is coming soon!*
13 changes: 13 additions & 0 deletions examples/async_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Asynchronous mode on GSM8K dataset

This example shows the usage of GRPO on the GSM8K dataset in an asynchronous mode.

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_async_mode.md).

The config files are located in [`trainer.yaml`](trainer.yaml), [`explorer.yaml`](explorer.yaml), and [`verl_config.yaml`](verl_config.yaml).

You can run this example by the following command:

```bash
bash examples/async_gsm8k/run.sh
```
58 changes: 58 additions & 0 deletions examples/async_gsm8k/explorer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
mode: explore
data:
# basic info
dataset_path: /PATH/TO/DATASET/
subset_name: ''
train_split: 'train'
eval_split: 'test'
format_config:
prompt_key: 'question'
response_key: 'answer'
# downstream loading related
total_epochs: 20
batch_size: 96
default_workflow_type: 'math_workflow'
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: 'checkpoints/qwen2.5-1.5B-gsm8k'
cluster:
node_num: 1
gpu_per_node: 8
buffer:
max_retry_times: 3
max_retry_interval: 1
train_dataset:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
synchronizer:
sync_method: 'checkpoint'
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
4 changes: 4 additions & 0 deletions examples/async_gsm8k/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log &
sleep 30
trinity run --config examples/async_gsm8k/trainer.yaml 2>&1 | tee trainer.log &
58 changes: 58 additions & 0 deletions examples/async_gsm8k/trainer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
mode: train
data:
# basic info
dataset_path: /PATH/TO/DATASET/
subset_name: ''
train_split: 'train'
eval_split: 'test'
format_config:
prompt_key: 'question'
response_key: 'answer'
# downstream loading related
total_epochs: 20
batch_size: 96
default_workflow_type: 'math_workflow'
model:
model_path: /PATH/TO/MODEL/
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: ""
cluster:
node_num: 1
gpu_per_node: 8
buffer:
max_retry_times: 3
max_retry_interval: 1
train_dataset:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
explorer:
engine_type: vllm_async
engine_num: 2
runner_num: 32
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
temperature: 1.0
seed: 42
logprobs: 0
repeat_times: 8
use_ray: false
backend: 'nccl'
max_pending_requests: 32
max_waiting_steps: 4
synchronizer:
sync_method: 'checkpoint'
sync_iteration_interval: 10
trainer:
trainer_type: 'verl'
algorithm_type: ppo
trainer_config_path: examples/async_gsm8k/verl_config.yaml
sft_warmup_iteration: 0 # Set to integer to enable sft warmup
eval_interval: 10
monitor:
cache_root_dir: ""
project: "Trinity-RFT-gsm8k"
name: "async-qwen2.5-1.5B-gsm8k"
Loading