Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chatgpt]Reward Model Training Process update #3133

Merged
merged 18 commits into from
Mar 20, 2023
Merged
4 changes: 2 additions & 2 deletions applications/ChatGPT/chatgpt/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .reward_dataset import RewardDataset
from .reward_dataset import RmStaticDataset, HhRlhfDataset
from .utils import is_rank_0

__all__ = ['RewardDataset', 'is_rank_0']
__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0']
65 changes: 60 additions & 5 deletions applications/ChatGPT/chatgpt/dataset/reward_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,80 @@

from .utils import is_rank_0


class RewardDataset(Dataset):
# Dahaos/rm-static
class RmStaticDataset(Dataset):
"""
Dataset for reward model

Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""

def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt']

chosen = prompt + data['chosen'] + "<|endoftext|>"
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})

reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})

def __len__(self):
length = len(self.chosen)
return length

def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]

# Anthropic/hh-rlhf
class HhRlhfDataset(Dataset):
"""
Dataset for reward model

Args:
dataset: dataset for reward model
tokenizer: tokenizer for reward model
max_length: max length of input
special_token: special token at the end of sentence
"""
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.chosen = []
self.reject = []
if special_token is None:
self.end_token = tokenizer.eos_token
else:
self.end_token = special_token
for data in tqdm(dataset, disable=not is_rank_0()):
chosen = data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
Expand All @@ -34,7 +89,7 @@ def __init__(self, dataset, tokenizer: Callable, max_length: int) -> None:
"attention_mask": chosen_token['attention_mask']
})

reject = prompt + data['rejected'] + "<|endoftext|>"
reject = data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
Expand Down
4 changes: 2 additions & 2 deletions applications/ChatGPT/chatgpt/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import Actor, Critic, RewardModel
from .loss import PairWiseLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss

__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'PairWiseLoss']
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
1 change: 1 addition & 0 deletions applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ def __init__(self,
if checkpoint:
model.gradient_checkpointing_enable()
value_head = nn.Linear(model.config.hidden_size, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
1 change: 1 addition & 0 deletions applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ def __init__(self,
model.gradient_checkpointing_enable()

value_head = nn.Linear(model.config.n_embd, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
14 changes: 12 additions & 2 deletions applications/ChatGPT/chatgpt/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,23 @@ def forward(self,
return policy_loss + self.pretrain_coef * lm_loss


class PairWiseLoss(nn.Module):
Fazziekey marked this conversation as resolved.
Show resolved Hide resolved
class LogSigLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2203.02155
"""

def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(chosen_reward - reject_reward)
log_probs = torch.log(probs)
loss = -log_probs.mean()
return loss


class LogExpLoss(nn.Module):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2204.05862
"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return loss
1 change: 1 addition & 0 deletions applications/ChatGPT/chatgpt/models/opt/opt_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ def __init__(self,
model.gradient_checkpointing_enable()

value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1))
super().__init__(model, value_head, lora_rank, lora_train_bias)
111 changes: 69 additions & 42 deletions applications/ChatGPT/chatgpt/trainer/rm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from abc import ABC

import pandas as pd
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.models.loss import PairWiseLoss
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from datetime import datetime
from torch.optim import Optimizer, lr_scheduler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from .strategies import Strategy
from .utils import is_rank_0

Expand All @@ -20,74 +19,102 @@ class RewardModelTrainer(ABC):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation
loss_fn (callable): the loss function to use for training
train_dataset (Dataset): the dataset to use for training
valid_dataset (Dataset): the dataset to use for validation
eval_dataset (Dataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""

def __init__(
self,
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: RewardDataset,
eval_dataset: RewardDataset,
loss_fn,
train_dataset: Dataset,
valid_dataset: Dataset,
eval_dataset: Dataset,
batch_size: int = 1,
max_epochs: int = 2,
max_epochs: int = 1,
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)

self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.loss_fn = PairWiseLoss()
self.loss_fn = loss_fn
self.optimizer = strategy.setup_optimizer(optim, self.model)
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100)


def fit(self, use_lora):
def eval_acc(self, dataloader):
dist = 0
on = 0
cnt = 0
self.model.eval()
with torch.no_grad():
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
for i in range(len(chosen_reward)):
cnt += 1
if chosen_reward[i] > reject_reward[i]:
on += 1
dist += (chosen_reward - reject_reward).mean().item()
dist_mean = dist / len(dataloader)
acc = on / cnt
self.model.train()
return dist_mean, acc


def fit(self):
time = datetime.now()
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
# train
self.model.train()
cnt = 0
acc = 0
dist = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
cnt += 1
if cnt == 100:
self.scheduler.step()
dist, acc = self.eval_acc(self.valid_dataloader)
cnt = 0
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
step_bar.update()
step_bar.set_postfix({'loss': loss.item()})

step_bar.set_postfix({'dist': dist, 'acc': acc})
# eval
self.model.eval()
with torch.no_grad():
dist = 0
loss_sum = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward).mean().item()
loss = self.loss_fn(chosen_reward, reject_reward)
loss_sum += loss.item()
dist_mean = dist / self.eval_dataloader.__len__()
loss_mean = loss_sum / self.eval_dataloader.__len__()
dist, acc = self.eval_acc(self.eval_dataloader)
if is_rank_0():
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
log.to_csv('log.csv', mode='a', header=False, index=False)
epoch_bar.update()
step_bar.set_postfix({'loss': loss_mean, 'dist_mean': dist_mean})
step_bar.set_postfix({'dist': dist, 'acc': acc})
step_bar.close()
41 changes: 30 additions & 11 deletions applications/ChatGPT/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,42 @@ pip install -r requirements.txt
```

## Train the reward model (Stage 2)
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.

You can download the dataset from huggingface automatically.

Use these code to train your reward model.

```shell
# Naive reward model training
python train_reward_model.py --pretrain <your model path> --model <your model type> --strategy naive
# Take naive reward model training with opt-350m as example
python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive
# use colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain <your model path> --model <your model type> --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
```

### Features and tricks in RM training
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
- We change the loss to valid_acc and pair_dist to monitor progress during training.
- We add special token to the end of the sequence to get better result.
- We use cosine-reducing lr-scheduler for RM training.
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861).

### Experiment result
Model performance in [Anthropics paper](https://arxiv.org/abs/2112.00861):

<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225263321-8d64c3a8-6877-4cc8-9b61-0e1c52d3d94f.png">

<div align=left>Our training & test result of bloom-560m for 1 epoch:

<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225262950-a7f0a686-25de-44ec-98f2-11b83ea86674.png">

<div align=left>

## Train with dummy prompt data (Stage 3)

This script supports 3 strategies:
This script supports 4 kinds of strategies:

- naive
- ddp
- colossalai
- colossalai_zero2
- colossalai_gemini

It uses random generated prompt data.

Expand All @@ -53,7 +69,7 @@ We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-cha

You should download `prompts.csv` first.

This script also supports 3 strategies.
This script also supports 4 strategies.

```shell
# display cli help
Expand All @@ -75,6 +91,9 @@ python inference.py --model_path <your actor model path> --model <your model typ
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
```

## Attention
The examples is just a demo for testing our progress of RM and PPO training.


#### data
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
Expand Down
14 changes: 14 additions & 0 deletions applications/ChatGPT/examples/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,17 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2

rm -rf ${BASE}/actor_checkpoint_prompts.pt

# train rm
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'facebook/opt-350m' --model 'opt' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
--test True --lora_rank 4

torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \
--strategy colossalai_gemini --loss_fn 'log_exp'\
--dataset 'Dahoas/rm-static' --test True --lora_rank 4

rm -rf ${BASE}/rm_ckpt.pt
Loading