-
Notifications
You must be signed in to change notification settings - Fork 10
Support of device ordering #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
autoparallel/redistribute_tensor.py
Outdated
| and device order without collective ops. | ||
| Example: | ||
| S(0)RR with device order [0, 1, 2] can be permuted to [0, 2, 1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm confused about a couple of things
- device order doesn't really matter for replicas right? so this permutation seems like it should be a no-op
- explain the syntax below?
S(a)[x,y]means what, Shard(dim=a)? what is the [x, y] part, just device -order of the replicas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device order doesn't really matter for replicas right? so this permutation seems like it should be a no-op
Yes, by shuffling the ordering in RRR, we can build the path between src and dest specs under different device ordering. For example, to convert placement S(0)S(0)S(0) order: [0,1,2] to placement RS(0)S(0) order: [2,0,1], (note the device dim for S(0)S(0)S(0) or RS(0)S(0) is always [0,1,2], regardless of device order) it will go through the following path:
| Step | Placements | Device Order |
|---|---|---|
| 0 | S(0)S(0)S(0) | [0,1,2] |
| 1 | S(0)S(0)R | [0,1,2] |
| 2 | S(0)RR | [0,1,2] |
| 3 | RRR | [0,1,2] |
| 4 | RRR | [2,0,1] (no-op) |
| 5 | RS(0)R | [2,0,1] |
| 6 | RS(0)S(0) | [2,0,1] |
Sometime it doesn't need to convert all placement to RRR, but the algorithm think this is the best path in this case.
explain the syntax below? S(a)[x,y] means what, Shard(dim=a)? what is the [x, y] part, just device -order of the replicas?
S(a)[x,y]: Shard tensor dim a on mesh dim x and y, where device order x < y. x, y is mesh dim, not device order, but [x,y] are sorted based on device order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update the design doc to cover those details.
|
I notice a limitation: Currently this PR only has two device ordering specified, one for src placements and another for dest placements. In fact, each Shard() of a specific tensor dim should have its own specified device ordering. I need to provide different ordering for different tensor dim, but that should be a small change on top of the current setup. I will make the update. |
autoparallel/redistribute_tensor.py
Outdated
| # case 1. Shard(a) -> Shard(b), use all to all, apply to case: | ||
| # S(a)[x] -> S(b)[x] or | ||
| # S(a)[x,y]S(b)[z,k] -> S(a)[x]S(b)[z,k,y], where device order of `y`` | ||
| # > device order of `z` and `k` (need confirm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need confirm
I give it another thought, we actually don't need to define one device order for each Shard(x). One device order for src placement and one device order for dest placement should be sufficient.
This can also tell sharding order regarding each tensor dim. The drawback is that different |
| i, | ||
| transform_info.logical_shape, | ||
| target_placement.dim, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am concerned about this, because shard to shard cannot always be done in one collective (as far as I know). The canonical example is S(0)S(1) to S(1)S(1) with left-to-right ordering. To do this you must all-gather on two device mesh dims (https://gist.github.com/mattjj/c9ec34f9e9a6fe0b99cbfe8fd56a3808). Perhaps you have an invariant from _gen_transform_infos that you will never request to do this but even then I don't like it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I agree under the canonical order, S(0)S(1)-> S(1)S(1) can't be handled in one all-to-all. As I mentioned in the code:
# For S(a), S(b), only the last device order of S(a) and S(b) can be all to all interchangeably. (need confirm)
The code won't allow S(0)S(1)-> S(1)S(1) to happen in one shot. We can only operate on the last order of a Shard(x) and the transition path should be: S(0)S(1) -> S(0)R -> S(1)R -> S(1)S(1), this seems consistent with the reference you pasted.
autoparallel/redistribute_tensor.py
Outdated
| # S(a)[x,y]S(b)[z,k] -> S(a)[x]S(b)[z,k,y], where device order of `y`` | ||
| # > device order of `z` and `k` (need confirm) | ||
|
|
||
| # case 2. Shard() -> Replicate(), use all gather, apply to case: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to add another layer of complexity to this whole thing, I think this algorithm is overly focused on the optimal order of collectives to perform on the mesh-dims as specified.
but in many cases we may be better off doing 1 step on a larger communicator than 2 steps on 2 smaller communciators.
This point probably applies to every case but i give the specific example here-
both of these could be one-shotted via all-gather larger communicators, which we may have access to (but aren't considering today)
S(a)[x,y,z] -> S(a)[x]R[y,z]
S(a)[x,y,z] -> R[x,y,z]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can define those patterns under different cases. Basically in the graph, we can refine on the next possible states given the current state. I am thinking maybe we can even handle more optimization case here. E.g., you mention we can work on a parent mesh to flatten on multiple device dims and gather in one shot. We may detect the patten here without exposing the device mesh hierarchy to the user.
autoparallel/redistribute_tensor.py
Outdated
| # <tensor dim> is sharded on <list of device dims>, where the <list of | ||
| # device dims> is sorted by device order. | ||
|
|
||
| # Blow are possible transition from one sharding state to another. We |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below
| ) | ||
| return ret | ||
|
|
||
| def get_next_state( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So is the overall idea here that you just exhaustively enumerate all the collectives you could do and specify what exactly happens to the sharding if you run that collective? If so, why do we not also specify what collective was run and then read it out directly (instead of having to reconstruct it again when we actually apply the transforms)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure if I understand read it out directly correctly. The _TransformInfo requires we input the expected logical shape after each collective operation. This logical shape information is not needed during the graph search, that's why I split it out.
|
Looks like we don't need to support DTensor DCP even with device ordering here. The DTensor in autoparallel works like the intermediate state. Even with device order specification, we still return a normal local tensor: |
@zpcore we still need DTensor DCP support here, because model parameters might be in a different order than the canonical order (and this is indeed the case in current state of main). |
autoparallel/ordered_sharding.py
Outdated
| curr_spec, | ||
| tgt_spec, | ||
| src_device_order=placement_order, | ||
| dst_device_order=canonical, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @fmassa , can you help confirm if this is true that dst_device_order is always canonical? Is the src_device_order/dst_device_order reversed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zpcore yes. For now we are assuming that only the parameters of the model (or its direct descendants) can have a non-canonical placement.
So we decide for the parameters what is its placement, and the only place we actually perform the conversion with ordered sharding is for those parameters. So for now I believe the code is correct.
But note that once we have full support for ordered sharding, we can actually remove this limiting assumption and optimize for the src / dst device order globally.
d592a02 to
a8c01a5
Compare
|
I figured out the reason why redistribute API gives different loss value, there are several issues:
I believe the code should work if we force the backward grad node pair for |
|
It turns out I was wrong in #95 (comment) regarding the first issue. The device order are generally follow (0,1), expect some last chain nodes related to params and its grad. I updated the expected placement order and now the loss curve and MFU start to match: (note: |
8094347 to
8304ced
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot!
I have only some minor comments, and if the runs are working fine wrt loss / speed then I'm confident we are not missing anything here!
| # below is supposed not to be triggered | ||
| redistribute_node_order[p] = ((1, 0), False) | ||
| for node, (order, _) in redistribute_node_order.items(): | ||
| node.meta["device_order"] = order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this is only used for printing / logging, is that right?
Yes, actually I didn't make use of this meta field because we call trace_structured for autoparallel_sharding_optimizer_log before we made the device order change. I add it here in case we need for debugging in the future.
| @@ -0,0 +1,864 @@ | |||
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not reviewed the logic in this file, but I assume it will be moved to PyTorch soon, is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, don't need to worry about reviewing this part.
autoparallel/apply_sharding.py
Outdated
| origin_order = None | ||
| tgt_order = None | ||
| if node in self.param_placement_order: | ||
| tgt_order, do_redistribute = self.param_placement_order[node] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by the do_redistribute -- in this case, it will always be true because we have that curr_spec.placements != tgt_spec.placements, is that right?
If yes, then this is here for clarity, is that right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should make do_redistribute name as do_reorder. This bool tells whether we need to permute the device order or not, regardless of the placments. I have updated the naming to make it clear.
|
For reference, this is the full set set I run to compare the PR with the main HEAD: I didn't notice accuracy or performance regressions. One point I would like to bring out in this PR is that now we can define cost value for collective ops s.t. we can find a |
Taken from pytorch/pytorch#160266, but I'm hitting an assertion for now
084eeb1 to
d6969f7
Compare
Nice! Let's set some time early next week to discuss about it |



Based on pytorch/pytorch#160266. Now the llama3 run can pass.
I rewrite the redistribute algorithm to support ordered device mesh. Now the sharding transformation is much more clear than before. It uses the Dijkstra's algorithm to find the path with the minimal cost based on the following parameter:
To make it more precision, we can turn those cost into functions based on the tensor shape.
Currently redistribute use the original solution if no device order is specified. We can enable the new algorithm for all cases by setting
use_greedy_transform = Falsein_gen_transform_infos_non_cachedfunction.