Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate pippy examples over and run tests #2424

Merged
merged 6 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions docs/source/usage_guides/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ The general idea with pipeline parallelism is: say you have 4 GPUs and a model b

![PiPPy example](https://camo.githubusercontent.com/681d7f415d6142face9dd1b837bdb2e340e5e01a58c3a4b119dea6c0d99e2ce0/68747470733a2f2f692e696d6775722e636f6d2f657955633934372e706e67)

To illustrate how you can use this with Accelerate, we have created a [model zoo example](https://github.com/muellerzr/pippy-device-map-playground/) showcasing a number of different models and situations. In this tutorial, we'll show this method for GPT2 across two GPUs.
To illustrate how you can use this with Accelerate, we have created an [example zoo](https://github.com/huggingface/accelerate/tree/main/examples/inference) showcasing a number of different models and situations. In this tutorial, we'll show this method for GPT2 across two GPUs.

Before you proceed, please make sure you have the latest pippy installed by running the following:

Expand Down Expand Up @@ -216,13 +216,10 @@ with torch.no_grad():
output = model(*args)
```

When finished, all the data will be on the last GPU, which you can use the [`PartialState`] to find and extract:
When finished, all the data will be on the CPU on each node:

muellerzr marked this conversation as resolved.
Show resolved Hide resolved
```{python}
from accelerate import PartialState

if PartialState().is_last_process:
print(output)
print(output)
```

And that's it! To explore more, please check out the examples in [this repository](https://github.com/muellerzr/pippy-device-map-playground/) and our documentation as we work to improving this integration.
And that's it! To explore more, please check out the inference examples in the [Accelerate repo](https://github.com/huggingface/accelerate/tree/main/examples/inference) and our [documentation](../package_reference/inference) as we work to improving this integration.
54 changes: 54 additions & 0 deletions examples/inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Distributed inference examples with PiPPy

This repo contains a variety of tutorials for using the [PiPPy](https://github.com/PyTorch/PiPPy) pipeline parallelism library with accelerate. You will find examples covering:

1. How to trace the model using `accelerate.prepare_pippy`
2. How to specify inputs based on what the model expects (when to use `kwargs`, `args`, and such)
3. How to gather the results at the end.

## Installation

This requires the `main` branch of accelerate (or a version at least 0.27.0) and `pippy` version of 0.2.0 or greater. Please install using `pip install .` to pull from the `setup.py` in this repo, or run manually:
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

```bash
pip install 'accelerate>=0.27.0' 'torchpippy>=0.2.0'
```

## Running code

You can either use `torchrun` or the recommended way of `accelerate launch` on each script:

```bash
accelerate launch bert.py
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
```

```bash
torchrun bert.py
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
```

## General speedups

One can expect that PiPPy will outperform native model parallism by a multiplicative factor since all GPUs are running at all times with inputs, rather than one input being passed through a GPU at a time waiting for the prior to finish.

Below are some benchmarks we have found when using the accelerate-pippy integration for a few models when running on 2x4090's:

### Bert

| | Accelerate/Sequential | PiPPy + Accelerate |
|---|---|---|
| First batch | 0.2137s | 0.3119s |
| Average of 5 batches | 0.0099s | **0.0062s** |

### GPT2

| | Accelerate/Sequential | PiPPy + Accelerate |
|---|---|---|
| First batch | 0.1959s | 0.4189s |
| Average of 5 batches | 0.0205s | **0.0126s** |

### T5

| | Accelerate/Sequential | PiPPy + Accelerate |
|---|---|---|
| First batch | 0.2789s | 0.3809s |
| Average of 5 batches | 0.0198s | **0.0166s** |
76 changes: 76 additions & 0 deletions examples/inference/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import torch
from transformers import AutoModelForMaskedLM

from accelerate import PartialState, prepare_pippy
from accelerate.utils import set_seed


# Set the random seed to have reproducable outputs
set_seed(42)

# Create an example model
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
model.eval()

# Input configs
# Create example inputs for the model
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)


# Create a pipeline stage from the model
# Using `auto` is equivalent to letting `device_map="auto"` figure
# out device mapping and will also split the model according to the
# number of total GPUs available if it fits on one GPU
model = prepare_pippy(model, split_points="auto", example_args=(input,))

# Move the inputs to the first device
input = input.to("cuda:0")

# Take an average of 5 times
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
# Measure first batch
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
output = model(input)
torch.cuda.synchronize()
end_time = time.time()
first_batch = end_time - start_time

# Now that CUDA is init, measure after
torch.cuda.synchronize()
start_time = time.time()
for i in range(5):
with torch.no_grad():
output = model(input)
torch.cuda.synchronize()
end_time = time.time()

# The outputs are on the CPU on each process,
# we print it once for posterity
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if PartialState().is_main_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time)/5}")
75 changes: 75 additions & 0 deletions examples/inference/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import torch
from transformers import AutoModelForSequenceClassification

from accelerate import PartialState, prepare_pippy
from accelerate.utils import set_seed


# Set the random seed to have reproducable outputs
set_seed(42)

# Create an example model
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
model.eval()

# Input configs
# Create example inputs for the model
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)

# Create a pipeline stage from the model
# Using `auto` is equivalent to letting `device_map="auto"` figure
# out device mapping and will also split the model according to the
# number of total GPUs available if it fits on one GPU
model = prepare_pippy(model, split_points="auto", example_args=(input,))

# Move the inputs to the first device
input = input.to("cuda:0")

# Take an average of 5 times
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
# Measure first batch
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
output = model(input)
torch.cuda.synchronize()
end_time = time.time()
first_batch = end_time - start_time

# Now that CUDA is init, measure after
torch.cuda.synchronize()
start_time = time.time()
for i in range(5):
with torch.no_grad():
output = model(input)
torch.cuda.synchronize()
end_time = time.time()

# The outputs are on the CPU on each process,
# we print it once for posterity
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if PartialState().is_main_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time)/5}")
52 changes: 52 additions & 0 deletions examples/inference/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from accelerate import PartialState, prepare_pippy


# sdpa implementation which is the default torch>2.1.2 fails with the tracing + attention mask kwarg
# with attn_implementation="eager" mode, the forward is very slow for some reason
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True, attn_implementation="sdpa"
)
model.eval()

# Input configs
# Create example inputs for the model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompts = ("I would like to", "I really like to", "The weather is") # bs = 3
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

# Create a pipeline stage from the model
# Using `auto` is equivalent to letting `device_map="auto"` figure
# out device mapping and will also split the model according to the
# number of total GPUs available if it fits on one GPU
model = prepare_pippy(model, split_points="auto", example_args=inputs)

# currently we don't support `model.generate`
# output = model.generate(**inputs, max_new_tokens=1)

with torch.no_grad():
output = model(**inputs)

# The outputs are on the CPU on each process,
# we print it once for posterity
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if PartialState().is_main_process:
next_token_logits = output[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))
2 changes: 2 additions & 0 deletions examples/inference/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
accelerate
pippy>=0.2.0
82 changes: 82 additions & 0 deletions examples/inference/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import torch
from transformers import AutoModelForSeq2SeqLM

from accelerate import PartialState, prepare_pippy
from accelerate.utils import set_seed


# Set the random seed to have reproducable outputs
set_seed(42)

# Create an example model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
model.eval()

# Input configs
# Create example inputs for the model
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)

example_inputs = {"input_ids": input, "decoder_input_ids": input}

# Create a pipeline stage from the model
# Using `auto` is equivalent to letting `device_map="auto"` figure
# out device mapping and will also split the model according to the
# number of total GPUs available if it fits on one GPU
model = prepare_pippy(
model,
no_split_module_classes=["T5Block"],
example_kwargs=example_inputs,
)

# The model expects a tuple during real inference
# with the data on the first device
args = (example_inputs["input_ids"].to("cuda:0"), example_inputs["decoder_input_ids"].to("cuda:0"))

# Take an average of 5 times
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
# Measure first batch
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
end_time = time.time()
first_batch = end_time - start_time

# Now that CUDA is init, measure after
torch.cuda.synchronize()
start_time = time.time()
for i in range(5):
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
end_time = time.time()

# The outputs are on the CPU on each process,
# we print it once for posterity
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if PartialState().is_main_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time)/5}")
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
extras["test_dev"] = [
"datasets",
"evaluate",
"torchpippy>=0.2.0",
"transformers",
"scipy",
"scikit-learn",
Expand Down
Loading
Loading