Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch-native pipeline parallelism for big models #2345

Merged
merged 39 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e713e28
Broken version
muellerzr Jan 16, 2024
2767bb1
Timing I would expect
muellerzr Jan 16, 2024
06f04a9
Working version!
muellerzr Jan 16, 2024
9eef9dd
Use MethodType
muellerzr Jan 17, 2024
449eb8d
working test
muellerzr Jan 17, 2024
77f8e92
Tests
muellerzr Jan 17, 2024
df7779a
Use no split module classes explicitly
muellerzr Jan 17, 2024
e3f6b99
Put split_points in pipelien
muellerzr Jan 17, 2024
8792a8c
Store split points in hf_split_points
muellerzr Jan 17, 2024
7ca4bcc
fix case num_process=1
SunMarc Jan 17, 2024
dac1daa
Allow for dynamic batch padding (#2352)
muellerzr Jan 25, 2024
364c3b6
Rm literal
muellerzr Jan 25, 2024
6a8479b
Allow users to pass in max_memory
muellerzr Jan 25, 2024
303c9cc
Note about recursion
muellerzr Jan 25, 2024
d497e8a
Document, document, document
muellerzr Jan 25, 2024
06bbc5b
Right import check
muellerzr Jan 26, 2024
5e047da
Merge branch 'main' into pippy-integration-v2
muellerzr Jan 26, 2024
a5059e6
Fix bug, add tests to multigpu runners
muellerzr Jan 26, 2024
71346a1
Change default to None
muellerzr Jan 26, 2024
fe66b93
Start of docs
muellerzr Feb 5, 2024
d2af472
Try again?
muellerzr Feb 5, 2024
8dc6c6c
Try again x2
muellerzr Feb 5, 2024
4d0aeb2
Trailing comma
muellerzr Feb 5, 2024
309b71a
Move import
muellerzr Feb 5, 2024
9f561f1
Clean
muellerzr Feb 5, 2024
d5a6fda
typehint
muellerzr Feb 5, 2024
954a668
typo
muellerzr Feb 5, 2024
853f552
From code review
muellerzr Feb 5, 2024
1362e5c
Use num_chunks
muellerzr Feb 5, 2024
68bd89b
Update tests/test_utils.py
muellerzr Feb 5, 2024
181fbda
Bad copy/paste
muellerzr Feb 5, 2024
9157cf1
hf_split_points
muellerzr Feb 6, 2024
f2c6e08
Apply suggestions from code review
muellerzr Feb 6, 2024
9f20496
Year
muellerzr Feb 6, 2024
e1961d6
Nit
muellerzr Feb 6, 2024
8c72a5e
better title
muellerzr Feb 6, 2024
3eaa967
Rephrase
muellerzr Feb 6, 2024
31fcde4
Rephrase
muellerzr Feb 6, 2024
7c3d183
Try spacing maybe?
muellerzr Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import math
from types import MethodType

from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage

from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
ignorant_find_batch_size,
infer_auto_device_map,
is_pippy_available,
pad_input_tensors,
send_to_device,
)


def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
"""
Calculates the device map for `model` with an offset for PiPPy
"""
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
if max_memory is None:
model_size, shared = calculate_maximum_sizes(model)

# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")

# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
max_memory = {i: memory for i in range(num_processes)}
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
return device_map


def find_pippy_batch_size(args, kwargs):
found_batch_size = None
for arg in args:
found_batch_size = ignorant_find_batch_size(arg)
if found_batch_size is not None:
break
for kwarg in kwargs.values():
found_batch_size = ignorant_find_batch_size(kwarg)
if found_batch_size is not None:
break
return found_batch_size


def build_pipeline(model, split_points, args, kwargs, num_chunks) -> PipelineStage:
"""
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
in needed `args` and `kwargs` as the model needs on the CPU.

Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
`AcceleratorState.num_processes`
"""
# We need to annotate the split points in the model for PiPPy
state = PartialState()
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, num_chunks)
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)

return stage


def pippy_forward(forward, *args, **kwargs):
state = PartialState()
output = None

if state.num_processes == 1:
output = forward(*args, **kwargs)
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
elif state.is_local_main_process:
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size is None:
raise ValueError("Could not find batch size from args or kwargs")
else:
if found_batch_size != state.num_processes:
args = pad_input_tensors(args, found_batch_size, state.num_processes)
kwargs = pad_input_tensors(kwargs, found_batch_size, state.num_processes)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
else:
forward()
return output


def prepare_pippy(
model, split_points="auto", no_split_module_classes=None, example_args=(), example_kwargs={}, num_chunks=None
):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"""
Wraps `model` for PipelineParallelism

Args:
model (`torch.nn.Module`):
A model we want to split for pipeline-parallel inference
split_points (`str`, defaults to 'auto'):
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
split given any model.
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of `torch.Tensor`):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
example_kwargs (dict of `torch.Tensor`)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks > num_gpus.
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"""
if not is_pippy_available():
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install git+https://github.com/pytorch/PiPPy"
)
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
if num_chunks is None:
num_chunks = state.num_processes
if split_points == "auto":
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
split_points = []
for i in range(1, num_chunks):
split_points.append(next(k for k, v in device_map.items() if v == i))
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
model._original_forward = model.forward
model._original_call = model.__call__
model.pippy_stage = stage
model.hf_split_points = split_points

