Skip to content
8 changes: 5 additions & 3 deletions scripts/rft/rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
f'--sampler_engine lmdeploy '
f'--max_new_tokens 768 '
f'--override_exist_file true '
f'--num_sampling_per_gpu_batch_size 2 '
f'--num_sampling_per_gpu_batch_size 1 '
f'--num_return_sequences 64 '
f'--cache_files sample_output/iter_{iter}_proc_{device}_cache.jsonl '
f'--output_file iter_{iter}_proc_{device}_cache.jsonl '
f'--top_p 1.0 '
f'--temperature 1.0 ')
print(f'Sampling caches of iter {iter}, part {device}.', flush=True)
env = os.environ.copy()
Expand Down Expand Up @@ -60,10 +61,11 @@ def do_sample(model: str, model_type: str, dataset: List[str], iter: int):
f'--load_args false '
f'--sampler_engine no '
f'--orm_model math '
f'--prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft '
f'--prm_model Qwen/Qwen2.5-Math-PRM-7B '
f'--prm_threshold 0.7 '
f'--max_new_tokens 768 '
f'--override_exist_file true '
f'--num_sampling_per_gpu_batch_size 2 '
f'--num_sampling_per_gpu_batch_size 1 '
f'--num_return_sequences 64 '
f'--output_file iter_{iter}_proc_{device}_sampling.jsonl '
f'--cache_files sample_output/iter_{iter}_proc_{device}_cache.jsonl ')
Expand Down
Loading