-
Notifications
You must be signed in to change notification settings - Fork 563
torchrec support on kvzch emb lookup module #2922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D73567631 |
This pull request was exported from Phabricator. Differential Revision: D73567631 |
Summary: X-link: pytorch/FBGEMM#4035 Pull Request resolved: meta-pytorch#2922 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv, emlin Differential Revision: D73567631
b4503c1
to
c636311
Compare
This pull request was exported from Phabricator. Differential Revision: D73567631 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D73567631 |
Summary: X-link: pytorch/FBGEMM#4035 Pull Request resolved: meta-pytorch#2922 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv, emlin Differential Revision: D73567631
c636311
to
dada6e3
Compare
This pull request was exported from Phabricator. Differential Revision: D73567631 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D73567631 |
dada6e3
to
e2369a0
Compare
Summary: X-link: pytorch/FBGEMM#4035 Pull Request resolved: meta-pytorch#2922 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv, emlin Differential Revision: D73567631
This pull request was exported from Phabricator. Differential Revision: D73567631 |
e2369a0
to
59d230a
Compare
Summary: X-link: pytorch/FBGEMM#4035 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv, emlin Differential Revision: D73567631
This pull request was exported from Phabricator. Differential Revision: D73567631 |
59d230a
to
723672c
Compare
Summary: X-link: pytorch/FBGEMM#4035 Pull Request resolved: meta-pytorch#2922 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv, emlin Differential Revision: D73567631
This pull request was exported from Phabricator. Differential Revision: D73567631 |
723672c
to
4418509
Compare
Summary: X-link: pytorch/FBGEMM#4035 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv Differential Revision: D73567631
This pull request was exported from Phabricator. Differential Revision: D73567631 |
Summary: X-link: pytorch/FBGEMM#4035 Pull Request resolved: meta-pytorch#2922 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv Differential Revision: D73567631
4418509
to
892c20f
Compare
This pull request was exported from Phabricator. Differential Revision: D73567631 |
Summary:
Change logs
#######################################################################
########################### embedding.py updates ##########################
#######################################################################
Differential Revision: D73567631