From 956a483173e77ebf655ca9636a5f7b6ef010b307 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 23 Nov 2021 14:09:15 -0800 Subject: [PATCH] [deepspeed] zero inference (#14253) * [deepspeed] zero inference * only z3 makes sense for inference * fix and style * docs * rework * fix test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * responding to suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/main_classes/deepspeed.rst | 55 +++++++++++ setup.py | 2 +- src/transformers/deepspeed.py | 96 ++++++++++++++----- src/transformers/dependency_versions_table.py | 2 +- src/transformers/trainer.py | 9 +- tests/deepspeed/test_deepspeed.py | 23 ++++- 6 files changed, 149 insertions(+), 38 deletions(-) diff --git a/docs/source/main_classes/deepspeed.rst b/docs/source/main_classes/deepspeed.rst index db639bb53d5531..5b2e6e64e5c0c5 100644 --- a/docs/source/main_classes/deepspeed.rst +++ b/docs/source/main_classes/deepspeed.rst @@ -46,6 +46,20 @@ won't be possible on a single GPU. parts of DeepSpeed like ``zero.Init`` for ZeRO stage 3 and higher. To tap into this feature read the docs on :ref:`deepspeed-non-trainer-integration`. +What is integrated: + +Training: + +1. DeepSpeed ZeRO training supports the full ZeRO stages 1, 2 and 3 with ZeRO-Infinity (CPU and NVME offload). + +Inference: + +1. DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. It uses the same ZeRO protocol as training, but + it doesn't use an optimizer and a lr scheduler and only stage 3 is relevant. For more details see: + :ref:`deepspeed-zero-inference`. + +There is also DeepSpeed Inference - this is a totally different technology which uses Tensor Parallelism instead of +ZeRO (coming soon). @@ -1628,6 +1642,47 @@ larger multi-dimensional shape, this means that the parameter is partitioned and +.. _deepspeed-zero-inference: + + +ZeRO Inference +======================================================================================================================= + +ZeRO Inference uses the same config as ZeRO-3 Training. You just don't need the optimizer and scheduler sections. In +fact you can leave these in the config file if you want to share the same one with the training. They will just be +ignored. + +Otherwise you just need to pass the usual :class:`~transformers.TrainingArguments` arguments. For example: + +.. code-block:: bash + + deepspeed --num_gpus=2 your_program.py --do_eval --deepspeed ds_config.json + +The only important thing is that you need to use a ZeRO-3 configuration, since ZeRO-2 provides no benefit whatsoever +for the inference as only ZeRO-3 performs sharding of parameters, whereas ZeRO-1 shards gradients and optimizer states. + +Here is an example of running ``run_translation.py`` under DeepSpeed deploying all available GPUs: + +.. code-block:: bash + + deepspeed examples/pytorch/translation/run_translation.py \ + --deepspeed tests/deepspeed/ds_config_zero3.json \ + --model_name_or_path t5-small --output_dir output_dir \ + --do_eval --max_eval_samples 50 --warmup_steps 50 \ + --max_source_length 128 --val_max_target_length 128 \ + --overwrite_output_dir --per_device_eval_batch_size 4 \ + --predict_with_generate --dataset_config "ro-en" --fp16 \ + --source_lang en --target_lang ro --dataset_name wmt16 \ + --source_prefix "translate English to Romanian: " + +Since for inference there is no need for additional large memory used by the optimizer states and the gradients you +should be able to fit much larger batches and/or sequence length onto the same hardware. + + +Additionally DeepSpeed is currently developing a related product called Deepspeed-Inference which has no relationship +to the ZeRO technology, but instead uses tensor parallelism to scale models that can't fit onto a single GPU. This is a +work in progress and we will provide the integration once that product is complete. + Filing Issues ======================================================================================================================= diff --git a/setup.py b/setup.py index cf96f9e4ef1dc3..4d59a717f27047 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "cookiecutter==1.7.2", "dataclasses", "datasets", - "deepspeed>=0.5.3", + "deepspeed>=0.5.7", "docutils==0.16.0", "fairscale>0.3", "faiss-cpu", diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index bb5d25d4b2375b..edbcbd50cca200 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -111,6 +111,29 @@ def get_value(self, ds_key_long, default=None): return default return config.get(ds_key, default) + def del_config_sub_tree(self, ds_key_long, must_exist=False): + """ + Deletes a sub-section of the config file if it's found. + + Unless ``must_exist`` is :obj:`True` the section doesn't have to exist. + """ + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + for node in nodes: + parent_config = config + config = config.get(node) + if config is None: + if must_exist: + raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}") + else: + return + + # if found remove it + if parent_config is not None: + parent_config.pop(node) + def is_true(self, ds_key_long): """ Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to @@ -280,30 +303,10 @@ def deepspeed_config(): return None -def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): +def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps): """ - Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. - - If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made. - - Args: - trainer: Trainer object - num_training_steps: per single gpu - resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load - - Returns: model, optimizer, lr_scheduler - + A convenience wrapper that deals with optimizer and lr scheduler configuration. """ - import deepspeed - from deepspeed.utils import logger as ds_logger - - model = trainer.model - args = trainer.args - - hf_deepspeed_config = args.hf_deepspeed_config - hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) - - # resume config update - some bits like `model` and `num_training_steps` only become available during train config = hf_deepspeed_config.config # Optimizer + Scheduler @@ -351,13 +354,54 @@ def _lr_scheduler_callable(optimizer): else: lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) - # keep for quick debug: - # from pprint import pprint; pprint(config) + return optimizer, lr_scheduler + + +def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False): + """ + Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args. + + If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made. + + Args: + trainer: Trainer object + num_training_steps: per single gpu + resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load + inference: launch in inference mode (no optimizer and no lr scheduler) + + Returns: model, optimizer, lr_scheduler + + """ + import deepspeed + from deepspeed.utils import logger as ds_logger + + model = trainer.model + args = trainer.args + + # resume config update - some bits like `model` and `num_training_steps` only become available during train + hf_deepspeed_config = args.hf_deepspeed_config + hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) + config = hf_deepspeed_config.config - # set the Deepspeed log level consistent with the trainer + # set the Deepspeed log level consistent with the Trainer ds_logger.setLevel(args.get_process_log_level()) - model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + if inference: + # only Z3 makes sense for the inference + if not hf_deepspeed_config.is_zero3(): + raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config") + + # in case the training config is re-used for inference + hf_deepspeed_config.del_config_sub_tree("optimizer") + hf_deepspeed_config.del_config_sub_tree("lr_scheduler") + optimizer, lr_scheduler = None, None + model_parameters = None + else: + optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) + + # keep for quick debug: + # from pprint import pprint; pprint(config) model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 294cd16c9b1717..b074ffe13a36ef 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -8,7 +8,7 @@ "cookiecutter": "cookiecutter==1.7.2", "dataclasses": "dataclasses", "datasets": "datasets", - "deepspeed": "deepspeed>=0.5.3", + "deepspeed": "deepspeed>=0.5.7", "docutils": "docutils==0.16.0", "fairscale": "fairscale>0.3", "faiss-cpu": "faiss-cpu", diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f954fe3ae016c9..7e6d500265725b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2229,15 +2229,12 @@ def evaluation_loop( # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) + deepspeed_engine, _, _ = deepspeed_init( + self, num_training_steps=0, resume_from_checkpoint=None, inference=True + ) self.model = deepspeed_engine.module self.model_wrapped = deepspeed_engine self.deepspeed = deepspeed_engine - # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since - # for example the Z3-optimizer is a must for zero3 to work even for inference - what we - # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer - deepspeed_engine.optimizer.optimizer = None - deepspeed_engine.lr_scheduler = None model = self._wrap_model(self.model, training=False) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 5c06d8b57f4d07..8e7587235df08e 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -697,11 +697,10 @@ class TestDeepSpeedWithLauncher(TestCasePlus): def test_basic_distributed(self, stage): self.run_and_check(stage=stage, distributed=True) - @parameterized.expand(stages) - def test_do_eval_no_train(self, stage): - # we should not fail if train is skipped + def test_do_eval_no_train(self): + # testing only zero3 since zero2 makes no sense with inference self.run_and_check( - stage=stage, + stage=ZERO3, eval_steps=1, distributed=False, do_train=False, @@ -755,6 +754,22 @@ def test_resume_train_not_from_ds_checkpoint(self, stage): self.do_checks(output_dir, do_train=do_train, do_eval=do_eval) + @require_torch_multi_gpu + @parameterized.expand(["fp16", "fp32"]) + def test_inference(self, dtype): + # this is just inference, so no optimizer should be loaded + # it only works for z3 (makes no sense with z1-z2) + fp16 = True if dtype == "fp16" else False + self.run_and_check( + stage=ZERO3, + model_name=T5_TINY, + distributed=True, + do_train=False, + do_eval=True, + quality_checks=False, + fp16=fp16, + ) + def do_checks(self, output_dir, do_train=True, do_eval=True, quality_checks=True): if do_train: