-
Notifications
You must be signed in to change notification settings - Fork 997
/
main.py
671 lines (607 loc) · 27.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
engine = DeepSpeedRLHFEngine(actor_model_name_or_path=actor_model_name_or_path,
critic_model_name_or_path=critic_model_name_or_path,
tokenizer=tokenizer,
args=args)
trainer = DeepSpeedPPOTrainer(engine=engine, args=args)
for prompt_batch in prompt_train_dataloader:
out = trainer.generate_experience(prompt_batch)
actor_loss, critic_loss = trainer.train_rlhf(out)
"""
import argparse
import os
import random
import time
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from transformers import (
SchedulerType,
default_data_collator,
)
import deepspeed
from dschat.rlhf.ppo_trainer import DeepSpeedPPOTrainer, DeepSpeedPPOTrainerUnsupervised
from dschat.rlhf.rlhf_engine import DeepSpeedRLHFEngine
from dschat.utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data
from dschat.utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer, \
ExponentialMovingAverage
from dschat.utils.module.lora import convert_lora_to_linear_layer
from dschat.utils.perf import print_throughput_step3
from deepspeed.accelerator import get_accelerator
writer = None
def parse_args():
global writer
parser = argparse.ArgumentParser(
description="(Step 3) RLHF training arguments")
parser.add_argument(
'--data_path',
nargs='*',
default=['Dahoas/rm-static'],
help=
'Path to the training dataset. Accepted format: 1) a single data path, 2) multiple datasets in the form: dataset1-path dataset2-path ...'
)
parser.add_argument(
'--data_split',
type=str,
default='2,4,4',
help=
'Comma-separated list of proportions for training phase 1, 2, and 3 data. For example the split `2,4,4` '
'will use 60%% of data for phase 1, 20%% for phase 2 and 20%% for phase 3.'
)
parser.add_argument(
'--data_output_path',
type=str,
default='/tmp/data_files',
help=
'Where to store the data-related files such as shuffle index. This needs to be on a local storage of a node (not on a shared storage)'
)
parser.add_argument(
"--unsupervised_dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).")
parser.add_argument(
"--unsupervised_dataset_config_name",
type=str,
default=None,
help=
"The configuration name of the dataset to use (via the datasets library)."
)
parser.add_argument("--unsup_coef",
type=float,
default=27.8,
help='''gamma in Equation 2 from InstructGPT paper''')
parser.add_argument(
"--actor_model_name_or_path",
type=str,
help=
"Path to pretrained model or model identifier from huggingface.co/models.",
required=True)
parser.add_argument(
"--critic_model_name_or_path",
type=str,
help=
"Path to pretrained model or model identifier from huggingface.co/models.",
required=True)
parser.add_argument(
"--num_padding_at_beginning",
type=int,
default=1,
help=
"OPT model has a fixed number (1) of padding tokens at the beginning of the input. We did not see this in other models but keep it as an option for now."
)
parser.add_argument(
"--per_device_generation_batch_size",
type=int,
default=16,
help=
"Batch size (per device) for the training dataloader and generation purpose."
)
parser.add_argument(
"--per_device_training_batch_size",
type=int,
default=16,
help=
"Mini Batch size (per device) for the training dataloader and training purpose."
)
parser.add_argument("--generation_batches",
type=int,
default=1,
help="Generate x batches to go to training mode.")
parser.add_argument(
"--ppo_epochs",
type=int,
default=1,
help="For generated data, how many ppo training epochs to run.")
parser.add_argument("--max_prompt_seq_len",
type=int,
default=256,
help="The maximum sequence length.")
parser.add_argument("--max_answer_seq_len",
type=int,
default=256,
help="The maximum sequence length.")
parser.add_argument(
"--actor_learning_rate",
type=float,
default=9.65e-6,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--critic_learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use."
)
parser.add_argument("--actor_weight_decay",
type=float,
default=0.,
help="Weight decay to use.")
parser.add_argument("--critic_weight_decay",
type=float,
default=0.,
help="Weight decay to use.")
parser.add_argument("--num_train_epochs",
type=int,
default=1,
help="Total number of training epochs to perform.")
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="cosine",
help="The scheduler type to use.",
choices=[
"linear", "cosine", "cosine_with_restarts", "polynomial",
"constant", "constant_with_warmup"
],
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument(
"--num_warmup_steps",
type=int,
default=100,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--output_dir",
type=str,
default=None,
help="Where to store the model.")
parser.add_argument("--seed",
type=int,
default=None,
help="A seed for reproducible training.")
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
# DeepSpeed
parser.add_argument(
"--enable_hybrid_engine",
action='store_true',
help=
"Enable hybrid engine for actor model to optimize both inference and training through DeepSpeed."
)
parser.add_argument(
"--unpin_actor_parameters",
action='store_true',
help=
"Unpin actor's parameters during generation. This makes generation slower but requires less memory."
)
parser.add_argument(
"--release_inference_cache",
action='store_true',
help=
"Release the memory cache used for inference. This makes generation preparation slower but might increase e2e throughput by using larger batch size."
)
parser.add_argument(
"--inference_tp_size",
type=int,
default=1,
help=
"Tensor-parallelism degree used for the inference-optimization. Please note hybrid-engine need to be enabled when using this feature."
)
parser.add_argument(
"--tp_gather_partition_size",
type=int,
default=8,
help=
"Granularity to bring in layers for TP sharding inside the hybrid engine. Please note hybrid-engine and tp_inference_size > 1 need to be true when using this feature."
)
parser.add_argument('--offload',
action='store_true',
help='Enable ZeRO Offload techniques.')
parser.add_argument('--dtype',
type=str,
default='fp16',
choices=['fp16', 'bf16'],
help='Training data type')
parser.add_argument(
'--offload_reference_model',
action='store_true',
help='Enable ZeRO Offload techniques for reference model')
parser.add_argument(
'--actor_zero_stage',
type=int,
default=0,
help='ZeRO optimization stage for Actor model (and clones).')
parser.add_argument(
'--critic_zero_stage',
type=int,
default=0,
help='ZeRO optimization stage for Critic model (and reward).')
parser.add_argument(
'--actor_gradient_checkpointing',
action='store_true',
help='Enable HF gradient checkpointing for Actor model.')
parser.add_argument(
'--critic_gradient_checkpointing',
action='store_true',
help='Enable HF gradient checkpointing for Critic model.')
parser.add_argument(
"--actor_dropout",
type=float,
default=None,
help="If actor dropout configured, use it. "
"Otherwise, keep the default dropout configuration of the actor model."
)
parser.add_argument(
"--critic_dropout",
type=float,
default=None,
help="If critic dropout configured, use it. "
"Otherwise, keep the default dropout configuration of the critic model."
)
## LoRA for efficient training setting
parser.add_argument("--actor_lora_dim",
type=int,
default=0,
help="If > 0, use LoRA for efficient training.")
parser.add_argument("--actor_lora_module_name",
type=str,
default="decoder.layers.",
help="The scope of LoRA.")
parser.add_argument("--critic_lora_dim",
type=int,
default=0,
help="If > 0, use LoRA for efficient training.")
parser.add_argument("--critic_lora_module_name",
type=str,
default="decoder.layers.",
help="The scope of LoRA.")
parser.add_argument('--only_optimize_lora',
action='store_true',
help='Only optimize the LoRA parameters.')
parser.add_argument(
"--actor_lora_learning_rate",
type=float,
default=5e-4,
help=
"Initial actor LoRA learning rate (after the potential warmup period) to use."
)
parser.add_argument(
"--critic_lora_learning_rate",
type=float,
default=5e-4,
help=
"Initial critic LoRA learning rate (after the potential warmup period) to use."
)
## Make EMA as an optional feature
parser.add_argument('--enable_ema',
action='store_true',
help='Enable EMA checkpoint for the model.')
## Mixed Precision ZeRO++
parser.add_argument(
'--enable_mixed_precision_lora',
action='store_true',
help='Enable Mixed Precision ZeRO++ for training and generation.')
## low precision
parser.add_argument(
'--compute_fp32_loss',
action='store_true',
help='Relevant for low precision dtypes (fp16, bf16, etc.). '
'If specified, loss is calculated in fp32.'
'This applies for both actor and critic models.')
## Tensorboard logging
parser.add_argument('--enable_tensorboard',
action='store_true',
help='Enable tensorboard logging')
parser.add_argument('--tensorboard_path',
type=str,
default="step3_tensorboard")
## Tokenizer
parser.add_argument(
"--add_eot_token",
action='store_true',
help="Add <|endoftext|> as additional special token to tokenizer")
## Actor/critic model overflow alignment
parser.add_argument(
'--align_overflow',
action='store_true',
help='Align loss scale overflow between actor and critic')
## Print actor model answers during training
parser.add_argument('--print_answers',
action='store_true',
help='Print prompt and answers during training')
parser.add_argument(
"--print_answers_interval",
type=int,
default=1,
help="If --print_answers enabled, controls the printing interval.")
## Testing
parser.add_argument(
'--enable_test_mode',
action='store_true',
help=
'Enable a testing mode that terminates training based on args.test_stop_step'
)
parser.add_argument(
"--test_stop_step",
type=int,
default=0,
help=
"Training non-overflow step at which to terminate training during testing."
)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
if args.enable_tensorboard:
print(
f"Tensorboard logs going to: {args.tensorboard_path}/step3_tensorboard_logs"
)
writer = SummaryWriter(
f"{args.tensorboard_path}/step3_tensorboard_logs")
# Validate settings
if args.inference_tp_size > 1:
assert (
args.actor_zero_stage == 3
), "Zero stage 3 must be used to do Tensor sharding in the hybrid engine"
if args.actor_zero_stage == 2 and args.critic_zero_stage == 2 and args.enable_hybrid_engine and args.offload and args.actor_lora_dim == 0:
raise ValueError(
"The combination of [actor_zero_stage==2, critic_zero_stage==2, enable_hybrid_engine=True, offload=True, lora=False] is currently unsupported due to training instability!"
)
return args
def create_datasets(args, tokenizer, train_phase=3):
unsupervised_training_enabled = args.unsupervised_dataset_name and args.unsupervised_dataset_config_name
prompt_train_dataset, _ = create_prompt_dataset(
args.local_rank, args.data_path, args.data_split,
args.data_output_path, train_phase, args.seed, tokenizer,
args.max_prompt_seq_len)
if unsupervised_training_enabled:
unsupervised_train_dataset = get_unsupervised_data(args, tokenizer)
else:
unsupervised_train_dataset = None
# DataLoaders creation:
data_collator = DataCollatorRLHF(args.max_prompt_seq_len,
args.inference_tp_size)
if args.local_rank == -1:
prompt_train_sampler = RandomSampler(prompt_train_dataset)
if unsupervised_training_enabled:
unsupervised_train_sampler = RandomSampler(
unsupervised_train_dataset)
else:
prompt_train_sampler = DistributedSampler(prompt_train_dataset)
if unsupervised_training_enabled:
unsupervised_train_sampler = DistributedSampler(
unsupervised_train_dataset)
prompt_train_dataloader = DataLoader(
prompt_train_dataset,
collate_fn=data_collator,
sampler=prompt_train_sampler,
batch_size=args.per_device_generation_batch_size)
if unsupervised_training_enabled:
unsupervised_train_dataloader = DataLoader(
unsupervised_train_dataset,
collate_fn=default_data_collator,
sampler=unsupervised_train_sampler,
batch_size=args.per_device_generation_batch_size)
else:
unsupervised_train_dataloader = [None] * len(
prompt_train_dataloader) # basically a dummy dataloader
num_update_steps_per_epoch = min(len(prompt_train_dataloader), len(unsupervised_train_dataloader)) * \
(args.per_device_generation_batch_size / args.per_device_training_batch_size) * \
args.ppo_epochs / args.gradient_accumulation_steps
num_total_iters = int(args.num_train_epochs * num_update_steps_per_epoch)
return prompt_train_dataloader, unsupervised_train_dataloader, num_total_iters
def main():
args = parse_args()
if args.local_rank == -1:
device = torch.device(get_accelerator().device_name())
else:
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
deepspeed.init_distributed()
args.global_rank = torch.distributed.get_rank()
unsupervised_training_enabled = args.unsupervised_dataset_name and args.unsupervised_dataset_config_name
if unsupervised_training_enabled:
# if we enable unsupervised training, we need to double the batch size for actor model
args.gradient_accumulation_steps_actor = args.gradient_accumulation_steps * 2
else:
args.gradient_accumulation_steps_actor = args.gradient_accumulation_steps
# If passed along, set the training seed now.
set_random_seed(args.seed)
torch.distributed.barrier()
# load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
args.end_of_conversation_token = "<|endoftext|>"
additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None
tokenizer = load_hf_tokenizer(args.actor_model_name_or_path,
fast_tokenizer=True,
add_special_tokens=additional_special_tokens)
prompt_train_dataloader, unsupervised_train_dataloader, num_total_iters = create_datasets(
args=args, tokenizer=tokenizer, train_phase=3)
# RLHF engine is responsible for creating models, loading checkpoints, ds-initialize models/optims/lr-schedulers
rlhf_engine = DeepSpeedRLHFEngine(
actor_model_name_or_path=args.actor_model_name_or_path,
critic_model_name_or_path=args.critic_model_name_or_path,
tokenizer=tokenizer,
num_total_iters=num_total_iters,
args=args)
# Mixed Precision ZeRO++
if args.enable_mixed_precision_lora:
assert args.actor_lora_dim > 0, "Mixed Precision LoRA requires LoRA to be enabled"
assert args.actor_zero_stage == 3, "Mixed Precision LoRA requires Zero stage 3"
rlhf_engine.actor.optimizer.quantize_nontrainable_params()
print_rank_0("Mixed Precision ZeRO++ enabled")
ppo_trainer = DeepSpeedPPOTrainerUnsupervised if unsupervised_training_enabled else DeepSpeedPPOTrainer
trainer = ppo_trainer(rlhf_engine, args)
# first number is how many experience-batch to generate, second number is the training batch size, which is the micro-batch size used
exp_mini_dataset = MiniDataset(args.generation_batches,
args.per_device_training_batch_size)
unsup_mini_dataset = MiniDataset(args.generation_batches,
args.per_device_training_batch_size)
# Train!
print_rank_0(
f"***** Running training (total_iters={num_total_iters}) *****",
args.global_rank)
non_overflow_step_count = 0
step_average_reward = 0.
ema_reward_score = ExponentialMovingAverage()
for epoch in range(args.num_train_epochs):
print_rank_0(
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Generation Batches {min(len(prompt_train_dataloader), len(unsupervised_train_dataloader))}",
args.global_rank)
for step, (batch_prompt, batch_unsupervised) in enumerate(
zip(prompt_train_dataloader, unsupervised_train_dataloader)):
batch_prompt = to_device(batch_prompt, device)
# prompts = batch_prompt['prompt']
# length = prompts.size(-1)
# if length > args.max_prompt_seq_len:
# prompts = prompts[:, length - args.max_prompt_seq_len:]
# raise ValueError("Prompt length is too long")
out = trainer.generate_experience(batch_prompt['prompt'],
batch_prompt['prompt_att_mask'],
step)
training_start = time.time()
if batch_unsupervised is not None:
batch_unsupervised = to_device(batch_unsupervised, device)
unsup_dataset = unsup_mini_dataset.add(batch_unsupervised)
else:
unsup_dataset = unsup_mini_dataset.add(
[[None] * args.per_device_generation_batch_size])
exp_dataset = exp_mini_dataset.add(out)
if exp_dataset is not None:
inner_iter = 0
actor_loss_sum, critic_loss_sum, unsup_loss_sum = 0, 0, 0
average_reward = 0
if args.actor_gradient_checkpointing:
rlhf_engine.actor.gradient_checkpointing_enable()
for ppo_ep in range(args.ppo_epochs):
for i, (exp_data, unsup_data) in enumerate(
zip(exp_dataset, unsup_dataset)):
actor_loss, critic_loss = trainer.train_rlhf(exp_data)
actor_loss_sum += actor_loss.item()
critic_loss_sum += critic_loss.item()
average_reward += exp_data["rewards"].mean()
if unsupervised_training_enabled:
unsup_loss = trainer.train_unsupervised(
unsup_data, args.unsup_coef)
unsup_loss_sum += unsup_loss.item()
inner_iter += 1
if args.enable_ema:
moving_average(rlhf_engine.actor,
rlhf_engine.actor_ema,
zero_stage=args.actor_zero_stage)
random.shuffle(exp_dataset)
random.shuffle(unsup_dataset)
end = time.time()
training_time = end - training_start
e2e_time = training_time + trainer.generate_time * args.generation_batches # it is an approximation, we did not include, e.g., rw forward time etc
print_rank_0(
f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}',
args.global_rank)
print_throughput_step3(rlhf_engine.actor.module,
rlhf_engine.critic, args, e2e_time,
trainer.generate_time, training_time,
args.global_rank)
average_reward = get_all_reduce_mean(average_reward).item()
step_average_reward += average_reward / args.gradient_accumulation_steps_actor
if (step + 1) % args.gradient_accumulation_steps_actor == 0:
ema_reward_score.update(step_average_reward)
step_average_reward = 0.
print_rank_0(
f"Average reward score: {average_reward/inner_iter} | EMA reward score: {ema_reward_score.get()}",
args.global_rank)
print_rank_0(
"-------------------------------------------------------------------------------------",
args.global_rank)
if args.enable_tensorboard and torch.distributed.get_rank(
) == 0:
writer.add_scalar('reward',
average_reward / inner_iter,
global_step=step)
writer.add_scalar('actor_loss',
actor_loss.item(),
global_step=step)
writer.add_scalar('actor_loss_sum',
actor_loss_sum,
global_step=step)
writer.add_scalar('critic_loss',
critic_loss.item(),
global_step=step)
writer.add_scalar('critic_loss_sum',
critic_loss_sum,
global_step=step)
writer.flush()
if args.actor_gradient_checkpointing:
rlhf_engine.actor.gradient_checkpointing_disable()
actor_overflow, critic_overflow = trainer.get_overflow()
if not actor_overflow and not critic_overflow:
non_overflow_step_count += 1
if args.enable_test_mode and non_overflow_step_count == args.test_stop_step:
break
if args.enable_test_mode:
break
if args.output_dir is not None:
print_rank_0('saving model ...')
rlhf_engine.actor = convert_lora_to_linear_layer(rlhf_engine.actor)
rlhf_engine.critic = convert_lora_to_linear_layer(rlhf_engine.critic)
if args.enable_ema:
rlhf_engine.actor_ema = convert_lora_to_linear_layer(
rlhf_engine.actor_ema)
if torch.distributed.get_rank() == 0:
save_hf_format(rlhf_engine.actor,
tokenizer,
args,
sub_folder='actor')
save_hf_format(rlhf_engine.critic,
tokenizer,
args,
sub_folder='critic')
if args.enable_ema:
save_hf_format(rlhf_engine.actor_ema,
tokenizer,
args,
sub_folder='actor_ema')
if args.actor_zero_stage == 3:
save_zero_three_model(rlhf_engine.actor,
global_rank=args.global_rank,
save_dir=os.path.join(
args.output_dir, 'actor'),
zero_stage=args.actor_zero_stage)
if args.enable_ema:
save_zero_three_model(rlhf_engine.actor_ema,
global_rank=args.global_rank,
save_dir=os.path.join(
args.output_dir, 'actor_ema'),
zero_stage=args.actor_zero_stage)
if args.critic_zero_stage == 3:
save_zero_three_model(rlhf_engine.critic,
global_rank=args.global_rank,
save_dir=os.path.join(
args.output_dir, 'critic'),
zero_stage=args.critic_zero_stage)
if __name__ == "__main__":
main()