def forward(*args, **kwargs):
return pippy_forward(stage.forward, *args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
model_forward = MethodType(forward, model)
forward.__wrapped__ = model_forward
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
model.forward = forward
return model
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
require_multi_gpu,
require_multi_xpu,
require_non_cpu,
require_pippy,
require_single_device,
require_single_gpu,
require_single_xpu,
Expand Down
130 changes: 130 additions & 0 deletions src/accelerate/test_utils/scripts/external_deps/test_pippy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.models import resnet34
from transformers import (
BertConfig,
BertForMaskedLM,
GPT2Config,
GPT2ForSequenceClassification,
T5Config,
T5ForConditionalGeneration,
)

from accelerate import PartialState
from accelerate.inference import prepare_pippy
from accelerate.utils import DistributedType, send_to_device, set_seed


model_to_config = {
"t5": (T5ForConditionalGeneration, T5Config, 1024),
"bert": (BertForMaskedLM, BertConfig, 512),
"gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024),
}


def get_model_and_data_for_text(model_name, device, num_processes: int = 2):
initializer, config, seq_len = model_to_config[model_name]
config_args = {}
# Eventually needed for batch inference tests on gpt-2 when bs != 1
# if model_name == "gpt2":
# config_args["pad_token_id"] = 0
model_config = config(**config_args)
model = initializer(model_config)
return model, torch.randint(
low=0,
high=model_config.vocab_size,
size=(num_processes, seq_len),
device=device,
dtype=torch.int64,
requires_grad=False,
)


def test_gpt2(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size)
model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules)
# For inference args need to be a tuple
inputs = inputs.to("cuda")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"


def test_t5(batch_size: int = 2):
set_seed(42)
state = PartialState()
model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size)
example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs}
model = prepare_pippy(
model,
no_split_module_classes=model._no_split_modules,
example_kwargs=example_inputs,
)
# For inference args need to be a tuple
inputs = send_to_device(example_inputs, "cuda:0")
with torch.no_grad():
output = model(*inputs.values())
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"


def test_resnet(batch_size: int = 2):
set_seed(42)
state = PartialState()
model = resnet34()
input_tensor = torch.rand(batch_size, 3, 224, 224)
model = prepare_pippy(
model,
example_args=(input_tensor,),
)
inputs = send_to_device(input_tensor, "cuda:0")
with torch.no_grad():
output = model(inputs)
# Zach: Check that we just grab the real outputs we need at the end
if not state.is_last_process:
assert output is None, "Output was not generated on just the last process!"
else:
assert output is not None, "Output was not generated in the last process!"


if __name__ == "__main__":
state = PartialState()
state.print("Testing pippy integration...")
if state.distributed_type == DistributedType.MULTI_GPU:
state.print("Testing GPT2...")
test_gpt2()
# Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue
# due to references
# NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope
# test_gpt2(3)
state.print("Testing T5...")
test_t5()
test_t5(1)
test_t5(3)
state.print("Testing CV model...")
test_resnet()
test_resnet(3)
else:
print("Less than two GPUs found, not running tests!")
8 changes: 8 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
is_mps_available,
is_npu_available,
is_pandas_available,
is_pippy_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -290,6 +291,13 @@ def require_pandas(test_case):
return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)


def require_pippy(test_case):
"""
Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed
"""
return unittest.skipUnless(is_pippy_available(), "test requires pippy")(test_case)


_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
)
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_npu_available,
is_pandas_available,
is_peft_available,
is_pippy_available,
is_rich_available,
is_sagemaker_available,
is_tensorboard_available,
Expand Down Expand Up @@ -126,12 +127,14 @@
gather_object,
get_data_structure,
honor_type,
ignorant_find_batch_size,
initialize_tensors,
is_namedtuple,
is_tensor_information,
is_torch_tensor,
listify,
pad_across_processes,
pad_input_tensors,
recursively_apply,
reduce,
send_to_device,
Expand Down
23 changes: 12 additions & 11 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@
_torch_distributed_available = torch.distributed.is_available()


def _is_package_available(pkg_name):
def _is_package_available(pkg_name, metadata_name=None):
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
if package_exists:
try:
_ = importlib.metadata.metadata(pkg_name)
# Some libraries have different names in the metadata
_ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name)
return True
except importlib.metadata.PackageNotFoundError:
return False
Expand Down Expand Up @@ -73,15 +74,7 @@ def get_ccl_version():


def is_msamp_available():
package_exists = importlib.util.find_spec("msamp") is not None
if package_exists:
try:
# MS-AMP has a different metadata name
_ = importlib.metadata.metadata("ms-amp")
return True
except importlib.metadata.PackageNotFoundError:
return False
return False
return _is_package_available("msamp", "ms-amp")


def is_transformer_engine_available():
Expand Down Expand Up @@ -126,6 +119,14 @@ def is_deepspeed_available():
return _is_package_available("deepspeed")


def is_pippy_available():
package_exists = _is_package_available("pippy", "torchpippy")
if package_exists:
pippy_version = version.parse(importlib.metadata.version("torchpippy"))
return compare_versions(pippy_version, ">", "0.1.1")
return False


def is_bf16_available(ignore_tpu=False):
"Checks if bf16 is supported, optionally ignoring the TPU"
if is_tpu_available():
Expand Down
Loading
Loading