From 0867c09318818f96cb62f610fe8ea95da213bc37 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 6 Feb 2024 13:00:40 -0500 Subject: [PATCH] torch-native pipeline parallelism for big models (#2345) * Broken version * Timing I would expect * Working version! * Use MethodType * working test * Tests * Use no split module classes explicitly * Put split_points in pipelien * Store split points in hf_split_points * fix case num_process=1 * Allow for dynamic batch padding (#2352) * Allow for dynamic batch paddign * Fix test * Update src/accelerate/inference.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Break early after the first valid bs is found * Less slicy-dicy * Test cv model * Start, need to test * Use dataloader-like logic * Refactor to utils * With tests * Update the source * Clean * bs=1 case * Add test * add some failing test * Almost working version * Much cleaner implementation * Use pad_input_tensor * All tests passing! * Do it at tracing too --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Marc Sun * Rm literal * Allow users to pass in max_memory * Note about recursion * Document, document, document * Right import check * Fix bug, add tests to multigpu runners * Change default to None * Start of docs * Try again? * Try again x2 * Trailing comma * Move import * Clean * typehint * typo * From code review * Use num_chunks * Update tests/test_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Bad copy/paste * hf_split_points --------- Co-authored-by: Marc Sun Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- docs/source/_toctree.yml | 2 + docs/source/package_reference/inference.md | 20 +++ .../usage_guides/distributed_inference.md | 104 ++++++++++- src/accelerate/__init__.py | 1 + src/accelerate/inference.py | 164 ++++++++++++++++++ src/accelerate/test_utils/__init__.py | 1 + .../scripts/external_deps/test_pippy.py | 130 ++++++++++++++ src/accelerate/test_utils/testing.py | 8 + src/accelerate/utils/__init__.py | 3 + src/accelerate/utils/imports.py | 23 +-- src/accelerate/utils/operations.py | 57 ++++++ tests/test_multigpu.py | 26 ++- tests/test_utils.py | 67 ++++++- 13 files changed, 587 insertions(+), 19 deletions(-) create mode 100644 docs/source/package_reference/inference.md create mode 100644 src/accelerate/inference.py create mode 100644 src/accelerate/test_utils/scripts/external_deps/test_pippy.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 23d07dcd843..3a25edfcf0a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -89,6 +89,8 @@ title: Logging - local: package_reference/big_modeling title: Working with large models + - local: package_reference/inference + title: Distributed inference with big models - local: package_reference/kwargs title: Kwargs handlers - local: package_reference/utilities diff --git a/docs/source/package_reference/inference.md b/docs/source/package_reference/inference.md new file mode 100644 index 00000000000..4347f98441e --- /dev/null +++ b/docs/source/package_reference/inference.md @@ -0,0 +1,20 @@ + + +# The inference API + +These docs refer to the [PiPPy](https://github.com/PyTorch/PiPPy) integration. + +[[autodoc]] inference.prepare_pippy diff --git a/docs/source/usage_guides/distributed_inference.md b/docs/source/usage_guides/distributed_inference.md index 41053658482..8ed896c8fb5 100644 --- a/docs/source/usage_guides/distributed_inference.md +++ b/docs/source/usage_guides/distributed_inference.md @@ -15,12 +15,18 @@ rendered properly in your Markdown viewer. # Distributed Inference with 🤗 Accelerate -Distributed inference is a common use case, especially with natural language processing (NLP) models. Users often want to -send a number of different prompts, each to a different GPU, and then get the results back. This also has other cases -outside of just NLP, however for this tutorial we will focus on just this idea of each GPU receiving a different prompt, -and then returning the results. +Distributed inference can fall into three brackets: -## The Problem +1. Loading an entire model onto each GPU and sending chunks of a batch through each GPU's model copy at a time +2. Loading parts of a model onto each GPU and processing a single input at one time +3. Loading parts of a model onto each GPU and using what is called scheduled Pipeline Parallelism to combine the two prior techniques. + +We're going to go through the first and the last bracket, showcasing how to do each as they are more realistic scenarios. + + +## Sending chunks of a batch automatically to each loaded model + +This is the most memory-intensive solution, as it requires each GPU to keep a full copy of the model in memory at a given time. Normally when doing this, users send the model to a specific device to load it from the CPU, and then move each prompt to a different device. @@ -55,7 +61,6 @@ a simple way to manage this. (To learn more, check out the relevant section in t Can it manage it? Yes. Does it add unneeded extra code however: also yes. -## The Solution With 🤗 Accelerate, we can simplify this process by using the [`Accelerator.split_between_processes`] context manager (which also exists in `PartialState` and `AcceleratorState`). This function will automatically split whatever data you pass to it (be it a prompt, a set of tensors, a dictionary of the prior data, etc.) across all the processes (with a potential @@ -134,3 +139,90 @@ with distributed_state.split_between_processes(["a dog", "a cat", "a chicken"], On the first GPU, the prompts will be `["a dog", "a cat"]`, and on the second GPU it will be `["a chicken", "a chicken"]`. Make sure to drop the final sample, as it will be a duplicate of the previous one. + +## Memory-efficient pipeline parallelism (experimental) + +This next part will discuss using *pipeline parallelism*. This is an **experimental** API utilizing the [PiPPy library by PyTorch](https://github.com/pytorch/PiPPy/) as a native solution. + +The general idea with pipeline parallelism is: say you have 4 GPUs and a model big enough it can be *split* on four GPUs using `device_map="auto"`. With this method you can send in 4 inputs at a time (for example here, any amount works) and each model chunk will work on an input, then receive the next input once the prior chunk finished, making it *much* more efficient **and faster** than the method described earlier. Here's a visual taken from the PyTorch repository: + +![PiPPy example](https://camo.githubusercontent.com/681d7f415d6142face9dd1b837bdb2e340e5e01a58c3a4b119dea6c0d99e2ce0/68747470733a2f2f692e696d6775722e636f6d2f657955633934372e706e67) + +To illustrate how you can use this with Accelerate, we have created a [model zoo example](https://github.com/muellerzr/pippy-device-map-playground/) showcasing a number of different models and situations. In this tutorial, we'll show this method for GPT2 across two GPUs. + +Before you proceed, please make sure you have the latest pippy installed by running the following: + +```bash +pip install torchpippy +``` + +We require at least version 0.2.0. To confirm that you have the correct version, run `pip show torchpippy`. + +Start by creating the model on the CPU: + +```{python} +from transformers import GPT2ForSequenceClassification, GPT2Config + +config = GPT2Config() +model = GPT2ForSequenceClassification(config) +model.eval() +``` + +Next you'll need to create some example inputs to use. These help PiPPy trace the model. + + + However you make this example will determine the relative batch size that will be used/passed + through the model at a given time, so make sure to remember how many items there are! + + +```{python} +input = torch.randint( + low=0, + high=config.vocab_size, + size=(2, 1024), # bs x seq_len + device="cpu", + dtype=torch.int64, + requires_grad=False, +) +``` +Next we need to actually perform the tracing and get the model ready. To do so, use the [`inference.prepare_pippy`] function and it will fully wrap the model for pipeline parallelism automatically: + +```{python} +from accelerate.inference import prepare_pippy +example_inputs = {"input_ids": input} +model = prepare_pippy(model, example_args=(input,)) +``` + + + + There are a variety of parameters you can pass through to `prepare_pippy`: + + * `split_points` lets you determine what layers to split the model at. By default we use wherever `device_map="auto" declares, such as `fc` or `conv1`. + + * `num_chunks` determines how the batch will be split and sent to the model itself (so `num_chunks=1` with four split points/four GPUs will have a naive MP where a single input gets passed between the four layer split points) + + + +From here, all that's left is to actually perform the distributed inference! + + + +When passing inputs, we highly recommend to pass them in as a tuple of arguments. Using `kwargs` is supported, however, this approach is experimental. + + +```{python} +args = some_more_arguments +with torch.no_grad(): + output = model(*args) +``` + +When finished, all the data will be on the last GPU, which you can use the [`PartialState`] to find and extract: + +```{python} +from accelerate import PartialState + +if PartialState().is_last_process: + print(output) +``` + +And that's it! To explore more, please check out the examples in [this repository](https://github.com/muellerzr/pippy-device-map-playground/) and our documentation as we work to improving this integration. diff --git a/src/accelerate/__init__.py b/src/accelerate/__init__.py index 6d45e3a483a..65791f2d106 100644 --- a/src/accelerate/__init__.py +++ b/src/accelerate/__init__.py @@ -11,6 +11,7 @@ load_checkpoint_and_dispatch, ) from .data_loader import skip_first_batches +from .inference import prepare_pippy from .launchers import debug_launcher, notebook_launcher from .state import PartialState from .utils import ( diff --git a/src/accelerate/inference.py b/src/accelerate/inference.py new file mode 100644 index 00000000000..13ac7402966 --- /dev/null +++ b/src/accelerate/inference.py @@ -0,0 +1,164 @@ +import math +from types import MethodType +from typing import Any, Dict, Optional + +from .state import PartialState +from .utils import ( + calculate_maximum_sizes, + convert_bytes, + ignorant_find_batch_size, + infer_auto_device_map, + is_pippy_available, + pad_input_tensors, + send_to_device, +) + + +if is_pippy_available(): + from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points + from pippy.PipelineStage import PipelineStage + + +def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None): + """ + Calculates the device map for `model` with an offset for PiPPy + """ + if num_processes == 1: + return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) + if max_memory is None: + model_size, shared = calculate_maximum_sizes(model) + + # Split into `n` chunks for each GPU + memory = (model_size + shared[0]) / num_processes + memory = convert_bytes(memory) + value, ending = memory.split(" ") + + # Add a chunk to deal with potential extra shared memory instances + memory = math.ceil(float(value)) * 1.1 + memory = f"{memory} {ending}" + max_memory = {i: memory for i in range(num_processes)} + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + clean_result=False, + ) + return device_map + + +def find_pippy_batch_size(args, kwargs): + found_batch_size = None + for arg in args: + found_batch_size = ignorant_find_batch_size(arg) + if found_batch_size is not None: + break + for kwarg in kwargs.values(): + found_batch_size = ignorant_find_batch_size(kwarg) + if found_batch_size is not None: + break + return found_batch_size + + +def build_pipeline(model, split_points, args, kwargs, num_chunks): + """ + Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing + in needed `args` and `kwargs` as the model needs on the CPU. + + Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use + `AcceleratorState.num_processes` + """ + # We need to annotate the split points in the model for PiPPy + state = PartialState() + annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points}) + found_batch_size = find_pippy_batch_size(args, kwargs) + if found_batch_size != num_chunks: + args = pad_input_tensors(args, found_batch_size, num_chunks) + kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) + pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs) + stage = PipelineStage(pipe, state.local_process_index, device=state.device) + + return stage + + +def pippy_forward(forward, num_chunks, *args, **kwargs): + state = PartialState() + output = None + + if state.num_processes == 1: + output = forward(*args, **kwargs) + elif state.is_local_main_process: + found_batch_size = find_pippy_batch_size(args, kwargs) + if found_batch_size is None: + raise ValueError("Could not find batch size from args or kwargs") + else: + if found_batch_size != num_chunks: + args = pad_input_tensors(args, found_batch_size, num_chunks) + kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) + forward(*args, **kwargs) + elif state.is_last_process: + output = forward() + else: + forward() + return output + + +def prepare_pippy( + model, + split_points="auto", + no_split_module_classes=None, + example_args=(), + example_kwargs: Optional[Dict[str, Any]] = None, + num_chunks=None, +): + """ + Wraps `model` for PipelineParallelism + + Args: + model (`torch.nn.Module`): + A model we want to split for pipeline-parallel inference + split_points (`str`, defaults to 'auto'): + How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced + split given any model. + no_split_module_classes (`List[str]`): + A list of class names for layers we don't want to be split. + example_args (tuple of `torch.Tensor`): + The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible. + example_kwargs (dict of `torch.Tensor`) + The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure + that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition + is true for all cases. + num_chunks (`int`): + The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but + this can be tuned and played with. In general one should have num_chunks >= num_gpus. + """ + if not is_pippy_available(): + raise ImportError( + "`pippy` was not found to be installed on your system. Please " + "install using `pip install torchpippy` or ensure you have at least version 0.2.0" + ) + state = PartialState() + example_args = send_to_device(example_args, "cpu") + example_kwargs = send_to_device(example_kwargs, "cpu") + if num_chunks is None: + num_chunks = state.num_processes + if split_points == "auto": + device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes) + split_points = [] + for i in range(1, num_chunks): + split_points.append(next(k for k, v in device_map.items() if v == i)) + model.hf_split_points = split_points + stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks) + model._original_forward = model.forward + model._original_call = model.__call__ + model.pippy_stage = stage + model.hf_split_points = split_points + + def forward(*args, **kwargs): + return pippy_forward(stage.forward, num_chunks, *args, **kwargs) + + # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` + # Note: creates an infinite recursion loop with `generate` + model_forward = MethodType(forward, model) + forward.__wrapped__ = model_forward + model.forward = forward + return model diff --git a/src/accelerate/test_utils/__init__.py b/src/accelerate/test_utils/__init__.py index 2d7e8b177ad..0dbee933ed3 100644 --- a/src/accelerate/test_utils/__init__.py +++ b/src/accelerate/test_utils/__init__.py @@ -13,6 +13,7 @@ require_multi_gpu, require_multi_xpu, require_non_cpu, + require_pippy, require_single_device, require_single_gpu, require_single_xpu, diff --git a/src/accelerate/test_utils/scripts/external_deps/test_pippy.py b/src/accelerate/test_utils/scripts/external_deps/test_pippy.py new file mode 100644 index 00000000000..9c2e7bb32de --- /dev/null +++ b/src/accelerate/test_utils/scripts/external_deps/test_pippy.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torchvision.models import resnet34 +from transformers import ( + BertConfig, + BertForMaskedLM, + GPT2Config, + GPT2ForSequenceClassification, + T5Config, + T5ForConditionalGeneration, +) + +from accelerate import PartialState +from accelerate.inference import prepare_pippy +from accelerate.utils import DistributedType, send_to_device, set_seed + + +model_to_config = { + "t5": (T5ForConditionalGeneration, T5Config, 1024), + "bert": (BertForMaskedLM, BertConfig, 512), + "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024), +} + + +def get_model_and_data_for_text(model_name, device, num_processes: int = 2): + initializer, config, seq_len = model_to_config[model_name] + config_args = {} + # Eventually needed for batch inference tests on gpt-2 when bs != 1 + # if model_name == "gpt2": + # config_args["pad_token_id"] = 0 + model_config = config(**config_args) + model = initializer(model_config) + return model, torch.randint( + low=0, + high=model_config.vocab_size, + size=(num_processes, seq_len), + device=device, + dtype=torch.int64, + requires_grad=False, + ) + + +def test_gpt2(batch_size: int = 2): + set_seed(42) + state = PartialState() + model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size) + model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules) + # For inference args need to be a tuple + inputs = inputs.to("cuda") + with torch.no_grad(): + output = model(inputs) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +def test_t5(batch_size: int = 2): + set_seed(42) + state = PartialState() + model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size) + example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs} + model = prepare_pippy( + model, + no_split_module_classes=model._no_split_modules, + example_kwargs=example_inputs, + ) + # For inference args need to be a tuple + inputs = send_to_device(example_inputs, "cuda:0") + with torch.no_grad(): + output = model(*inputs.values()) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +def test_resnet(batch_size: int = 2): + set_seed(42) + state = PartialState() + model = resnet34() + input_tensor = torch.rand(batch_size, 3, 224, 224) + model = prepare_pippy( + model, + example_args=(input_tensor,), + ) + inputs = send_to_device(input_tensor, "cuda:0") + with torch.no_grad(): + output = model(inputs) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +if __name__ == "__main__": + state = PartialState() + state.print("Testing pippy integration...") + if state.distributed_type == DistributedType.MULTI_GPU: + state.print("Testing GPT2...") + test_gpt2() + # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue + # due to references + # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope + # test_gpt2(3) + state.print("Testing T5...") + test_t5() + test_t5(1) + test_t5(3) + state.print("Testing CV model...") + test_resnet() + test_resnet(3) + else: + print("Less than two GPUs found, not running tests!") diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 90b5b8d9798..cdeee5fa995 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -40,6 +40,7 @@ is_mps_available, is_npu_available, is_pandas_available, + is_pippy_available, is_tensorboard_available, is_timm_available, is_torch_version, @@ -290,6 +291,13 @@ def require_pandas(test_case): return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) +def require_pippy(test_case): + """ + Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed + """ + return unittest.skipUnless(is_pippy_available(), "test requires pippy")(test_case) + + _atleast_one_tracker_available = ( any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available() ) diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 121b9ffca77..4b3a5efa6ea 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -70,6 +70,7 @@ is_npu_available, is_pandas_available, is_peft_available, + is_pippy_available, is_rich_available, is_sagemaker_available, is_tensorboard_available, @@ -126,12 +127,14 @@ gather_object, get_data_structure, honor_type, + ignorant_find_batch_size, initialize_tensors, is_namedtuple, is_tensor_information, is_torch_tensor, listify, pad_across_processes, + pad_input_tensors, recursively_apply, reduce, send_to_device, diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index f2147e75986..7a0947b4d96 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -38,12 +38,13 @@ _torch_distributed_available = torch.distributed.is_available() -def _is_package_available(pkg_name): +def _is_package_available(pkg_name, metadata_name=None): # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version package_exists = importlib.util.find_spec(pkg_name) is not None if package_exists: try: - _ = importlib.metadata.metadata(pkg_name) + # Some libraries have different names in the metadata + _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name) return True except importlib.metadata.PackageNotFoundError: return False @@ -73,15 +74,7 @@ def get_ccl_version(): def is_msamp_available(): - package_exists = importlib.util.find_spec("msamp") is not None - if package_exists: - try: - # MS-AMP has a different metadata name - _ = importlib.metadata.metadata("ms-amp") - return True - except importlib.metadata.PackageNotFoundError: - return False - return False + return _is_package_available("msamp", "ms-amp") def is_transformer_engine_available(): @@ -126,6 +119,14 @@ def is_deepspeed_available(): return _is_package_available("deepspeed") +def is_pippy_available(): + package_exists = _is_package_available("pippy", "torchpippy") + if package_exists: + pippy_version = version.parse(importlib.metadata.version("torchpippy")) + return compare_versions(pippy_version, ">", "0.1.1") + return False + + def is_bf16_available(ignore_tpu=False): "Checks if bf16 is supported, optionally ignoring the TPU" if is_tpu_available(): diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 27b2dc46457..e6b2350b2c0 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -263,6 +263,23 @@ def find_batch_size(data): return data.shape[0] +def ignorant_find_batch_size(data): + """ + Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size. + + Returns: + `int`: The batch size. + """ + try: + return find_batch_size(data) + except (ValueError, TypeError): + pass + return None + + def listify(data): """ Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers. @@ -606,6 +623,46 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): ) +def pad_input_tensors(tensor, batch_size, num_processes, dim=0): + """ + Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions. + + New tensors are just the last input repeated. + + E.g.: + Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4]) + + """ + + def _pad_input_tensors(tensor, batch_size, num_processes, dim=0): + remainder = batch_size // num_processes + last_inputs = batch_size - (remainder * num_processes) + if batch_size // num_processes == 0: + to_pad = num_processes - batch_size + else: + to_pad = num_processes - (batch_size // num_processes) + # In the rare case that `to_pad` is negative, + # we need to pad the last inputs - the found `to_pad` + if last_inputs > to_pad & to_pad < 1: + to_pad = last_inputs - to_pad + old_size = tensor.shape + new_size = list(old_size) + new_size[0] = batch_size + to_pad + new_tensor = tensor.new_zeros(tuple(new_size)) + indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size))) + new_tensor[indices] = tensor + return new_tensor + + return recursively_apply( + _pad_input_tensors, + tensor, + error_on_other_type=True, + batch_size=batch_size, + num_processes=num_processes, + dim=dim, + ) + + @verify_operation def reduce(tensor, reduction="mean", scale=1.0): """ diff --git a/tests/test_multigpu.py b/tests/test_multigpu.py index 140ed7f8247..f9cb472491e 100644 --- a/tests/test_multigpu.py +++ b/tests/test_multigpu.py @@ -21,7 +21,14 @@ import accelerate from accelerate import Accelerator from accelerate.big_modeling import dispatch_model -from accelerate.test_utils import assert_exception, device_count, execute_subprocess_async, require_multi_device +from accelerate.test_utils import ( + assert_exception, + device_count, + execute_subprocess_async, + require_multi_device, + require_multi_gpu, + require_pippy, +) from accelerate.utils import patch_environment @@ -66,6 +73,23 @@ def test_distributed_data_loop(self): with patch_environment(omp_num_threads=1, cuda_visible_devices="0,1"): execute_subprocess_async(cmd, env=os.environ.copy()) + @require_multi_gpu + @require_pippy + def test_pippy(self): + """ + Checks the integration with the pippy framework + """ + print(f"Found {torch.cuda.device_count()} devices") + cmd = [ + "accelerate", + "launch", + "--multi_gpu", + f"--num_processes={torch.cuda.device_count()}", + self.pippy_file_path, + ] + with patch_environment(omp_num_threads=1): + execute_subprocess_async(cmd, env=os.environ.copy()) + if __name__ == "__main__": accelerator = Accelerator() diff --git a/tests/test_utils.py b/tests/test_utils.py index 239214bfc3c..4658e44ef1a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import pickle import tempfile @@ -34,6 +33,7 @@ find_device, listify, pad_across_processes, + pad_input_tensors, patch_environment, recursively_apply, save, @@ -237,3 +237,68 @@ def test_pad_across_processes(self): with self.assertWarns(CannotPadNestedTensorWarning): nt2 = pad_across_processes(nt) self.assertIs(nt, nt2) + + def test_slice_and_concatenate(self): + # First base case: 2 processes, batch size of 1 + num_processes = 2 + batch_size = 1 + batch = torch.rand(batch_size, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 2 items now + assert result.shape == torch.Size([2, 4]) + + # Second base case: 2 processes, batch size of 3 + num_processes = 2 + batch_size = 3 + batch = torch.rand(batch_size, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 4 items now + assert result.shape == torch.Size([4, 4]) + + # Third base case: 3 processes, batch size of 4 + num_processes = 3 + batch_size = 4 + batch = torch.rand(batch_size, 4, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 6 items now + assert result.shape == torch.Size([6, 4, 4]) + + # Fourth base case: 4 processes, batch size of 3 + num_processes = 4 + batch_size = 3 + batch = torch.rand(batch_size, 4, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 4 items now + assert result.shape == torch.Size([4, 4, 4]) + + # Fifth base case: 6 processes, batch size of 4 + num_processes = 6 + batch_size = 4 + batch = torch.rand(batch_size, 4, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 6 items now + assert result.shape == torch.Size([6, 4, 4]) + + # Sixth base case: 6 processes, batch size of 1 + num_processes = 6 + batch_size = 1 + batch = torch.rand(batch_size, 4, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 6 items now + assert result.shape == torch.Size([6, 4, 4]) + + # Seventh base case: 6 processes, batch size of 2 + num_processes = 6 + batch_size = 2 + batch = torch.rand(batch_size, 4, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 6 items now + assert result.shape == torch.Size([6, 4, 4]) + + # Eighth base case: 6 processes, batch size of 61 + num_processes = 6 + batch_size = 61 + batch = torch.rand(batch_size, 4, 4) + result = pad_input_tensors(batch, batch_size, num_processes) + # We should expect there to be 66 items now + assert result.shape == torch.Size([66, 4, 4])