Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch-native pipeline parallelism for big models #2345

Merged
merged 39 commits into from
Feb 6, 2024
Merged

Conversation

muellerzr
Copy link
Collaborator

@muellerzr muellerzr commented Jan 16, 2024

Example use:

import torch
from accelerate.inference import prepare_pippy
from accelerate.utils import set_seed
from transformers import T5ForConditionalGeneration, T5Config

set_seed(42)

config = T5Config()
model = T5ForConditionalGeneration(config)
model.eval()

# Create example inputs for the model
input = torch.randint(
    low=0,
    high=config.vocab_size,
    size=(2, 1024),  # bs x seq_len
    device="cpu",
    dtype=torch.int64,
    requires_grad=False,
)

example_inputs = {"input_ids": input, "decoder_input_ids": input}

model = prepare_pippy(model, example_kwargs=example_inputs)

args = (
    example_inputs["input_ids"].to("cuda:0"),
    example_inputs["decoder_input_ids"].to("cuda:0")
)
with torch.no_grad():
    output = model(*args)

Speed up:

Using 2x4090's in full precision

Bert

Accelerate/Sequential PiPPy + Accelerate
First batch 0.2137s 0.3119s
Average of 5 batches 0.0099s 0.0062s

GPT2

Accelerate/Sequential PiPPy + Accelerate
First batch 0.1959s 0.4189s
Average of 5 batches 0.0205s 0.0126s

T5

Accelerate/Sequential PiPPy + Accelerate
First batch 0.2789s 0.3809s
Average of 5 batches 0.0198s 0.0166s

@muellerzr muellerzr marked this pull request as draft January 16, 2024 19:35
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool API ! I like the design and how easily it is to use. I left a few comments around the split_points mainly.

Comment on lines 93 to 94
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
forward.__wrapped__ = model_forward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice !

Comment on lines 24 to 42
no_split_module_classes = getattr(model, "_no_split_modules", [])
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
model_size, shared = calculate_maximum_sizes(model)

# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")

# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
device_map = infer_auto_device_map(
model,
max_memory={i: memory for i in range(num_processes)},
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can definitely generate a balanced device_map for pippy exclusively "device_map = "balanced_pippy" if the current balanced option is not the best for that. However, I think it would be great if the user can use other options like "sequential". I didn't try but what happens when we only fill 2 gpus out of the 4 available (possible sequential case) ?

src/accelerate/inference.py Outdated Show resolved Hide resolved
Comment on lines 81 to 83
if device_map == "auto":
device_map = generate_device_map(model, PartialState().num_processes)
stage = build_pipeline(model, device_map, example_args, example_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought about how to handle the split points.

    1. We only expose device_map with predefined options ("sequential", "balanced_pippy")
    1. We let the user use a custom device_map. For the custom case, it can be complicated since the user needs to be careful about the order (OrderedDict()) and he needs to attribute the gpu in a sequential manner because of that split_points.append(next(k for k, v in device_map.items() if v == i)). So that can be quite complicated.
    1. We let the user let his own split points List[str].
      I think that 1) is a must. between 2) and 3), I prefer 3) since it is easier for the user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed to do 1 and 3

src/accelerate/inference.py Outdated Show resolved Hide resolved
src/accelerate/inference.py Outdated Show resolved Hide resolved
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API is in good shape ! Let's document the main functions a bit and we can merge it. I left a few comments but nothing blocking.

src/accelerate/inference.py Outdated Show resolved Hide resolved
src/accelerate/inference.py Outdated Show resolved Hide resolved
Comment on lines 80 to 84
if split_points == "auto":
device_map = generate_device_map(model, state.num_processes, no_split_module_classes=no_split_module_classes)
split_points = []
for i in range(1, state.num_processes):
split_points.append(next(k for k, v in device_map.items() if v == i))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be great to have a sanity check, to make sure that we indeed have self.num_processes split points when we are generating the split_points + when the user manually pass them

src/accelerate/inference.py Show resolved Hide resolved
src/accelerate/inference.py Show resolved Hide resolved
muellerzr and others added 9 commits January 25, 2024 14:56
* Allow for dynamic batch paddign

* Fix test

* Update src/accelerate/inference.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Break early after the first valid bs is found

* Less slicy-dicy

* Test cv model

* Start, need to test

* Use dataloader-like logic

* Refactor to utils

* With tests

* Update the source

* Clean

* bs=1 case

* Add test

* add some failing test

* Almost working version

* Much cleaner implementation

* Use pad_input_tensor

* All tests passing!

* Do it at tracing too

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
Copy link

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the integration effort!
LGTM!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for iterating ! LGTM

src/accelerate/inference.py Outdated Show resolved Hide resolved
src/accelerate/inference.py Outdated Show resolved Hide resolved
tests/test_utils.py Outdated Show resolved Hide resolved
src/accelerate/inference.py Outdated Show resolved Hide resolved
Copy link

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing the doc so quick! Looks good to me!

docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
muellerzr and others added 3 commits February 5, 2024 15:47
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@muellerzr muellerzr marked this pull request as ready for review February 5, 2024 20:53
@muellerzr muellerzr changed the title Pippy integration v2 torch-native pipeline parallelism for big models Feb 6, 2024
@muellerzr
Copy link
Collaborator Author

cc @MKhalusova for the docs!

Copy link
Contributor

@MKhalusova MKhalusova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I left a few comments to polish things in the docs a bit.

docs/source/package_reference/inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
docs/source/usage_guides/distributed_inference.md Outdated Show resolved Hide resolved
@muellerzr
Copy link
Collaborator Author

Final comment before merging, things that still need to be done in a latter PR at some point (but okay not being in the first iteration of this joint effort):

  1. Specify balanced_pippy device map and allow a sequential device_map when making the pipeline via prepare_pippy
  2. Look into supporting model.generate() through an alternative hook into the model forward if possible
  3. Make sure all outputs end up on the CPU so users don't need to check at the end and we can call them via a .gather
  4. Migrate the pippy-device-map-playground examples over to here as part of our examples folder

(I'll be doing 3& 4 this week as a follow-up prior to release)

@muellerzr muellerzr merged commit 0867c09 into main Feb 6, 2024
25 checks passed
@muellerzr muellerzr deleted the pippy-integration-v2 branch February 6, 2024 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants