Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch-native pipeline parallelism for big models #2345

Merged
merged 39 commits into from
Feb 6, 2024
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e713e28
Broken version
muellerzr Jan 16, 2024
2767bb1
Timing I would expect
muellerzr Jan 16, 2024
06f04a9
Working version!
muellerzr Jan 16, 2024
9eef9dd
Use MethodType
muellerzr Jan 17, 2024
449eb8d
working test
muellerzr Jan 17, 2024
77f8e92
Tests
muellerzr Jan 17, 2024
df7779a
Use no split module classes explicitly
muellerzr Jan 17, 2024
e3f6b99
Put split_points in pipelien
muellerzr Jan 17, 2024
8792a8c
Store split points in hf_split_points
muellerzr Jan 17, 2024
7ca4bcc
fix case num_process=1
SunMarc Jan 17, 2024
dac1daa
Allow for dynamic batch padding (#2352)
muellerzr Jan 25, 2024
364c3b6
Rm literal
muellerzr Jan 25, 2024
6a8479b
Allow users to pass in max_memory
muellerzr Jan 25, 2024
303c9cc
Note about recursion
muellerzr Jan 25, 2024
d497e8a
Document, document, document
muellerzr Jan 25, 2024
06bbc5b
Right import check
muellerzr Jan 26, 2024
5e047da
Merge branch 'main' into pippy-integration-v2
muellerzr Jan 26, 2024
a5059e6
Fix bug, add tests to multigpu runners
muellerzr Jan 26, 2024
71346a1
Change default to None
muellerzr Jan 26, 2024
fe66b93
Start of docs
muellerzr Feb 5, 2024
d2af472
Try again?
muellerzr Feb 5, 2024
8dc6c6c
Try again x2
muellerzr Feb 5, 2024
4d0aeb2
Trailing comma
muellerzr Feb 5, 2024
309b71a
Move import
muellerzr Feb 5, 2024
9f561f1
Clean
muellerzr Feb 5, 2024
d5a6fda
typehint
muellerzr Feb 5, 2024
954a668
typo
muellerzr Feb 5, 2024
853f552
From code review
muellerzr Feb 5, 2024
1362e5c
Use num_chunks
muellerzr Feb 5, 2024
68bd89b
Update tests/test_utils.py
muellerzr Feb 5, 2024
181fbda
Bad copy/paste
muellerzr Feb 5, 2024
9157cf1
hf_split_points
muellerzr Feb 6, 2024
f2c6e08
Apply suggestions from code review
muellerzr Feb 6, 2024
9f20496
Year
muellerzr Feb 6, 2024
e1961d6
Nit
muellerzr Feb 6, 2024
8c72a5e
better title
muellerzr Feb 6, 2024
3eaa967
Rephrase
muellerzr Feb 6, 2024
31fcde4
Rephrase
muellerzr Feb 6, 2024
7c3d183
Try spacing maybe?
muellerzr Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from types import MethodType
from typing import Any, Dict, Optional

from .state import PartialState
from .utils import (
Expand Down Expand Up @@ -79,7 +80,7 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
return stage


def pippy_forward(forward, *args, **kwargs):
def pippy_forward(forward, num_chunks, *args, **kwargs):
state = PartialState()
output = None

Expand All @@ -90,7 +91,7 @@ def pippy_forward(forward, *args, **kwargs):
if found_batch_size is None:
raise ValueError("Could not find batch size from args or kwargs")
else:
if found_batch_size != state.num_processes:
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, state.num_processes)
kwargs = pad_input_tensors(kwargs, found_batch_size, state.num_processes)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
forward(*args, **kwargs)
Expand All @@ -102,7 +103,12 @@ def pippy_forward(forward, *args, **kwargs):


def prepare_pippy(
model, split_points="auto", no_split_module_classes=None, example_args=(), example_kwargs={}, num_chunks=None
model,
split_points="auto",
no_split_module_classes=None,
example_args=(),
example_kwargs: Optional[Dict[str, Any]] = None,
num_chunks=None,
):
"""
Wraps `model` for PipelineParallelism
Expand All @@ -123,12 +129,12 @@ def prepare_pippy(
is true for all cases.
num_chunks (`int`):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks > num_gpus.
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
"""
if not is_pippy_available():
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install git+https://github.com/pytorch/PiPPy"
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
)
state = PartialState()
example_args = send_to_device(example_args, "cpu")
Expand All @@ -147,7 +153,7 @@ def prepare_pippy(
model.hf_split_points = split_points

def forward(*args, **kwargs):
return pippy_forward(stage.forward, *args, **kwargs)
return pippy_forward(stage.forward, num_chunks, *args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
Expand Down
Loading