Skip to content

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Jun 27, 2025

This makes it easier to experiment with Llama3, as it doesn't require installing TorchTitan.

Also adds an example of how to add node constraints, and 1d mesh.

The example for the node constraint was necessary for now as the solution proposed was to all-gather the whole set of embeddings + inputs (which is small), and then shard it afterwards (which is free wrt comms as it would start from a replicate tensor). This is actually bad because the whole activation after the embedding is massive, but given that our solver doesn't take activation memory into account (only runtime and input memory), it thought it was fine to do it as is.

As a next step I'll look into adding some activation memory constraint, so that this type of behavior gets forbidden. Another solution would be to add a compute cost to embedding-bag, which for now has 0 compute cost as only gemm flops are taken into account.

Fix for smaller batch size

This PR uncovered a bug in the sharding schemes for a number of ops. Indeed, the sharding schemes would allow to shard tensors on a dimension which was smaller than the world size. As an example, a tensor of shape [32, 4096, 4096] could previously be sharded as S(0),S(0) on a mesh of size (32, 8). This case is now forbidden.

This should ideally be fixed upstream I believe.

fmassa added 2 commits June 27, 2025 09:10
This is a self-contained copy-paste from TorchTitan that works
@fmassa fmassa requested review from bdhirsh and wconstab June 27, 2025 09:53
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 27, 2025
@fmassa fmassa changed the title Add self-contained example for Llama3 model from TorchTitan Fix for invalid sharding and add self-contained example for Llama3 model from TorchTitan Jun 27, 2025
@wconstab
Copy link
Contributor

i've confirmed that this PR fixes the issue I described in this doc, namely that llama3_8b variant was trying to communicate a complex tensor and failing.

tlparse from my run

Specifically, this mul op now has the same sharding decision as it had in the debugmodel variant that was always working.

    mul_2 = torch.ops.aten.mul.Tensor(view_as_complex, view_11);  view_as_complex = None  # placement=(S(0)S(2), RR) -> S(0)S(2), cost=[0.0, 0.0]

However I am not convinced this is a complete fix. I think we are just lucky that this change avoids a solution that would shard a complex number.

  • a complete fix should either (a) preferred: gaurantee unwrapping complex tensors before communicating them, or (b) ensure complex tensors are never communicated.

Further, unless i'm confused, both of the shardings of freqs_cis complex tensor in my broken l3_8b and working debugmodel were valid under the 8gpu DP=2,TP=4 mesh I was using. So the fact that this banning of invalid shardings helped at all seems like a butterfly effect to me?

Further details about my complex number case

Previously broken case in llama3_8b:
view_as_complex_5 gets S0S2 sharding which should be tensor dims 2,8 over mesh dims 2,4 and should be valid.

    view_60 = torch.ops.aten.view.default(view_57, [2, 8192, 8, -1, 2]);  view_57 = None  # placement=(S(0)S(2)) -> S(0)S(2), cost=[0.0]
    view_as_complex_5 = torch.ops.aten.view_as_complex.default(view_60);  view_60 = None  # placement=(S(0)S(2)) -> S(0)S(2), cost=[0.0]

Previously working case in debugmodel:
view_as_complex gets S0S0 sharding of tensor dim 16 over mesh dims 2, 4

    view_9 = torch.ops.aten.view.default(view_6, [16, 2048, 16, -1, 2]);  view_6 = None  # placement=(S(0)S(0)) -> S(0)S(0), cost=[0.0]   
    view_as_complex = torch.ops.aten.view_as_complex.default(view_9);  view_9 = None  # placement=(S(0)S(0)) -> S(0)S(0), cost=[0.0]

Newly working case with this PR for llama3_8b:

It appears that some 'view_as_complex' nodes are not double-sharded and this helps them avoid comms before mul. However, the double shardings (e.g. S0S2 or S0S1) would still have been valid for this tensor shape and mesh dim. So this is why I claim the fix is not a direct fix but rather a butterfly effect / luck.

    view_10 = torch.ops.aten.view.default(view_7, [2, 8192, 8, -1, 2]);  view_7 = None  # placement=(S(0)P) -> S(0)P, cost=[0.0]
    view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_10);  view_10 = None  # placement=(S(0)R) -> S(0)R, cost=[544.6925883694413]

Still, happy to land this and take the win for now while we debate further! Thanks for this fix!

for mesh_shape, plc in zip(mesh.shape, input_spec.placements):
if plc.is_shard():
dim = plc.dim
if shape[dim] % mesh_shape == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically, we can support uneven sharding, can't we? But for now I am ok with this. Also @XilunWu / @zpcore are going to put this logic into DTensor so we can remove it from here once that's ready.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now in DTensor sharding prop, we only check if tensor_dim_size >= mesh_size in some rules as in https://github.com/pytorch/pytorch/blob/0decd966af9cdcb7ab4410cf475d2fc09f2dea0c/torch/distributed/tensor/_ops/_matrix_ops.py#L78-L80 .

Francisco's case is a bit different: what he's doing is to filter out uneven sharding. I would like you guys to clarify on this a bit.

I think there're 3 concepts:

  1. empty local tensor. This comes from some special cases of uneven sharding.
  2. uneven sharding. This doesn't necessarily lead to empty local tensor.
  3. the current is_tensor_shardable which means tensor_dim_size >= mesh_size.

And there're some TODOs:

  1. unify (2) and (3). The consequence is we won't be able to support uneven sharding. Wanchao has a TODO on unifying (2) and (3) ( see https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/utils.py#L147-L148 ) but I think this requires us to enable "always-on padding" (sorry I forgot the exact term we've been using, but the meaning is to always pad DTensor's local tensor rather than only pad/unpad around collectives) otherwise this unification rejects all uneven DTensors in sharding prop.

  2. enforce is_tensor_shardable check in every op's sharding prop.

My question to @Francescaaa is, why AutoParallel wants to avoid (1)? If this is a hard blocker, then we can consider upstreaming a property to OpStrategy that outputs all even sharding strategies, because filtering out all uneven sharding strategies directly in sharding prop may be problematic for current DTensor status (i.e. we need to support always-on padding).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can ultimately support uneven sharding. But it might be better to focus on the "simple case" first and have everything be evenly shardable.

In the end, we should filter out everything which is not possible, so something like tensor_dim_size >= mesh_size indeed. But to make things simpler to reason about I'm filtering a bit more things.

My question to @fmassa is, why AutoParallel wants to avoid (1)?

For now I believe keeping only even sharding will probably be best to ensure the rest of the stack is working as expected, as it will be easier to reason about (i.e., we don't need to check the outcome of many different traces to understand what changed between them).
But when supporting dynamic shapes (which will be important for Ads models), we will need to re-enable uneven sharding indeed

@wconstab wconstab merged commit ef70738 into main Jun 27, 2025
2 checks passed
embedding_nodes = autop.gm.graph.find_nodes(
op="call_function", target=torch.ops.aten.embedding.default
)
autop.sharding_optimizer.add_node_constraint(embedding_nodes[0], x_sharding)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i want to better understand / think harder about what you think the right way to model activation memory is. I need to think some more but just dumping some raw thoughts:

(1) today we only have a constraint to limit total param memory, not activation memory. I guess ideally we would model "peak memory hit during forward graph execution", which seems hard since we need to reason about lifetimes of the activations (it might be ok to allgather a large activation if we know it dies before we hit peak memory later on). But this is hard because we don't actually know which activations will be saved for backward vs recomputed (maybe some simplification that all "allgathered activations will be recomputed" is ok?)

(2) at the same time, replicating vs. sharding an activation will also change the runtime. Maybe there is a way to bias the solver against replicating activations if only because the compute will take longer? Although I guess this is only true for ops that have meaningful flops, since we're saying that pointwise ops are "free" today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also wondering - if there are going to be cases where a user really does need to add a manual node constraint to get a more optimal solving w.r.t. memory usage, it would be nice to have tooling to make it easier to identify what nodes to change.

One thing that could be cool to do is to use FakeTensor + memory estimator code to figure out what node are run peak memory during the forward, and maybe provide some visualizations that a user can eyeball to find nodes that they think we have made bad sharding decisions for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding (1), yes, modelling activation memory properly is going to be hard as we don't know what is the decision the partitioner will take. One rough proxy is just to constraint that every node shouldn't use more than x GB of memory, but that's probably not great in all situations

On (2), yes, ideally if we had cost estimates that go beyond flops for those ops then just modelling the runtime would be enough. In the embedding case, given that we don't have a flop formula for it it is returned as being free to compute, which would then always favor replicate tensors because the cost to go from replicate to any sharding is 0.

I'm also wondering - if there are going to be cases where a user really does need to add a manual node constraint to get a more optimal solving w.r.t. memory usage, it would be nice to have tooling to make it easier to identify what nodes to change.

Yes, totally! We have a memory estimator that I've built (and has been improved by Shatian) which we could use it here. In genera we should work on improving the UX here.

One other possibility is to output a graph representation colored by activation size, which would make it easier to spot problematic nodes

@fmassa fmassa deleted the fmassa/example_llama3 branch June 30, 2025 07:41
bdhirsh pushed a commit that referenced this pull request Jul 1, 2025
…del from TorchTitan (#22)

* Add Llama3 example

This is a self-contained copy-paste from TorchTitan that works

* Cleanup

* Fix license

* Remove invalid configurations that would yield empty shapes

* Cleanup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants