Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add rlhf pipeline. #748

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'chat':
cfg.epoch_per_collect = 1
cfg.batch_size = 1
cfg.learning_rate = 5e-7
cfg.answers_per_question = 3
cfg.kl_penalty_weight = 0.1
cfg.ppo_param_init = False

Check warning on line 176 in ding/bonus/config.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/config.py#L170-L176

Added lines #L170 - L176 were not covered by tests
else:
raise KeyError("not supported env type: {}".format(env_id))
else:
Expand Down Expand Up @@ -315,6 +322,16 @@
)
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
elif env_id == 'chat':
from dizoo.chat.env import ChatEnv
return ChatEnv(

Check warning on line 327 in ding/bonus/config.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/config.py#L325-L327

Added lines #L325 - L327 were not covered by tests
batch_size=1,
reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
maxlen_prompt=128,
maxlen_res=128,
)
else:
raise KeyError("not supported env type: {}".format(env_id))

Expand Down
25 changes: 19 additions & 6 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy

Check warning on line 1 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L1

Added line #L1 was not covered by tests
from typing import Optional, Union, List
from ditk import logging
from easydict import EasyDict
Expand All @@ -9,7 +10,7 @@
import torch
from ding.framework import task, OnlineRLContext
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator, ChatCollector
from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
from ding.utils import set_pkg_seed
Expand Down Expand Up @@ -62,6 +63,8 @@
'Hopper-v3',
'HalfCheetah-v3',
'Walker2d-v3',
# rlhf
'chat'
]
"""
Overview:
Expand Down Expand Up @@ -170,6 +173,8 @@
action_shape = int(action_space.n)
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
action_shape = get_hybrid_shape(action_space)
elif action_space is None:
pass

Check warning on line 177 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L176-L177

Added lines #L176 - L177 were not covered by tests
else:
action_shape = action_space.shape

Expand All @@ -191,7 +196,11 @@
popart_head=True,
**self.cfg.model
)
self.policy = PPOFPolicy(self.cfg, model=model)
if self.cfg.chat_data:
orig_model = copy.deepcopy(model)

Check warning on line 200 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L199-L200

Added lines #L199 - L200 were not covered by tests
else:
orig_model = None
self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model)

Check warning on line 203 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L202-L203

Added lines #L202 - L203 were not covered by tests
if policy_state_dict is not None:
self.policy.load_state_dict(policy_state_dict)
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
Expand Down Expand Up @@ -246,10 +255,14 @@
pass

with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(ppof_adv_estimator(self.policy))
if self.policy._cfg.chat_data:

Check warning on line 258 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L258

Added line #L258 was not covered by tests
# task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
# task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))

Check warning on line 261 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L261

Added line #L261 was not covered by tests
else:
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))

Check warning on line 265 in ding/bonus/ppof.py

View check run for this annotation

Codecov / codecov/patch

ding/bonus/ppof.py#L263-L265

Added lines #L263 - L265 were not covered by tests
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
task.use(
wandb_online_logger(
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .functional import *
from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, ChatCollector
from .learner import OffPolicyLearner, HERLearner
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
Expand Down
61 changes: 61 additions & 0 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING
from easydict import EasyDict
import treetensor.torch as ttorch
Expand Down Expand Up @@ -190,4 +191,64 @@
break


class ChatCollector:
"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
"""

def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
return super(ChatCollector, cls).__new__(cls)

Check warning on line 204 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L202-L204

Added lines #L202 - L204 were not covered by tests

def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
"""
Arguments:
- seed (:obj:`int`): Random seed.
- policy (:obj:`Policy`): The policy to be collected.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
"""
self.env = env
self.env.seed(seed)
self.env.launch()
self.env = self.env._envs[0]
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len

Check warning on line 220 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L214-L220

Added lines #L214 - L220 were not covered by tests

def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
device = self.policy._device

Check warning on line 230 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L230

Added line #L230 was not covered by tests

obs = ttorch.as_tensor(self.env.last_batch['text_vec'])
batch_size = obs.shape[0]
obs = obs.to(device)

Check warning on line 234 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L232-L234

Added lines #L232 - L234 were not covered by tests

total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T]
for _ in range(self.policy._cfg.answers_per_question):
_, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs)
for i in range(batch_size):
total_action[i].append(copy.deepcopy(inference_output[i]))

Check warning on line 240 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L236-L240

Added lines #L236 - L240 were not covered by tests

mask, resp, rew = self.env.step(total_action)
ctx.env_step += 1
ctx.env_episode += 1

Check warning on line 244 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L242-L244

Added lines #L242 - L244 were not covered by tests

train_data = {}
train_data['obs'] = resp # [B x answer-per-question, T]
train_data['reward'] = rew # [B x answer-per-question, ]
train_data['mask'] = mask # [B x answer-per-question, T]

Check warning on line 249 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L246-L249

Added lines #L246 - L249 were not covered by tests

ctx.train_data = ttorch.as_tensor(train_data)

Check warning on line 251 in ding/framework/middleware/collector.py

View check run for this annotation

Codecov / codecov/patch

ding/framework/middleware/collector.py#L251

Added line #L251 was not covered by tests


# TODO battle collector
2 changes: 1 addition & 1 deletion ding/model/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
from .utils import create_model
from .utils import create_model, top_p_logits
15 changes: 15 additions & 0 deletions ding/model/common/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
import torch
from ding.model.common.utils import top_p_logits


@pytest.mark.unittest
class TestUtils:

def test_top_p_logits(self):
test_logit = torch.Tensor([[0., 0.91, 0.05, 0.04], [0.04, 0.46, 0.46, 0.04]])

gt_logit = torch.Tensor([[0., 1., 0., 0.], [0., 0.5, 0.5, 0.]])

pred_logit = top_p_logits(test_logit)
assert torch.sum((gt_logit - pred_logit) ** 2).item() < 1e-8
26 changes: 26 additions & 0 deletions ding/model/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,29 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
import_module(cfg.pop('import_names', []))
# here we must use the pop opeartion to ensure compatibility
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)


def top_p_logits(logits: torch.Tensor, topp: float = 0.9, filter_value: float = 0, min_topk: int = 1):
"""
Overview:
Filter a distribution of logits using nucleus (top-p) filtering. The output is also logit tensors but some \
values are masked.
Arguments:
- logits (:obj:`torch.Tensor`): The input logits for top-p sampling.
- topp (:obj:`float`): The top-p value, such as 0.9.
- filter_value (:obj:`float`): The value for masked logits in output, default as 0.
- min_topk (:obj:`int`): The min number of sampled logit, default as 1 (which means that at least one sample \
will not be masked.)
Returns:
- cum_logits (:obj:`torch.Tensor`): The output logits after masking.
"""
cum_logits = logits.clone()
if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[..., :min_topk] = False
# Remove tokens with cumulative top_p above the threshold
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
cum_logits[mask] = filter_value
cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
return cum_logits
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .language_transformer import LanguageTransformer
from .lm_vac import LlamaVAC
# algorithm-specific
from .pg import PG
from .ppg import PPG
Expand Down
Loading
Loading