From 965fac1c08679a38eee1340a9ee3ad9811851ecb Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 26 May 2023 09:36:16 -0700 Subject: [PATCH] remove assert for offload --- .../DeepSpeed-Chat/training/step1_supervised_finetuning/main.py | 2 -- .../training/step2_reward_model_finetuning/main.py | 2 -- .../DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py | 2 -- 3 files changed, 6 deletions(-) diff --git a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py index 088088ea3..4528c5db9 100644 --- a/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py @@ -198,8 +198,6 @@ def main(): # If passed along, set the training seed now. set_random_seed(args.seed) - assert not args.offload, "zero-offload is not currently supported but coming soon!" - torch.distributed.barrier() tokenizer = load_hf_tokenizer(args.model_name_or_path, fast_tokenizer=True) diff --git a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py index 0ef70a788..38f25b207 100644 --- a/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py @@ -187,8 +187,6 @@ def main(): args.global_rank = torch.distributed.get_rank() - assert not args.offload, "zero-offload is not currently supported but coming soon!" - ds_config = get_train_ds_config(offload=args.offload, stage=args.zero_stage) ds_config[ diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index 6903150a4..eb8be4ea4 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -365,8 +365,6 @@ def main(): args.global_rank = torch.distributed.get_rank() - assert not args.offload, "zero-offload is not currently supported but coming soon!" - unsupervised_training_enabled = args.unsupervised_dataset_name and args.unsupervised_dataset_config_name if unsupervised_training_enabled: # if we enable unsupervised training, we need to double the batch size for actor model