Skip to content

Conversation

aporialiao
Copy link
Contributor

Summary:
Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

Motivation for Dynamic Sharding: Doc [Work in Progress]
Design: [WIP]

What's added here:

  1. A reshard API which implements the update_shards APIs for ShardedEmbeddingBagCollection

  2. Util functions for dynamic sharding - these are used by the update_shards API:

    1. extend_shard_name: for extending table_i to embedding_bags.table_i.weight
    2. shards_all_to_all: containing the all to all collective call to redistribute shards in a distributed environment, based on the changed_sharding_params
    3. update_state_dict_post_resharding: for updating a given state_dict with new shard placements and local_shards.
  3. A multi-process unit test test_dynamic_sharding_ebc_tw testing TW sharded EBCs calling the reshard API, sampling from various: world_sizes, num_tables, data_types.

    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call hereD71703434

Future work items (features not yet supported in this diff):

  • CW, RW, and many other sharding types
  • Optimizer saving
  • DTensor implementation

Differential Revision: D69095169

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 27, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 27, 2025
…2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 28, 2025
…2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 28, 2025
…2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 31, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Mar 31, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 1, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
…rch#2852)

Summary:
Pull Request resolved: meta-pytorch#2852

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here:
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`.

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`.
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69095169

aporialiao added a commit to aporialiao/torchrec that referenced this pull request Apr 2, 2025
…rch#2852)

Summary:

Add initial dynamic sharding API and test. This current version supports EBC, TW, and Sharded Tensor. Other variants beyond those configurations (e.g. CW, RW, DTensor etc..) to be added in next few diffs.

What's added here: 
1. A `reshard` API which implements the `update_shards` APIs for `ShardedEmbeddingBagCollection`
2. Util functions for dynamic sharding - these are used by the `update_shards` API:
    1. `extend_shard_name`: for extending `table_i` to `embedding_bags.table_i.weight`
    2. `shards_all_to_all`: containing the all to all collective call to redistribute shards in a distributed environment, based on the `changed_sharding_params`
    3. `update_state_dict_post_resharding`: for updating a given `state_dict` with new shard `placements` and `local_shards`. 

3. A multi-process unit test `test_dynamic_sharding_ebc_tw` testing TW sharded EBCs calling the `reshard` API, sampling from  various: `world_sizes`, `num_tables`, `data_types`. 
    1. This unit test also uses a few utils to generate random inputs and rank placements. A future todo will be to merge this input generation to use the generate call here D71703434 

Future work items (features not yet supported in this diff):
* CW, RW, and many other sharding types
* Optimizer saving
* DTensor implementation

Differential Revision: D69095169
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants