-
Notifications
You must be signed in to change notification settings - Fork 563
2/n CW ShardedTensor Test #2863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D72017500 |
c818d8c
to
6f4b84f
Compare
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
6f4b84f
to
95ba344
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
95ba344
to
a530afe
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
a530afe
to
3363462
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
3363462
to
7a4abb3
Compare
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Pull Request resolved: meta-pytorch#2863 Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
7a4abb3
to
a5cd548
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Differential Revision: D72017500
a5cd548
to
404da5e
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
404da5e
to
748fd03
Compare
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Reviewed By: iamzainhuda Differential Revision: D72017500
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Pull Request resolved: meta-pytorch#2863 Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Reviewed By: iamzainhuda Differential Revision: D72017500
748fd03
to
d7892ce
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: This will be crucial for any non-TW sharding type going through dynamic sharding. Handles the case where the embedding dimension is not the same across shards. For example: ``` num_embeddings = 4 embedding_dim = 16 Table 0: CW sharded across ranks: [0, 1] Table 1: CW sharded across rank: [0] Table 0 shard 0 size: [8, 4] Table 1 shard 0 size: [16, 4] ``` This will require `Table 0 shard 0` and `Table 1 shard 0` to be concatenated in dimension 1. Main changes: ## All_to_all Collective input/output tensor composition & processing 1. Concatenating `local_input_tensor` to `all_to_all` collective by dimension 1 instead of 0. This is because dim 0 is variable for each shard depending , while dim 1 is consistently the same across all shards/tables as it is the number of embeddings. 2. This means we need to **transpose**, and properly process both the `local_input_tensor` and `local_output_tensor` to be passed into the `all_to_all` collective. 3. Made small optimization to the `local_output_tensor` to not be consistently updated via `torch.concat` since we only need the final dimensions for the empty tensor. ## Correct Order of `all_to_all` tensor output To handle multiple shards per table, we need to properly store the **order** which the `all_to_all` collective is collecting the tensors across ranks. The order of shards composing the `local_output_tensor` is: 1. First ordered by rank 2. Then ordered by table order in the EBC -> this can be inferred from the `module_sharding_plan` 3. Finally by the shard_order for this table. * Since we can assume each rank only contain 1 shard per table, we only need to track 1. and 2. The return type of `shards_all_to_all`, and input type of `update_state_dict_post_resharding` is updated to be a flattened list of the above order. * Also, I'm storing the `shard_size` in dim 1 for this output while composing the `local_output_tensor`, to avoid needing to re-query in `update_state_dict_post_resharding`. This will ensure correct behavior in the CW sharding implementation/test in the next diff. Reviewed By: iamzainhuda Differential Revision: D72486367
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks. Reviewed By: iamzainhuda Differential Revision: D72017500
d7892ce
to
007b860
Compare
This pull request was exported from Phabricator. Differential Revision: D72017500 |
Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.
Differential Revision: D72017500