Skip to content

Commit

Permalink
torch-native pipeline parallelism for big models (#2345)
Browse files Browse the repository at this point in the history
* Broken version

* Timing I would expect

* Working version!

* Use MethodType

* working test

* Tests

* Use no split module classes explicitly

* Put split_points in pipelien

* Store split points in hf_split_points

* fix case num_process=1

* Allow for dynamic batch padding (#2352)

* Allow for dynamic batch paddign

* Fix test

* Update src/accelerate/inference.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Break early after the first valid bs is found

* Less slicy-dicy

* Test cv model

* Start, need to test

* Use dataloader-like logic

* Refactor to utils

* With tests

* Update the source

* Clean

* bs=1 case

* Add test

* add some failing test

* Almost working version

* Much cleaner implementation

* Use pad_input_tensor

* All tests passing!

* Do it at tracing too

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <marc@huggingface.co>

* Rm literal

* Allow users to pass in max_memory

* Note about recursion

* Document, document, document

* Right import check

* Fix bug, add tests to multigpu runners

* Change default to None

* Start of docs

* Try again?

* Try again x2

* Trailing comma

* Move import

* Clean

* typehint

* typo

* From code review

* Use num_chunks

* Update tests/test_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Bad copy/paste

* hf_split_points

---------

Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 6, 2024
1 parent 0e1ee4b commit 0867c09
Show file tree
Hide file tree
Showing 13 changed files with 587 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
title: Logging
- local: package_reference/big_modeling
title: Working with large models
- local: package_reference/inference
title: Distributed inference with big models
- local: package_reference/kwargs
title: Kwargs handlers
- local: package_reference/utilities
Expand Down
20 changes: 20 additions & 0 deletions docs/source/package_reference/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
鈿狅笍 Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# The inference API

These docs refer to the [PiPPy](https://github.com/PyTorch/PiPPy) integration.

[[autodoc]] inference.prepare_pippy
104 changes: 98 additions & 6 deletions docs/source/usage_guides/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ rendered properly in your Markdown viewer.

# Distributed Inference with 馃 Accelerate

Distributed inference is a common use case, especially with natural language processing (NLP) models. Users often want to
send a number of different prompts, each to a different GPU, and then get the results back. This also has other cases
outside of just NLP, however for this tutorial we will focus on just this idea of each GPU receiving a different prompt,
and then returning the results.
Distributed inference can fall into three brackets:

## The Problem
1. Loading an entire model onto each GPU and sending chunks of a batch through each GPU's model copy at a time
2. Loading parts of a model onto each GPU and processing a single input at one time
3. Loading parts of a model onto each GPU and using what is called scheduled Pipeline Parallelism to combine the two prior techniques.

We're going to go through the first and the last bracket, showcasing how to do each as they are more realistic scenarios.


## Sending chunks of a batch automatically to each loaded model

This is the most memory-intensive solution, as it requires each GPU to keep a full copy of the model in memory at a given time.

Normally when doing this, users send the model to a specific device to load it from the CPU, and then move each prompt to a different device.

Expand Down Expand Up @@ -55,7 +61,6 @@ a simple way to manage this. (To learn more, check out the relevant section in t

Can it manage it? Yes. Does it add unneeded extra code however: also yes.

## The Solution

With 馃 Accelerate, we can simplify this process by using the [`Accelerator.split_between_processes`] context manager (which also exists in `PartialState` and `AcceleratorState`).
This function will automatically split whatever data you pass to it (be it a prompt, a set of tensors, a dictionary of the prior data, etc.) across all the processes (with a potential
Expand Down Expand Up @@ -134,3 +139,90 @@ with distributed_state.split_between_processes(["a dog", "a cat", "a chicken"],

On the first GPU, the prompts will be `["a dog", "a cat"]`, and on the second GPU it will be `["a chicken", "a chicken"]`.
Make sure to drop the final sample, as it will be a duplicate of the previous one.

## Memory-efficient pipeline parallelism (experimental)

This next part will discuss using *pipeline parallelism*. This is an **experimental** API utilizing the [PiPPy library by PyTorch](https://github.com/pytorch/PiPPy/) as a native solution.

The general idea with pipeline parallelism is: say you have 4 GPUs and a model big enough it can be *split* on four GPUs using `device_map="auto"`. With this method you can send in 4 inputs at a time (for example here, any amount works) and each model chunk will work on an input, then receive the next input once the prior chunk finished, making it *much* more efficient **and faster** than the method described earlier. Here's a visual taken from the PyTorch repository:

![PiPPy example](https://camo.githubusercontent.com/681d7f415d6142face9dd1b837bdb2e340e5e01a58c3a4b119dea6c0d99e2ce0/68747470733a2f2f692e696d6775722e636f6d2f657955633934372e706e67)

To illustrate how you can use this with Accelerate, we have created a [model zoo example](https://github.com/muellerzr/pippy-device-map-playground/) showcasing a number of different models and situations. In this tutorial, we'll show this method for GPT2 across two GPUs.

Before you proceed, please make sure you have the latest pippy installed by running the following:

```bash
pip install torchpippy
```

We require at least version 0.2.0. To confirm that you have the correct version, run `pip show torchpippy`.

Start by creating the model on the CPU:

```{python}
from transformers import GPT2ForSequenceClassification, GPT2Config
config = GPT2Config()
model = GPT2ForSequenceClassification(config)
model.eval()
```

Next you'll need to create some example inputs to use. These help PiPPy trace the model.

<Tip warning={true}>
However you make this example will determine the relative batch size that will be used/passed
through the model at a given time, so make sure to remember how many items there are!
</Tip>

```{python}
input = torch.randint(
low=0,
high=config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
```
Next we need to actually perform the tracing and get the model ready. To do so, use the [`inference.prepare_pippy`] function and it will fully wrap the model for pipeline parallelism automatically:

```{python}
from accelerate.inference import prepare_pippy
example_inputs = {"input_ids": input}
model = prepare_pippy(model, example_args=(input,))
```

<Tip>

There are a variety of parameters you can pass through to `prepare_pippy`:

* `split_points` lets you determine what layers to split the model at. By default we use wherever `device_map="auto" declares, such as `fc` or `conv1`.

* `num_chunks` determines how the batch will be split and sent to the model itself (so `num_chunks=1` with four split points/four GPUs will have a naive MP where a single input gets passed between the four layer split points)

</Tip>

From here, all that's left is to actually perform the distributed inference!

<Tip warning={true}>

When passing inputs, we highly recommend to pass them in as a tuple of arguments. Using `kwargs` is supported, however, this approach is experimental.
</Tip>

```{python}
args = some_more_arguments
with torch.no_grad():
output = model(*args)
```

When finished, all the data will be on the last GPU, which you can use the [`PartialState`] to find and extract:

```{python}
from accelerate import PartialState
if PartialState().is_last_process:
print(output)
```

And that's it! To explore more, please check out the examples in [this repository](https://github.com/muellerzr/pippy-device-map-playground/) and our documentation as we work to improving this integration.
1 change: 1 addition & 0 deletions src/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
load_checkpoint_and_dispatch,
)
from .data_loader import skip_first_batches
from .inference import prepare_pippy
from .launchers import debug_launcher, notebook_launcher
from .state import PartialState
from .utils import (
Expand Down
164 changes: 164 additions & 0 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import math
from types import MethodType
from typing import Any, Dict, Optional

from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
ignorant_find_batch_size,
infer_auto_device_map,
is_pippy_available,
pad_input_tensors,
send_to_device,
)


if is_pippy_available():
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage


def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
"""
Calculates the device map for `model` with an offset for PiPPy
"""
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
if max_memory is None:
model_size, shared = calculate_maximum_sizes(model)

# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")

# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
max_memory = {i: memory for i in range(num_processes)}
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
return device_map


def find_pippy_batch_size(args, kwargs):
found_batch_size = None
for arg in args:
found_batch_size = ignorant_find_batch_size(arg)
if found_batch_size is not None:
break
for kwarg in kwargs.values():
found_batch_size = ignorant_find_batch_size(kwarg)
if found_batch_size is not None:
break
return found_batch_size


def build_pipeline(model, split_points, args, kwargs, num_chunks):
"""
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
in needed `args` and `kwargs` as the model needs on the CPU.
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
`AcceleratorState.num_processes`
"""
# We need to annotate the split points in the model for PiPPy
state = PartialState()
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, num_chunks)
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)

return stage


def pippy_forward(forward, num_chunks, *args, **kwargs):
state = PartialState()
output = None

if state.num_processes == 1:
output = forward(*args, **kwargs)
elif state.is_local_main_process:
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size is None:
raise ValueError("Could not find batch size from args or kwargs")
else:
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, num_chunks)
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
else:
forward()
return output


def prepare_pippy(
model,
split_points="auto",
no_split_module_classes=None,
example_args=(),
example_kwargs: Optional[Dict[str, Any]] = None,
num_chunks=None,
):
"""
Wraps `model` for PipelineParallelism
Args:
model (`torch.nn.Module`):
A model we want to split for pipeline-parallel inference
split_points (`str`, defaults to 'auto'):
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
split given any model.
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of `torch.Tensor`):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
example_kwargs (dict of `torch.Tensor`)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
"""
if not is_pippy_available():
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
)
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
if num_chunks is None:
num_chunks = state.num_processes
if split_points == "auto":
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
split_points = []
for i in range(1, num_chunks):
split_points.append(next(k for k, v in device_map.items() if v == i))
model.hf_split_points = split_points
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
model._original_forward = model.forward
model._original_call = model.__call__
model.pippy_stage = stage
model.hf_split_points = split_points

def forward(*args, **kwargs):
return pippy_forward(stage.forward, num_chunks, *args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
model_forward = MethodType(forward, model)
forward.__wrapped__ = model_forward
model.forward = forward
return model
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
require_multi_gpu,
require_multi_xpu,
require_non_cpu,
require_pippy,
require_single_device,
require_single_gpu,
require_single_xpu,
Expand Down
Loading

0 comments on commit 0867c09

Please sign in to comment.