Skip to content

Commit

Permalink
[chatgpt] Support saving ckpt in examples (#2846)
Browse files Browse the repository at this point in the history
* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2

* [chatgpt]fix rm eval typo

* fix rm eval

* fix pre commit

* add support of saving ckpt in examples

* fix single-gpu save
  • Loading branch information
ht-zhou committed Feb 22, 2023
1 parent 5979143 commit 34ca324
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
7 changes: 7 additions & 0 deletions applications/ChatGPT/examples/train_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def main(args):
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)

# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down
7 changes: 7 additions & 0 deletions applications/ChatGPT/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy

import pandas as pd
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
Expand Down Expand Up @@ -95,6 +96,12 @@ def tokenize_fn(texts):
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)


if __name__ == '__main__':
Expand Down

0 comments on commit 34ca324

Please sign in to comment.