Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch-native pipeline parallelism for big models #2345

Merged
merged 39 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e713e28
Broken version
muellerzr Jan 16, 2024
2767bb1
Timing I would expect
muellerzr Jan 16, 2024
06f04a9
Working version!
muellerzr Jan 16, 2024
9eef9dd
Use MethodType
muellerzr Jan 17, 2024
449eb8d
working test
muellerzr Jan 17, 2024
77f8e92
Tests
muellerzr Jan 17, 2024
df7779a
Use no split module classes explicitly
muellerzr Jan 17, 2024
e3f6b99
Put split_points in pipelien
muellerzr Jan 17, 2024
8792a8c
Store split points in hf_split_points
muellerzr Jan 17, 2024
7ca4bcc
fix case num_process=1
SunMarc Jan 17, 2024
dac1daa
Allow for dynamic batch padding (#2352)
muellerzr Jan 25, 2024
364c3b6
Rm literal
muellerzr Jan 25, 2024
6a8479b
Allow users to pass in max_memory
muellerzr Jan 25, 2024
303c9cc
Note about recursion
muellerzr Jan 25, 2024
d497e8a
Document, document, document
muellerzr Jan 25, 2024
06bbc5b
Right import check
muellerzr Jan 26, 2024
5e047da
Merge branch 'main' into pippy-integration-v2
muellerzr Jan 26, 2024
a5059e6
Fix bug, add tests to multigpu runners
muellerzr Jan 26, 2024
71346a1
Change default to None
muellerzr Jan 26, 2024
fe66b93
Start of docs
muellerzr Feb 5, 2024
d2af472
Try again?
muellerzr Feb 5, 2024
8dc6c6c
Try again x2
muellerzr Feb 5, 2024
4d0aeb2
Trailing comma
muellerzr Feb 5, 2024
309b71a
Move import
muellerzr Feb 5, 2024
9f561f1
Clean
muellerzr Feb 5, 2024
d5a6fda
typehint
muellerzr Feb 5, 2024
954a668
typo
muellerzr Feb 5, 2024
853f552
From code review
muellerzr Feb 5, 2024
1362e5c
Use num_chunks
muellerzr Feb 5, 2024
68bd89b
Update tests/test_utils.py
muellerzr Feb 5, 2024
181fbda
Bad copy/paste
muellerzr Feb 5, 2024
9157cf1
hf_split_points
muellerzr Feb 6, 2024
f2c6e08
Apply suggestions from code review
muellerzr Feb 6, 2024
9f20496
Year
muellerzr Feb 6, 2024
e1961d6
Nit
muellerzr Feb 6, 2024
8c72a5e
better title
muellerzr Feb 6, 2024
3eaa967
Rephrase
muellerzr Feb 6, 2024
31fcde4
Rephrase
muellerzr Feb 6, 2024
7c3d183
Try spacing maybe?
muellerzr Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
title: Logging
- local: package_reference/big_modeling
title: Working with large models
- local: package_reference/inference
title: Distributed inference with big models
- local: package_reference/kwargs
title: Kwargs handlers
- local: package_reference/utilities
Expand Down
20 changes: 20 additions & 0 deletions docs/source/package_reference/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<!--Copyright 2021 The HuggingFace Team. All rights reserved.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# The inference API

These docs refer to the [PiPPy](https://github.com/PyTorch/PiPPy) integration.

[[autodoc]] inference.prepare_pippy
99 changes: 93 additions & 6 deletions docs/source/usage_guides/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ rendered properly in your Markdown viewer.

# Distributed Inference with 🤗 Accelerate

Distributed inference is a common use case, especially with natural language processing (NLP) models. Users often want to
send a number of different prompts, each to a different GPU, and then get the results back. This also has other cases
outside of just NLP, however for this tutorial we will focus on just this idea of each GPU receiving a different prompt,
and then returning the results.
Distributed inference can fall into three brackets:

## The Problem
1. Loading an entire model onto each GPU and sending chunks of a batch through each model at a time
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
2. Load parts of a model onto each GPU and process a single input at one time
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
3. Load parts of a model onto each GPU and use what is called scheduled Pipeline Parallelism to combine the two prior techniques.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

We're going to go through the first and the last, showcasing how to do each as they are more realistic scenarios.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


## Sending chunks of inputs automatically to each loaded model
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

This is the most memory-intensive solution, as it requires each GPU keeps a full copy of the model in memory at a given time.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

Normally when doing this, users send the model to a specific device to load it from the CPU, and then move each prompt to a different device.

Expand Down Expand Up @@ -55,7 +61,6 @@ a simple way to manage this. (To learn more, check out the relevant section in t

Can it manage it? Yes. Does it add unneeded extra code however: also yes.

## The Solution

With 🤗 Accelerate, we can simplify this process by using the [`Accelerator.split_between_processes`] context manager (which also exists in `PartialState` and `AcceleratorState`).
This function will automatically split whatever data you pass to it (be it a prompt, a set of tensors, a dictionary of the prior data, etc.) across all the processes (with a potential
Expand Down Expand Up @@ -134,3 +139,85 @@ with distributed_state.split_between_processes(["a dog", "a cat", "a chicken"],

On the first GPU, the prompts will be `["a dog", "a cat"]`, and on the second GPU it will be `["a chicken", "a chicken"]`.
Make sure to drop the final sample, as it will be a duplicate of the previous one.

## A more memory-efficient version (experimental)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

This next part will discuss using *pipeline parallelism*. This is an **experimental** API utilizing the [PiPPy library by PyTorch](https://github.com/pytorch/PiPPy/) as a native solution.

The general idea with pipeline parallism is say you have 4 GPUs, and a model big enough it can be *split* on four GPUs using `device_map="auto"`. What this version will do is you can send in 4 inputs at at time (for example here, any amount works) and each model chunk will work on an input, then recieve the next input after the prior chunk finished it making it *much* more efficient **and faster** than the prior version. Here's a visual taken from the PyTorch repository:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

![PiPPy example](https://camo.githubusercontent.com/681d7f415d6142face9dd1b837bdb2e340e5e01a58c3a4b119dea6c0d99e2ce0/68747470733a2f2f692e696d6775722e636f6d2f657955633934372e706e67)

To use this with Accelerate, we have created a [model zoo](https://github.com/muellerzr/pippy-device-map-playground/) showcasing a number of different models and situations to do so. In this tutorial we'll take GPT2 however across two gpus.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

Before anything, please make sure you have the latest pippy installed by performing:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

```bash
pip install torchpippy
```

We require at least version 0.2.0, please perform `pip show torchpippy` to check this!
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

First we need to create the model on the CPU:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

```{python}
from transformers import GPT2ForSequenceClassification, GPT2Config

config = GPT2Config()
model = GPT2ForSequenceClassification(config)
model.eval()
```

Next we need to create some example inputs to use. These help PiPPy trace the model.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

<Tip warning={true}>
However you make this example will determine the relative batch size that will be used/passed
through the model at a given time, so make sure to remember them!
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
</Tip>

```{python}
input = torch.randint(
low=0,
high=config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
```
Next we need to actually perform the tracing and get the model ready. To do so you simply use the [`inference.prepare_pippy`] function and it will fully wrap the model for pipeline parallism automatically:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

```{python}
from accelerate.inference import prepare_pippy
example_inputs = {"input_ids": input}
model = prepare_pippy(model, example_args=(input,))
```

<Tip>
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
There are a variety of parameters you can pass through to `prepare_pippy`:
* `split_points` will let you determine where to split the model at. By default we use wherever `device_map="auto" declares
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
* `num_chunks` can be used to determine how the batch will be split and sent to the model itself (so `num_chunks=1` with four split points/four GPUs would have a naive MP where a single input gets passed between the four layer split points)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
</Tip>

From here all that's left is to actually perform the distributed inference!
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

<Tip warning={true}>
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
When passing in inputs, while using `kwargs` are supported currently those are even *more* experimental, so it's highly recommended to just simply pass inputs in as a tuple of arguments.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
</Tip>

```{python}
args = some_more_arguments
with torch.no_grad():
output = model(*args)
```

Afterwards all the data will be on the last GPU, which you can use the [`PartialState`] to find and extract:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

```{python}
from accelerate import PartialState

if PartialState().is_last_process:
print(output)
```

And that's it! To explore more, please check out the examples in [this repository](https://github.com/muellerzr/pippy-device-map-playground/) and our documentation as we work to improving this integration.
1 change: 1 addition & 0 deletions src/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
load_checkpoint_and_dispatch,
)
from .data_loader import skip_first_batches
from .inference import prepare_pippy
from .launchers import debug_launcher, notebook_launcher
from .state import PartialState
from .utils import (
Expand Down
163 changes: 163 additions & 0 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import math
from types import MethodType
from typing import Any, Dict, Optional

from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
ignorant_find_batch_size,
infer_auto_device_map,
is_pippy_available,
pad_input_tensors,
send_to_device,
)


if is_pippy_available():
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage


def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
"""
Calculates the device map for `model` with an offset for PiPPy
"""
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
if max_memory is None:
model_size, shared = calculate_maximum_sizes(model)

# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")

# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
max_memory = {i: memory for i in range(num_processes)}
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
return device_map


def find_pippy_batch_size(args, kwargs):
found_batch_size = None
for arg in args:
found_batch_size = ignorant_find_batch_size(arg)
if found_batch_size is not None:
break
for kwarg in kwargs.values():
found_batch_size = ignorant_find_batch_size(kwarg)
if found_batch_size is not None:
break
return found_batch_size


def build_pipeline(model, split_points, args, kwargs, num_chunks):
"""
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
in needed `args` and `kwargs` as the model needs on the CPU.

Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
`AcceleratorState.num_processes`
"""
# We need to annotate the split points in the model for PiPPy
state = PartialState()
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, num_chunks)
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)

return stage


def pippy_forward(forward, num_chunks, *args, **kwargs):
state = PartialState()
output = None

if state.num_processes == 1:
output = forward(*args, **kwargs)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
elif state.is_local_main_process:
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size is None:
raise ValueError("Could not find batch size from args or kwargs")
else:
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, state.num_processes)
kwargs = pad_input_tensors(kwargs, found_batch_size, state.num_processes)
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
else:
forward()
return output


def prepare_pippy(
model,
split_points="auto",
no_split_module_classes=None,
example_args=(),
example_kwargs: Optional[Dict[str, Any]] = None,
num_chunks=None,
):
"""
Wraps `model` for PipelineParallelism

Args:
model (`torch.nn.Module`):
A model we want to split for pipeline-parallel inference
split_points (`str`, defaults to 'auto'):
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
split given any model.
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of `torch.Tensor`):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
example_kwargs (dict of `torch.Tensor`)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
"""
if not is_pippy_available():
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
)
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if num_chunks is None:
num_chunks = state.num_processes
if split_points == "auto":
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
split_points = []
for i in range(1, num_chunks):
split_points.append(next(k for k, v in device_map.items() if v == i))
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
model._original_forward = model.forward
model._original_call = model.__call__
model.pippy_stage = stage
model.hf_split_points = split_points

def forward(*args, **kwargs):
return pippy_forward(stage.forward, num_chunks, *args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
model_forward = MethodType(forward, model)
forward.__wrapped__ = model_forward
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
model.forward = forward
return model
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
require_multi_gpu,
require_multi_xpu,
require_non_cpu,
require_pippy,
require_single_device,
require_single_gpu,
require_single_xpu,
Expand Down
Loading
Loading