Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch-native pipeline parallelism for big models #2345

Merged
merged 39 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e713e28
Broken version
muellerzr Jan 16, 2024
2767bb1
Timing I would expect
muellerzr Jan 16, 2024
06f04a9
Working version!
muellerzr Jan 16, 2024
9eef9dd
Use MethodType
muellerzr Jan 17, 2024
449eb8d
working test
muellerzr Jan 17, 2024
77f8e92
Tests
muellerzr Jan 17, 2024
df7779a
Use no split module classes explicitly
muellerzr Jan 17, 2024
e3f6b99
Put split_points in pipelien
muellerzr Jan 17, 2024
8792a8c
Store split points in hf_split_points
muellerzr Jan 17, 2024
7ca4bcc
fix case num_process=1
SunMarc Jan 17, 2024
dac1daa
Allow for dynamic batch padding (#2352)
muellerzr Jan 25, 2024
364c3b6
Rm literal
muellerzr Jan 25, 2024
6a8479b
Allow users to pass in max_memory
muellerzr Jan 25, 2024
303c9cc
Note about recursion
muellerzr Jan 25, 2024
d497e8a
Document, document, document
muellerzr Jan 25, 2024
06bbc5b
Right import check
muellerzr Jan 26, 2024
5e047da
Merge branch 'main' into pippy-integration-v2
muellerzr Jan 26, 2024
a5059e6
Fix bug, add tests to multigpu runners
muellerzr Jan 26, 2024
71346a1
Change default to None
muellerzr Jan 26, 2024
fe66b93
Start of docs
muellerzr Feb 5, 2024
d2af472
Try again?
muellerzr Feb 5, 2024
8dc6c6c
Try again x2
muellerzr Feb 5, 2024
4d0aeb2
Trailing comma
muellerzr Feb 5, 2024
309b71a
Move import
muellerzr Feb 5, 2024
9f561f1
Clean
muellerzr Feb 5, 2024
d5a6fda
typehint
muellerzr Feb 5, 2024
954a668
typo
muellerzr Feb 5, 2024
853f552
From code review
muellerzr Feb 5, 2024
1362e5c
Use num_chunks
muellerzr Feb 5, 2024
68bd89b
Update tests/test_utils.py
muellerzr Feb 5, 2024
181fbda
Bad copy/paste
muellerzr Feb 5, 2024
9157cf1
hf_split_points
muellerzr Feb 6, 2024
f2c6e08
Apply suggestions from code review
muellerzr Feb 6, 2024
9f20496
Year
muellerzr Feb 6, 2024
e1961d6
Nit
muellerzr Feb 6, 2024
8c72a5e
better title
muellerzr Feb 6, 2024
3eaa967
Rephrase
muellerzr Feb 6, 2024
31fcde4
Rephrase
muellerzr Feb 6, 2024
7c3d183
Try spacing maybe?
muellerzr Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import math
from functools import partial
from typing import Literal

from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage

from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
infer_auto_device_map,
send_to_device,
)


ParallelMode = Literal["sequential", "pipeline_parallel"]
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


def generate_device_map(model, num_processes: int = 1):
"""
Calculates the device map for `model` with an offset for PiPPy
"""
no_split_module_classes = getattr(model, "_no_split_modules", [])
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
model_size, shared = calculate_maximum_sizes(model)

# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")

# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
device_map = infer_auto_device_map(
model,
max_memory={i: memory for i in range(num_processes)},
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can definitely generate a balanced device_map for pippy exclusively "device_map = "balanced_pippy" if the current balanced option is not the best for that. However, I think it would be great if the user can use other options like "sequential". I didn't try but what happens when we only fill 2 gpus out of the 4 available (possible sequential case) ?

return device_map


def build_pipeline(model, device_map, args, kwargs) -> PipelineStage:
"""
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
in needed `args` and `kwargs` as the model needs on the CPU.
"""
# We need to annotate the split points in the model for PiPPy
state = PartialState()
split_points = []
for i in range(1, state.num_processes):
split_points.append(next(k for k, v in device_map.items() if v == i))
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
pipe = Pipe.from_tracing(model, num_chunks=state.num_processes, example_args=args, example_kwargs=kwargs)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
stage = PipelineStage(pipe, state.local_process_index, device=state.device)

return stage


def pippy_forward(forward, *args, **kwargs):
state = PartialState()
output = None
if state.is_local_main_process:
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
else:
forward()
return output


def prepare_pippy(model, device_map="auto", example_args=(), example_kwargs={}):
"""
Wraps `model` for PipelineParallelism
"""
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if device_map == "auto":
device_map = generate_device_map(model, PartialState().num_processes)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
stage = build_pipeline(model, device_map, example_args, example_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought about how to handle the split points.

    1. We only expose device_map with predefined options ("sequential", "balanced_pippy")
    1. We let the user use a custom device_map. For the custom case, it can be complicated since the user needs to be careful about the order (OrderedDict()) and he needs to attribute the gpu in a sequential manner because of that split_points.append(next(k for k, v in device_map.items() if v == i)). So that can be quite complicated.
    1. We let the user let his own split points List[str].
      I think that 1) is a must. between 2) and 3), I prefer 3) since it is easier for the user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed to do 1 and 3

model._original_forward = model.forward
model._original_call = model.__call__
model.pippy_stage = stage

model_forward = partial(pippy_forward, forward=model.pippy_stage.forward)

def forward(*args, **kwargs):
return model_forward(*args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
forward.__wrapped__ = model_forward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice !

model.forward = forward
return stage
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
is_msamp_available,
is_npu_available,
is_pandas_available,
is_pippy_available,
is_rich_available,
is_sagemaker_available,
is_tensorboard_available,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def is_deepspeed_available():
return _is_package_available("deepspeed")


def is_pippy_available():
return _is_package_available("torchpippy")


def is_bf16_available(ignore_tpu=False):
"Checks if bf16 is supported, optionally ignoring the TPU"
if is_tpu_available():
Expand Down
Loading