Skip to content

Conversation

@zpcore
Copy link
Contributor

@zpcore zpcore commented Aug 13, 2025

Based on pytorch/pytorch#160266. Now the llama3 run can pass.

I rewrite the redistribute algorithm to support ordered device mesh. Now the sharding transformation is much more clear than before. It uses the Dijkstra's algorithm to find the path with the minimal cost based on the following parameter:

        all_reduce_cost: int = 4,
        all_to_all_cost: int = 1,
        all_gather_cost: int = 2,
        chunk_cost: int = 0,

To make it more precision, we can turn those cost into functions based on the tensor shape.

Currently redistribute use the original solution if no device order is specified. We can enable the new algorithm for all cases by setting use_greedy_transform = False in _gen_transform_infos_non_cached function.

  • Note that device ordering DCP is not supported yet.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 13, 2025
@zpcore zpcore requested review from ezyang, fmassa and wconstab August 13, 2025 22:36
and device order without collective ops.
Example:
S(0)RR with device order [0, 1, 2] can be permuted to [0, 2, 1].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused about a couple of things

  1. device order doesn't really matter for replicas right? so this permutation seems like it should be a no-op
  2. explain the syntax below? S(a)[x,y] means what, Shard(dim=a)? what is the [x, y] part, just device -order of the replicas?

Copy link
Contributor Author

@zpcore zpcore Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device order doesn't really matter for replicas right? so this permutation seems like it should be a no-op

Yes, by shuffling the ordering in RRR, we can build the path between src and dest specs under different device ordering. For example, to convert placement S(0)S(0)S(0) order: [0,1,2] to placement RS(0)S(0) order: [2,0,1], (note the device dim for S(0)S(0)S(0) or RS(0)S(0) is always [0,1,2], regardless of device order) it will go through the following path:

Step Placements Device Order
0 S(0)S(0)S(0) [0,1,2]
1 S(0)S(0)R [0,1,2]
2 S(0)RR [0,1,2]
3 RRR [0,1,2]
4 RRR [2,0,1] (no-op)
5 RS(0)R [2,0,1]
6 RS(0)S(0) [2,0,1]

Sometime it doesn't need to convert all placement to RRR, but the algorithm think this is the best path in this case.

explain the syntax below? S(a)[x,y] means what, Shard(dim=a)? what is the [x, y] part, just device -order of the replicas?

S(a)[x,y]: Shard tensor dim a on mesh dim x and y, where device order x < y. x, y is mesh dim, not device order, but [x,y] are sorted based on device order.

Copy link
Contributor Author

@zpcore zpcore Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will update the design doc to cover those details.

@zpcore
Copy link
Contributor Author

zpcore commented Aug 14, 2025

I notice a limitation: Currently this PR only has two device ordering specified, one for src placements and another for dest placements. In fact, each Shard() of a specific tensor dim should have its own specified device ordering.

I need to provide different ordering for different tensor dim, but that should be a small change on top of the current setup. I will make the update.

# case 1. Shard(a) -> Shard(b), use all to all, apply to case:
# S(a)[x] -> S(b)[x] or
# S(a)[x,y]S(b)[z,k] -> S(a)[x]S(b)[z,k,y], where device order of `y``
# > device order of `z` and `k` (need confirm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need confirm

@zpcore
Copy link
Contributor Author

zpcore commented Aug 14, 2025

I notice a limitation: Currently this PR only has two device ordering specified, one for src placements and another for dest placements. In fact, each Shard() of a specific tensor dim should have its own specified device ordering.

I need to provide different ordering for different tensor dim, but that should be a small change on top of the current setup. I will make the update.

I give it another thought, we actually don't need to define one device order for each Shard(x). One device order for src placement and one device order for dest placement should be sufficient.

  • (Wanchao proposed) we can use define order as order_list = list[list[int]] = [[0,1], [1,0]], then for placements
    S(1)S(0)S(1)S(0)R, the order info is used below:
S(0)'s device order is order_list[0] = [0,1].
S(1)'s device order is order_list[1] = [1,0].
  • (how we do in this PR) If we only define one order based on the device mesh dim size as order_list = list[int] = [3,1,0,2] (4d device mesh), then for placements
    S(1)S(0)S(1)S(0)R, the order info is used below:
S(0)'s device order is order_list[1,3] = [1,2]
S(1)'s device order is order_list[0,2] = [3,0]

This can also tell sharding order regarding each tensor dim. The drawback is that different order_list may express the same ordering towards specific placements, which is redundant.

i,
transform_info.logical_shape,
target_placement.dim,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned about this, because shard to shard cannot always be done in one collective (as far as I know). The canonical example is S(0)S(1) to S(1)S(1) with left-to-right ordering. To do this you must all-gather on two device mesh dims (https://gist.github.com/mattjj/c9ec34f9e9a6fe0b99cbfe8fd56a3808). Perhaps you have an invariant from _gen_transform_infos that you will never request to do this but even then I don't like it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I agree under the canonical order, S(0)S(1)-> S(1)S(1) can't be handled in one all-to-all. As I mentioned in the code:

# For S(a), S(b), only the last device order of S(a) and S(b) can be all to all interchangeably. (need confirm)

The code won't allow S(0)S(1)-> S(1)S(1) to happen in one shot. We can only operate on the last order of a Shard(x) and the transition path should be: S(0)S(1) -> S(0)R -> S(1)R -> S(1)S(1), this seems consistent with the reference you pasted.

# S(a)[x,y]S(b)[z,k] -> S(a)[x]S(b)[z,k,y], where device order of `y``
# > device order of `z` and `k` (need confirm)

# case 2. Shard() -> Replicate(), use all gather, apply to case:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to add another layer of complexity to this whole thing, I think this algorithm is overly focused on the optimal order of collectives to perform on the mesh-dims as specified.

but in many cases we may be better off doing 1 step on a larger communicator than 2 steps on 2 smaller communciators.

This point probably applies to every case but i give the specific example here-
both of these could be one-shotted via all-gather larger communicators, which we may have access to (but aren't considering today)
S(a)[x,y,z] -> S(a)[x]R[y,z]
S(a)[x,y,z] -> R[x,y,z]

Copy link
Contributor Author

@zpcore zpcore Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can define those patterns under different cases. Basically in the graph, we can refine on the next possible states given the current state. I am thinking maybe we can even handle more optimization case here. E.g., you mention we can work on a parent mesh to flatten on multiple device dims and gather in one shot. We may detect the patten here without exposing the device mesh hierarchy to the user.

# <tensor dim> is sharded on <list of device dims>, where the <list of
# device dims> is sorted by device order.

# Blow are possible transition from one sharding state to another. We
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below

)
return ret

def get_next_state(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is the overall idea here that you just exhaustively enumerate all the collectives you could do and specify what exactly happens to the sharding if you run that collective? If so, why do we not also specify what collective was run and then read it out directly (instead of having to reconstruct it again when we actually apply the transforms)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure if I understand read it out directly correctly. The _TransformInfo requires we input the expected logical shape after each collective operation. This logical shape information is not needed during the graph search, that's why I split it out.

@zpcore zpcore requested a review from XilunWu August 15, 2025 23:13
@zpcore
Copy link
Contributor Author

zpcore commented Aug 19, 2025

Looks like we don't need to support DTensor DCP even with device ordering here. The DTensor in autoparallel works like the intermediate state. Even with device order specification, we still return a normal local tensor:
https://github.com/meta-pytorch/autoparallel/blob/b3d667b880849534431a58138b8eee5bb3671cec/autoparallel/ordered_sharding.py#L18C5-L38.

cc @XilunWu @wconstab

@fmassa
Copy link
Contributor

fmassa commented Aug 20, 2025

Looks like we don't need to support DTensor DCP even with device ordering here. The DTensor in autoparallel works like the intermediate state. Even with device order specification, we still return a normal local tensor:

@zpcore we still need DTensor DCP support here, because model parameters might be in a different order than the canonical order (and this is indeed the case in current state of main).

curr_spec,
tgt_spec,
src_device_order=placement_order,
dst_device_order=canonical,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fmassa , can you help confirm if this is true that dst_device_order is always canonical? Is the src_device_order/dst_device_order reversed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zpcore yes. For now we are assuming that only the parameters of the model (or its direct descendants) can have a non-canonical placement.
So we decide for the parameters what is its placement, and the only place we actually perform the conversion with ordered sharding is for those parameters. So for now I believe the code is correct.

But note that once we have full support for ordered sharding, we can actually remove this limiting assumption and optimize for the src / dst device order globally.

@zpcore zpcore force-pushed the fmassa/try_new_redistribute_local_tensor branch 2 times, most recently from d592a02 to a8c01a5 Compare August 25, 2025 22:12
@zpcore
Copy link
Contributor Author

zpcore commented Aug 26, 2025

I figured out the reason why redistribute API gives different loss value, there are several issues:

  1. We can't simply assume some param nodes are ordered as (0,1) and all following nodes are ordered as (1,0). If we assume all following nodes are ordered as (1,0), the redistribute rule ops triggered will be different. E.g.,
    RS(0) -> S(0)S(0). If they are both under order (0,1). The convert path will be RS(0) -> RR -> S(0)R -> S(0)S(0); If they are both under order (1,0), the convert path will be RS(0) -> S(0)S(0). So even though the input and output are both under the same device ordering, the result can still be different if we use the new ordered sharding redistribute API.
    However, since those pair input output are both under the same (1,0) order and they have been triggered by normal redistribution code w/o ordering to take effect, the code can still give correct forward pass result.

  2. Another issue is in the backward path when we replace the order for PS(0) -> S(0)S(0). With @fmassa 's original fix here, PS(0) -> S(0)S(0) will work like they can be changed in one shot. However, if we call the new redistribute API, the path will become: P(sum)S(0)<0> -> P(sum)S(1)<0> -> S(0)<0>S(1)<0> -> S(0)<0>S(0)<1>. (note number in <> means device order, relative within the same sharding).

I believe the code should work if we force the backward grad node pair for PS(0) -> S(0)S(0) to make S(0)S(0) order as (1,0) (in contrast, current code ordered PS(0) as (1,0) and S(0)S(0) as (0,1)), this will give the same output if we call with the new redistribute API. However, this works as a hack and I suggest we should let the new redistribute API to figure out the backward. In addition, if we abuse redistribute API in this way, the DCP won't work as expected.

@zpcore
Copy link
Contributor Author

zpcore commented Aug 27, 2025

It turns out I was wrong in #95 (comment) regarding the first issue. The device order are generally follow (0,1), expect some last chain nodes related to params and its grad. I updated the expected placement order and now the loss curve and MFU start to match:
tbm main:torchtitan-8-piz-cr7ndw device_order_redistribute_api:torchtitan-8-piz-p3cvb9 .

(note: _optimize_same_nd_sharding_as_1d is not called for both cases)
image
image

@zpcore zpcore force-pushed the fmassa/try_new_redistribute_local_tensor branch 2 times, most recently from 8094347 to 8304ced Compare August 27, 2025 22:54
@zpcore
Copy link
Contributor Author

zpcore commented Aug 27, 2025

I brought back the _optimize_same_nd_sharding_as_1d and cleaned up the logic to discover the device order permutation logic. Now I can compare with performance with the main branch. The performance matches with the HEAD now:
tbm head:torchtitan-8-piz-vhbwrqg device_order_API:torchtitan-8-piz-nzrcd9t

image

Let's see if we can merge the PR by the August milestone report.

Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

I have only some minor comments, and if the runs are working fine wrt loss / speed then I'm confident we are not missing anything here!

# below is supposed not to be triggered
redistribute_node_order[p] = ((1, 0), False)
for node, (order, _) in redistribute_node_order.items():
node.meta["device_order"] = order
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh What is the current state of adding metadata into nodes? Is it something that we should be generally doing?

@zpcore IIUC, this is only used for printing / logging, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is only used for printing / logging, is that right?

Yes, actually I didn't make use of this meta field because we call trace_structured for autoparallel_sharding_optimizer_log before we made the device order change. I add it here in case we need for debugging in the future.

@@ -0,0 +1,864 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not reviewed the logic in this file, but I assume it will be moved to PyTorch soon, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, don't need to worry about reviewing this part.

origin_order = None
tgt_order = None
if node in self.param_placement_order:
tgt_order, do_redistribute = self.param_placement_order[node]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by the do_redistribute -- in this case, it will always be true because we have that curr_spec.placements != tgt_spec.placements, is that right?

If yes, then this is here for clarity, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should make do_redistribute name as do_reorder. This bool tells whether we need to permute the device order or not, regardless of the placments. I have updated the naming to make it clear.

@zpcore
Copy link
Contributor Author

zpcore commented Aug 28, 2025

For reference, this is the full set set I run to compare the PR with the main HEAD:

tbm llama3_FSDP_compile:torchtitan-64-piz-w476ds llama3_autop_1d_compile:torchtitan-64-piz-rjbnhgzr llama3_autop_1d_compile_ruisi_bucket_reorder:torchtitan-64-piz-dk7dwx2 llama3_FSDP_tp_compile:torchtitan-64-piz-f1jd7p llama3_autop_2d_compile:torchtitan-64-piz-r25gpf llama3_autop_2d_compile_ruisi_bucket_reorder:torchtitan-64-piz-q170kd llama3_FSDP_compile(device_order_API):torchtitan-64-piz-jrs2lcdn llama3_autop_1d_compile(device_order_API):torchtitan-64-piz-z69gq5 llama3_autop_1d_compile_ruisi_bucket_reorder(device_order_API):torchtitan-64-piz-sfcgq02 llama3_FSDP_tp_compile(device_order_API):torchtitan-64-piz-rqh00bb llama3_autop_2d_compile(device_order_API):torchtitan-64-piz-gds4w4r llama3_autop_2d_compile_ruisi_bucket_reorder(device_order_API):torchtitan-64-piz-rggr5hs0

I didn't notice accuracy or performance regressions.

One point I would like to bring out in this PR is that now we can define cost value for collective ops s.t. we can find a lower cost path to transform from one placement to another. I compared the generated path with the previous redistribute solution, the new redistribute function can make use of a free device mesh to avoid replication. This should be helpful to reduce the memory overhead. However, the solution is not reflected in the current redistribute_cost estimation yet so it won't change the solver output. As a follow up, I can start looking into improving the redistribute_cost.

@zpcore zpcore force-pushed the fmassa/try_new_redistribute_local_tensor branch from 084eeb1 to d6969f7 Compare August 28, 2025 21:10
@zpcore zpcore merged commit 760cc7d into main Aug 28, 2025
6 checks passed
@fmassa fmassa deleted the fmassa/try_new_redistribute_local_tensor branch August 29, 2025 05:18
@fmassa
Copy link
Contributor

fmassa commented Aug 29, 2025

One point I would like to bring out in this PR is that now we can define cost value for collective ops s.t. we can find a lower cost path to transform from one placement to another. I compared the generated path with the previous redistribute solution, the new redistribute function can make use of a free device mesh to avoid replication. This should be helpful to reduce the memory overhead. However, the solution is not reflected in the current redistribute_cost estimation yet so it won't change the solver output. As a follow up, I can start looking into improving the redistribute_cost.

Nice! Let's set some time early next week to discuss about it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants