Skip to content

Conversation

duduyi2013
Copy link
Contributor

Summary:

Change logs

  1. add ZeroCollisionKeyValueEmbedding emb lookup
  2. address existing unit test missing for ssd offloading
  3. add new ut for kv zch embedding module
  4. add a temp hack solution for calculate bucket metadata
  5. embedding updates, details illustrated below

#######################################################################
########################### embedding.py updates ##########################
#######################################################################

  1. keep the original idea to init shardedTensor during training init
  2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
  3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
  4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
  5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
  6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done on the nn.module side, which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Differential Revision: D73567631

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 28, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

duduyi2013 added a commit to duduyi2013/torchrec that referenced this pull request May 6, 2025
Summary:
X-link: pytorch/FBGEMM#4035

Pull Request resolved: meta-pytorch#2922

X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv, emlin

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

duduyi2013 added a commit to duduyi2013/torchrec that referenced this pull request May 6, 2025
Summary:
X-link: pytorch/FBGEMM#4035

Pull Request resolved: meta-pytorch#2922

X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv, emlin

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

duduyi2013 added a commit to duduyi2013/torchrec that referenced this pull request May 7, 2025
Summary:
X-link: pytorch/FBGEMM#4035

Pull Request resolved: meta-pytorch#2922

X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv, emlin

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

duduyi2013 added a commit to duduyi2013/torchrec that referenced this pull request May 7, 2025
Summary:
X-link: pytorch/FBGEMM#4035


X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv, emlin

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

duduyi2013 added a commit to duduyi2013/torchrec that referenced this pull request May 7, 2025
Summary:
X-link: pytorch/FBGEMM#4035

Pull Request resolved: meta-pytorch#2922

X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv, emlin

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

duduyi2013 added a commit to duduyi2013/torchrec that referenced this pull request May 7, 2025
Summary:
X-link: pytorch/FBGEMM#4035


X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

Summary:
X-link: pytorch/FBGEMM#4035

Pull Request resolved: meta-pytorch#2922

X-link: facebookresearch/FBGEMM#1120

# Change logs
1. add ZeroCollisionKeyValueEmbedding emb lookup
2. address existing unit test missing for ssd offloading
3. add new ut for kv zch embedding module
4. add a temp hack solution for calculate bucket metadata
5. embedding updates, details illustrated below

#######################################################################
###########################  embedding.py updates ##########################
#######################################################################

1. keep the original idea to init shardedTensor during training init
2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows
3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks
4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id
5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly
6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors

Reviewed By: kausv

Differential Revision: D73567631
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D73567631

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants