-
Notifications
You must be signed in to change notification settings - Fork 9
Fix for invalid sharding and add self-contained example for Llama3 model from TorchTitan #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is a self-contained copy-paste from TorchTitan that works
|
i've confirmed that this PR fixes the issue I described in this doc, namely that llama3_8b variant was trying to communicate a complex tensor and failing. Specifically, this mul op now has the same sharding decision as it had in the debugmodel variant that was always working. However I am not convinced this is a complete fix. I think we are just lucky that this change avoids a solution that would shard a complex number.
Further, unless i'm confused, both of the shardings of freqs_cis complex tensor in my broken l3_8b and working debugmodel were valid under the 8gpu DP=2,TP=4 mesh I was using. So the fact that this banning of invalid shardings helped at all seems like a butterfly effect to me? Further details about my complex number casePreviously broken case in llama3_8b: Previously working case in debugmodel: Newly working case with this PR for llama3_8b: It appears that some 'view_as_complex' nodes are not double-sharded and this helps them avoid comms before mul. However, the double shardings (e.g. S0S2 or S0S1) would still have been valid for this tensor shape and mesh dim. So this is why I claim the fix is not a direct fix but rather a butterfly effect / luck. Still, happy to land this and take the win for now while we debate further! Thanks for this fix! |
| for mesh_shape, plc in zip(mesh.shape, input_spec.placements): | ||
| if plc.is_shard(): | ||
| dim = plc.dim | ||
| if shape[dim] % mesh_shape == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now in DTensor sharding prop, we only check if tensor_dim_size >= mesh_size in some rules as in https://github.com/pytorch/pytorch/blob/0decd966af9cdcb7ab4410cf475d2fc09f2dea0c/torch/distributed/tensor/_ops/_matrix_ops.py#L78-L80 .
Francisco's case is a bit different: what he's doing is to filter out uneven sharding. I would like you guys to clarify on this a bit.
I think there're 3 concepts:
- empty local tensor. This comes from some special cases of uneven sharding.
- uneven sharding. This doesn't necessarily lead to empty local tensor.
- the current
is_tensor_shardablewhich meanstensor_dim_size >= mesh_size.
And there're some TODOs:
-
unify (2) and (3). The consequence is we won't be able to support uneven sharding. Wanchao has a TODO on unifying (2) and (3) ( see https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/utils.py#L147-L148 ) but I think this requires us to enable "always-on padding" (sorry I forgot the exact term we've been using, but the meaning is to always pad DTensor's local tensor rather than only pad/unpad around collectives) otherwise this unification rejects all uneven DTensors in sharding prop.
-
enforce
is_tensor_shardablecheck in every op's sharding prop.
My question to @Francescaaa is, why AutoParallel wants to avoid (1)? If this is a hard blocker, then we can consider upstreaming a property to OpStrategy that outputs all even sharding strategies, because filtering out all uneven sharding strategies directly in sharding prop may be problematic for current DTensor status (i.e. we need to support always-on padding).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can ultimately support uneven sharding. But it might be better to focus on the "simple case" first and have everything be evenly shardable.
In the end, we should filter out everything which is not possible, so something like tensor_dim_size >= mesh_size indeed. But to make things simpler to reason about I'm filtering a bit more things.
My question to @fmassa is, why AutoParallel wants to avoid (1)?
For now I believe keeping only even sharding will probably be best to ensure the rest of the stack is working as expected, as it will be easier to reason about (i.e., we don't need to check the outcome of many different traces to understand what changed between them).
But when supporting dynamic shapes (which will be important for Ads models), we will need to re-enable uneven sharding indeed
| embedding_nodes = autop.gm.graph.find_nodes( | ||
| op="call_function", target=torch.ops.aten.embedding.default | ||
| ) | ||
| autop.sharding_optimizer.add_node_constraint(embedding_nodes[0], x_sharding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i want to better understand / think harder about what you think the right way to model activation memory is. I need to think some more but just dumping some raw thoughts:
(1) today we only have a constraint to limit total param memory, not activation memory. I guess ideally we would model "peak memory hit during forward graph execution", which seems hard since we need to reason about lifetimes of the activations (it might be ok to allgather a large activation if we know it dies before we hit peak memory later on). But this is hard because we don't actually know which activations will be saved for backward vs recomputed (maybe some simplification that all "allgathered activations will be recomputed" is ok?)
(2) at the same time, replicating vs. sharding an activation will also change the runtime. Maybe there is a way to bias the solver against replicating activations if only because the compute will take longer? Although I guess this is only true for ops that have meaningful flops, since we're saying that pointwise ops are "free" today
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also wondering - if there are going to be cases where a user really does need to add a manual node constraint to get a more optimal solving w.r.t. memory usage, it would be nice to have tooling to make it easier to identify what nodes to change.
One thing that could be cool to do is to use FakeTensor + memory estimator code to figure out what node are run peak memory during the forward, and maybe provide some visualizations that a user can eyeball to find nodes that they think we have made bad sharding decisions for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding (1), yes, modelling activation memory properly is going to be hard as we don't know what is the decision the partitioner will take. One rough proxy is just to constraint that every node shouldn't use more than x GB of memory, but that's probably not great in all situations
On (2), yes, ideally if we had cost estimates that go beyond flops for those ops then just modelling the runtime would be enough. In the embedding case, given that we don't have a flop formula for it it is returned as being free to compute, which would then always favor replicate tensors because the cost to go from replicate to any sharding is 0.
I'm also wondering - if there are going to be cases where a user really does need to add a manual node constraint to get a more optimal solving w.r.t. memory usage, it would be nice to have tooling to make it easier to identify what nodes to change.
Yes, totally! We have a memory estimator that I've built (and has been improved by Shatian) which we could use it here. In genera we should work on improving the UX here.
One other possibility is to output a graph representation colored by activation size, which would make it easier to spot problematic nodes
…del from TorchTitan (#22) * Add Llama3 example This is a self-contained copy-paste from TorchTitan that works * Cleanup * Fix license * Remove invalid configurations that would yield empty shapes * Cleanup
This makes it easier to experiment with Llama3, as it doesn't require installing TorchTitan.
Also adds an example of how to add node constraints, and 1d mesh.
The example for the node constraint was necessary for now as the solution proposed was to all-gather the whole set of embeddings + inputs (which is small), and then shard it afterwards (which is free wrt comms as it would start from a replicate tensor). This is actually bad because the whole activation after the embedding is massive, but given that our solver doesn't take activation memory into account (only runtime and input memory), it thought it was fine to do it as is.
As a next step I'll look into adding some activation memory constraint, so that this type of behavior gets forbidden. Another solution would be to add a compute cost to embedding-bag, which for now has 0 compute cost as only gemm flops are taken into account.
Fix for smaller batch size
This PR uncovered a bug in the sharding schemes for a number of ops. Indeed, the sharding schemes would allow to shard tensors on a dimension which was smaller than the world size. As an example, a tensor of shape
[32, 4096, 4096]could previously be sharded asS(0),S(0)on a mesh of size(32, 8). This case is now forbidden.This should ideally be fixed upstream I believe.