-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch-native pipeline parallelism for big models #2345
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool API ! I like the design and how easily it is to use. I left a few comments around the split_points
mainly.
src/accelerate/inference.py
Outdated
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel` | ||
forward.__wrapped__ = model_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice !
src/accelerate/inference.py
Outdated
no_split_module_classes = getattr(model, "_no_split_modules", []) | ||
if num_processes == 1: | ||
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) | ||
model_size, shared = calculate_maximum_sizes(model) | ||
|
||
# Split into `n` chunks for each GPU | ||
memory = (model_size + shared[0]) / num_processes | ||
memory = convert_bytes(memory) | ||
value, ending = memory.split(" ") | ||
|
||
# Add a chunk to deal with potential extra shared memory instances | ||
memory = math.ceil(float(value)) * 1.1 | ||
memory = f"{memory} {ending}" | ||
device_map = infer_auto_device_map( | ||
model, | ||
max_memory={i: memory for i in range(num_processes)}, | ||
no_split_module_classes=no_split_module_classes, | ||
clean_result=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can definitely generate a balanced device_map
for pippy exclusively "device_map = "balanced_pippy"
if the current balanced option is not the best for that. However, I think it would be great if the user can use other options like "sequential". I didn't try but what happens when we only fill 2 gpus out of the 4 available (possible sequential
case) ?
src/accelerate/inference.py
Outdated
if device_map == "auto": | ||
device_map = generate_device_map(model, PartialState().num_processes) | ||
stage = build_pipeline(model, device_map, example_args, example_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought about how to handle the split points.
-
- We only expose
device_map
with predefined options ("sequential", "balanced_pippy")
- We only expose
-
- We let the user use a custom
device_map
. For the custom case, it can be complicated since the user needs to be careful about the order (OrderedDict()) and he needs to attribute the gpu in a sequential manner because of thatsplit_points.append(next(k for k, v in device_map.items() if v == i))
. So that can be quite complicated.
- We let the user use a custom
-
- We let the user let his own split points
List[str]
.
I think that 1) is a must. between 2) and 3), I prefer 3) since it is easier for the user.
- We let the user let his own split points
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed to do 1 and 3
66fb611
to
9eef9dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API is in good shape ! Let's document the main functions a bit and we can merge it. I left a few comments but nothing blocking.
src/accelerate/inference.py
Outdated
if split_points == "auto": | ||
device_map = generate_device_map(model, state.num_processes, no_split_module_classes=no_split_module_classes) | ||
split_points = [] | ||
for i in range(1, state.num_processes): | ||
split_points.append(next(k for k, v in device_map.items() if v == i)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be great to have a sanity check, to make sure that we indeed have self.num_processes
split points when we are generating the split_points
+ when the user manually pass them
* Allow for dynamic batch paddign * Fix test * Update src/accelerate/inference.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Break early after the first valid bs is found * Less slicy-dicy * Test cv model * Start, need to test * Use dataloader-like logic * Refactor to utils * With tests * Update the source * Clean * bs=1 case * Add test * add some failing test * Almost working version * Much cleaner implementation * Use pad_input_tensor * All tests passing! * Do it at tracing too --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Marc Sun <marc@huggingface.co>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the integration effort!
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for iterating ! LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for writing the doc so quick! Looks good to me!
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
cc @MKhalusova for the docs! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I left a few comments to polish things in the docs a bit.
Final comment before merging, things that still need to be done in a latter PR at some point (but okay not being in the first iteration of this joint effort):
(I'll be doing 3& 4 this week as a follow-up prior to release) |
Example use:
Speed up:
Using 2x4090's in full precision
Bert
GPT2
T5