Skip to content

Conversation

aporialiao
Copy link
Contributor

Summary: Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500

@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported labels Apr 1, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
Summary:
Pull Request resolved: meta-pytorch#2863

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 9, 2025
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Reviewed By: iamzainhuda

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 9, 2025
Summary:
Pull Request resolved: meta-pytorch#2863

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Reviewed By: iamzainhuda

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

Summary:

This will be crucial for any non-TW sharding type going through dynamic sharding. Handles the case where the embedding dimension is not the same across shards.

For example: 
```
num_embeddings = 4
embedding_dim = 16

Table 0: CW sharded across ranks: [0, 1]
Table 1: CW sharded across rank: [0]

Table 0 shard 0 size: [8, 4]
Table 1 shard 0 size: [16, 4]
```
This will require `Table 0 shard 0` and `Table 1 shard 0` to be concatenated in dimension 1. 

Main changes: 
## All_to_all Collective input/output tensor composition & processing
1. Concatenating `local_input_tensor` to `all_to_all` collective by dimension 1 instead of 0. This is because dim 0 is variable for each shard depending , while dim 1 is consistently the same across all shards/tables as it is the number of embeddings. 
2. This means we need to **transpose**, and properly process both the `local_input_tensor` and `local_output_tensor` to be passed into the `all_to_all` collective.
3. Made small optimization to the `local_output_tensor` to not be consistently updated via `torch.concat` since we only need the final dimensions for the empty tensor.

## Correct Order of `all_to_all` tensor output
To handle multiple shards per table, we need to properly store the **order** which the `all_to_all` collective is collecting the tensors across ranks. The order of shards composing the `local_output_tensor` is:
    1. First ordered by rank
    2. Then ordered by table order in the EBC -> this can be inferred from the `module_sharding_plan`
    3. Finally by the shard_order for this table. 
* Since we can assume each rank only contain 1 shard per table, we only need to track 1. and 2. The return type of `shards_all_to_all`, and input type of `update_state_dict_post_resharding` is updated to be a flattened list of the above order. 

* Also, I'm storing the `shard_size` in dim 1 for this output while composing the `local_output_tensor`, to avoid needing to re-query in `update_state_dict_post_resharding`. 


This will ensure correct behavior in the CW sharding implementation/test in the next diff.

Reviewed By: iamzainhuda

Differential Revision: D72486367
Summary:

Current implementation already supports CW sharding. so this only adds a unit test, and some utils for generating rank placements for +1 ranks.

Reviewed By: iamzainhuda

Differential Revision: D72017500
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72017500

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants