/
rlhf_engine.py
executable file
·305 lines (265 loc) · 13.3 KB
/
rlhf_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import time
import torch
import deepspeed
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from transformers import AutoModelForCausalLM, get_scheduler
from dschat.utils.ds_utils import get_train_ds_config, get_eval_ds_config
from dschat.utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
from dschat.utils.model.model_utils import create_hf_model, create_critic_model
from dschat.utils.utils import get_optimizer_grouped_parameters
"""
TODOs:
* support HF models for critic (for debugging), must be a previously saved ckpt from step-2
* determine ds_config/zero_stage based on model size, gpu style, world size, etc
- get model size by creating simple meta model
- 1.3b: zero-2 for actor/ref models, zero-0 for others
- 13b+: zero-3 for all models
"""
def log_init(model_name, stime=None):
if torch.distributed.get_rank() == 0:
tag = "start" if stime is None else "end"
suffix = "ing" if stime is None else "ed"
duration = ""
if stime is not None:
duration = "(duration: {:.2f}s)".format(time.time() - stime)
msg = f"[{tag}] Initializ{suffix} {model_name} Model [{tag}] {duration}"
stars = (90 - len(msg)) // 2
extra_star = "*" if (90 - len(msg)) % 2 == 1 else ""
print("*" * stars + msg + "*" * stars + extra_star)
return time.time()
class DeepSpeedRLHFEngine():
def __init__(self, actor_model_name_or_path, critic_model_name_or_path,
tokenizer, args, num_total_iters):
self.args = args
self.num_total_iters = num_total_iters
self.tokenizer = tokenizer
self.actor = self._init_actor(
actor_model_name_or_path=actor_model_name_or_path)
self.ref = self._init_ref(
actor_model_name_or_path=actor_model_name_or_path)
self.actor_ema = None
if self.args.enable_ema:
self.actor_ema = self._init_ema(
actor_model_name_or_path=actor_model_name_or_path)
self.critic = self._init_critic(
critic_model_name_or_path=critic_model_name_or_path)
self.reward = self._init_reward(
critic_model_name_or_path=critic_model_name_or_path)
if self.args.critic_gradient_checkpointing:
self.critic.gradient_checkpointing_enable()
def _init_actor(self, actor_model_name_or_path):
stime = log_init("Actor")
# DS Config
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.actor_zero_stage,
enable_hybrid_engine=self.args.enable_hybrid_engine,
inference_tp_size=self.args.inference_tp_size,
release_inference_cache=self.args.release_inference_cache,
pin_parameters=(not self.args.unpin_actor_parameters),
tp_gather_partition_size=self.args.tp_gather_partition_size,
max_out_tokens=self.args.max_prompt_seq_len +
self.args.max_answer_seq_len,
enable_tensorboard=self.args.enable_tensorboard,
enable_mixed_precision_lora=self.args.enable_mixed_precision_lora,
tb_path=self.args.tensorboard_path,
tb_name="step3_actor")
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
#TODO(jeff): we should probably set grad accumlation steps here as well for clarity
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps_actor
# Model
actor_model = create_hf_model(
model_class=AutoModelForCausalLM,
model_name_or_path=actor_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_config,
dropout=self.args.actor_dropout)
# LoRA
if self.args.actor_lora_dim > 0:
actor_model = convert_linear_layer_to_lora(
actor_model, self.args.actor_lora_module_name,
self.args.actor_lora_dim)
if self.args.only_optimize_lora:
actor_model = only_optimize_lora_parameters(actor_model)
actor_model = make_model_gradient_checkpointing_compatible(
actor_model)
# Optimizer
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
optim_params = get_optimizer_grouped_parameters(
actor_model, self.args.actor_weight_decay,
self.args.actor_lora_learning_rate)
optim = AdamOptimizer(optim_params,
lr=self.args.actor_learning_rate,
betas=(0.9, 0.95))
# LR Scheduler
lr_scheduler = get_scheduler(
name=self.args.lr_scheduler_type,
optimizer=optim,
num_warmup_steps=self.args.num_warmup_steps,
num_training_steps=self.num_total_iters,
)
# DeepSpeed Engine
#TODO: move enable_hybrid_engine and pin_parameters to ds_config
actor_engine, *_ = deepspeed.initialize(model=actor_model,
optimizer=optim,
lr_scheduler=lr_scheduler,
config=ds_config)
log_init("Actor", stime=stime)
return actor_engine
def _init_ref(self, actor_model_name_or_path):
stime = log_init("Ref")
# DS Config
zero_stage = self.args.actor_zero_stage
if zero_stage != 3:
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype, zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
#TODO(jeff): we should probably set grad accumlation steps here as well for clarity
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps_actor
ref_model = create_hf_model(AutoModelForCausalLM,
actor_model_name_or_path, self.tokenizer,
ds_config)
ref_engine, *_ = deepspeed.initialize(model=ref_model,
config=ds_config)
log_init("Ref", stime=stime)
return ref_engine
def _init_ema(self, actor_model_name_or_path):
stime = log_init("EMA")
# DS Config
zero_stage = self.args.actor_zero_stage
if zero_stage != 3:
# If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0
ds_config = get_eval_ds_config(self.args.offload_reference_model,
self.args.dtype, zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
#TODO(jeff): we should probably set grad accumlation steps here as well for clarity
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps_actor
actor_model_ema = create_hf_model(AutoModelForCausalLM,
actor_model_name_or_path,
self.tokenizer, ds_config)
if self.args.actor_lora_dim > 0:
actor_model_ema = convert_linear_layer_to_lora(
actor_model_ema, self.args.actor_lora_module_name,
self.args.actor_lora_dim)
ema_engine, *_ = deepspeed.initialize(model=actor_model_ema,
config=ds_config)
log_init("EMA", stime=stime)
return ema_engine
def _init_critic(self, critic_model_name_or_path):
stime = log_init("Critic")
ds_config = get_train_ds_config(
offload=self.args.offload,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage,
enable_tensorboard=self.args.enable_tensorboard,
tb_path=self.args.tensorboard_path,
tb_name="step3_critic")
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
#TODO(jeff): we should probably set grad accumlation steps here as well for clarity
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps
ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=self.args.critic_zero_stage)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_eval_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps
# Model
critic_model = create_critic_model(
model_name_or_path=critic_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_eval_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
dropout=self.args.critic_dropout,
zero_stage=self.args.critic_zero_stage)
# LoRA
if self.args.critic_lora_dim > 0:
critic_model = convert_linear_layer_to_lora(
critic_model, self.args.critic_lora_module_name,
self.args.critic_lora_dim)
if self.args.only_optimize_lora:
critic_model = only_optimize_lora_parameters(critic_model)
critic_model = make_model_gradient_checkpointing_compatible(
critic_model)
# Optimizer
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
optim_params = get_optimizer_grouped_parameters(
critic_model, self.args.critic_weight_decay,
self.args.critic_lora_learning_rate)
optim = AdamOptimizer(optim_params,
lr=self.args.critic_learning_rate,
betas=(0.9, 0.95))
# LR Scheduler
lr_scheduler = get_scheduler(
name=self.args.lr_scheduler_type,
optimizer=optim,
num_warmup_steps=self.args.num_warmup_steps,
num_training_steps=self.num_total_iters,
)
# DeepSpeed Engine
critic_engine, *_ = deepspeed.initialize(model=critic_model,
optimizer=optim,
lr_scheduler=lr_scheduler,
config=ds_config)
log_init("Critic", stime=stime)
return critic_engine
def _init_reward(self, critic_model_name_or_path):
stime = log_init("Reward")
# DS Config
zero_stage = self.args.critic_zero_stage
if zero_stage != 3:
# If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory
zero_stage = 0
ds_config = get_eval_ds_config(offload=self.args.offload,
dtype=self.args.dtype,
stage=zero_stage)
ds_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps
ds_eval_config = get_eval_ds_config(offload=False,
dtype=self.args.dtype,
stage=zero_stage)
# We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine.
ds_eval_config[
'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size
ds_eval_config[
'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size(
) * self.args.gradient_accumulation_steps
# Model
reward_model = create_critic_model(
model_name_or_path=critic_model_name_or_path,
tokenizer=self.tokenizer,
ds_config=ds_eval_config,
num_padding_at_beginning=self.args.num_padding_at_beginning,
rlhf_training=True,
dropout=self.args.critic_dropout,
zero_stage=zero_stage)
reward_engine, *_ = deepspeed.initialize(model=reward_model,
config=ds_config)
log_init("Reward", stime=stime)
return reward_engine