Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chatgpt] Support saving ckpt in examples #2846

Merged
merged 10 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/ChatGPT/chatgpt/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .reward_dataset import RewardDataset
from .utils import is_rank_0

__all__ = ['RewardDataset']
__all__ = ['RewardDataset', 'is_rank_0']
4 changes: 3 additions & 1 deletion applications/ChatGPT/chatgpt/dataset/reward_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torch.utils.data import Dataset
from tqdm import tqdm

from .utils import is_rank_0


class RewardDataset(Dataset):
"""
Expand All @@ -18,7 +20,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
super().__init__()
self.chosen = []
self.reject = []
for data in tqdm(dataset):
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt']

chosen = prompt + data['chosen'] + "<|endoftext|>"
Expand Down
5 changes: 5 additions & 0 deletions applications/ChatGPT/chatgpt/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torch.distributed as dist


def is_rank_0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
4 changes: 2 additions & 2 deletions applications/ChatGPT/chatgpt/nn/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.body = model
if value_head is not None:
if value_head.out_features != 1:
raise ValueError("The value head of reward model's output dim should be 1!")
Expand All @@ -34,7 +34,7 @@ def __init__(self,
self.convert_to_lora()

def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
outputs = self.body(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
Expand Down
29 changes: 18 additions & 11 deletions applications/ChatGPT/chatgpt/trainer/rm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC

import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss
from torch.optim import Adam, Optimizer
Expand Down Expand Up @@ -55,7 +56,8 @@ def fit(self, use_lora):
# train
if use_lora > 0:
print("Using Lora")
lora.mark_only_lora_as_trainable(self.model.model)
lora.mark_only_lora_as_trainable(self.model.body)

else:
self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
Expand All @@ -74,16 +76,21 @@ def fit(self, use_lora):

# eval
self.model.eval()
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
with torch.no_grad():
ht-zhou marked this conversation as resolved.
Show resolved Hide resolved
dist = 0
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward)
dist_mean = dist / self.eval_dataloader.__len__()
loss_sum = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward).mean().item()
loss = self.loss_fn(chosen_reward, reject_reward)
loss_sum += loss.item()
dist_mean = dist / self.eval_dataloader.__len__()
loss_mean = loss_sum / self.eval_dataloader.__len__()
epoch_bar.update()
step_bar.set_postfix({'loss': loss.item(), 'dist_mean': dist_mean.item()})
step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean})
step_bar.close()
7 changes: 7 additions & 0 deletions applications/ChatGPT/examples/train_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ def main(args):
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)

# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down
7 changes: 7 additions & 0 deletions applications/ChatGPT/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy

import pandas as pd
import torch
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
from chatgpt.trainer import PPOTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
Expand Down Expand Up @@ -95,6 +96,12 @@ def tokenize_fn(texts):
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)


if __name__ == '__main__':
Expand Down
7 changes: 4 additions & 3 deletions applications/ChatGPT/examples/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def train(args):
# configure model
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
model = BLOOMRM(pretrained=args.pretrain).cuda()
with strategy.model_init_context():
model = BLOOMRM(pretrained=args.pretrain).cuda()
max_len = 1024

# configure optimizer
Expand Down Expand Up @@ -71,8 +72,8 @@ def train(args):
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
train(args)