From bca5ecff754da8c7d13fea15a0acb511d4266dfd Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Mon, 14 Feb 2022 15:29:17 -0800 Subject: [PATCH] Fix QuantEBC for shared sparse features Summary: A sparse feature can be shared between multiple tables. Fix QuantEBC for shared sparse features Reviewed By: zyan0 Differential Revision: D34220554 fbshipit-source-id: 5c64c991b9d8bab8c295335f7ebd0868d0e403b8 --- torchrec/quant/embedding_modules.py | 43 +++++++--- .../quant/tests/test_embedding_modules.py | 79 ++++++++++++------- 2 files changed, 82 insertions(+), 40 deletions(-) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 7a85de2c3..e986f2b43 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -135,10 +135,15 @@ def to_sparse_type(data_type: DataType) -> SparseType: self._is_weighted = is_weighted self._embedding_bag_configs: List[EmbeddingBagConfig] = embedding_configs - # pyre-fixme[24]: Non-generic type `nn.modules.container.ModuleList` cannot - # take parameters. - self.embedding_bags: nn.ModuleList[nn.Module] = nn.ModuleList() + self.embedding_bags: nn.ModuleList = nn.ModuleList() + self._embedding_names: List[str] = [] + self._lengths_per_embedding: List[int] = [] + shared_feature: Dict[str, bool] = {} + table_names = set() for emb_config in self._embedding_bag_configs: + if emb_config.name in table_names: + raise ValueError(f"Duplicate table name {emb_config.name}") + table_names.add(emb_config.name) emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( embedding_specs=[ ( @@ -155,39 +160,51 @@ def to_sparse_type(data_type: DataType) -> SparseType: weight_lists=[table_name_to_quantized_weights[emb_config.name]], device=device, ) - self.embedding_bags.append(emb_module) + if not emb_config.feature_names: + emb_config.feature_names = [emb_config.name] + for feature_name in emb_config.feature_names: + if feature_name not in shared_feature: + shared_feature[feature_name] = False + else: + shared_feature[feature_name] = True + self._lengths_per_embedding.append(emb_config.embedding_dim) + + for emb_config in self._embedding_bag_configs: + for feature_name in emb_config.feature_names: + if shared_feature[feature_name]: + self._embedding_names.append(feature_name + "@" + emb_config.name) + else: + self._embedding_names.append(feature_name) def forward( self, features: KeyedJaggedTensor, ) -> KeyedTensor: - keys: List[str] = [] pooled_embeddings: List[Tensor] = [] length_per_key: List[int] = [] + feature_dict = features.to_dict() for emb_config, emb_module in zip( self._embedding_bag_configs, self.embedding_bags ): for feature_name in emb_config.feature_names: - keys.append(feature_name) - - values = features[feature_name].values() - offsets = features[feature_name].offsets() - weights = features[feature_name].weights_or_none() + f = feature_dict[feature_name] + values = f.values() + offsets = f.offsets() pooled_embeddings.append( emb_module( indices=values.int(), offsets=offsets.int(), - per_sample_weights=weights, + per_sample_weights=f.weights() if self._is_weighted else None, ).float() ) length_per_key.append(emb_config.embedding_dim) return KeyedTensor( - keys=keys, + keys=self._embedding_names, values=torch.cat(pooled_embeddings, dim=1), - length_per_key=length_per_key, + length_per_key=self._lengths_per_embedding, ) def state_dict( diff --git a/torchrec/quant/tests/test_embedding_modules.py b/torchrec/quant/tests/test_embedding_modules.py index c0c926156..e9a6a146c 100644 --- a/torchrec/quant/tests/test_embedding_modules.py +++ b/torchrec/quant/tests/test_embedding_modules.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest +from typing import List import torch from torchrec.modules.embedding_configs import ( @@ -21,20 +22,11 @@ class EmbeddingBagCollectionTest(unittest.TestCase): - def test_ebc(self) -> None: - eb1_config = EmbeddingBagConfig( - name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"] - ) - eb2_config = EmbeddingBagConfig( - name="t2", embedding_dim=16, num_embeddings=10, feature_names=["f2"] - ) - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + def _test_ebc( + self, tables: List[EmbeddingBagConfig], features: KeyedJaggedTensor + ) -> None: + ebc = EmbeddingBagCollection(tables=tables) - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.as_tensor([0, 1]), - lengths=torch.as_tensor([1, 1]), - ) embeddings = ebc(features) # test forward @@ -50,23 +42,56 @@ def test_ebc(self) -> None: quantized_embeddings = qebc(features) self.assertEqual(embeddings.keys(), quantized_embeddings.keys()) - self.assertEqual(embeddings["f1"].shape, quantized_embeddings["f1"].shape) - self.assertTrue( - torch.allclose( - embeddings["f1"].cpu(), - quantized_embeddings["f1"].cpu().float(), - atol=1, + for key in embeddings.keys(): + self.assertEqual(embeddings[key].shape, quantized_embeddings[key].shape) + self.assertTrue( + torch.allclose( + embeddings[key].cpu(), + quantized_embeddings[key].cpu().float(), + atol=1, + ) ) - ) - self.assertTrue( - torch.allclose( - embeddings["f2"].cpu(), - quantized_embeddings["f2"].cpu().float(), - atol=1, - ) - ) # test state dict state_dict = ebc.state_dict() quantized_state_dict = qebc.state_dict() self.assertEqual(state_dict.keys(), quantized_state_dict.keys()) + + def test_ebc(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=16, num_embeddings=10, feature_names=["f2"] + ) + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.as_tensor([0, 1]), + lengths=torch.as_tensor([1, 1]), + ) + self._test_ebc([eb1_config, eb2_config], features) + + def test_shared_tables(self) -> None: + eb_config = EmbeddingBagConfig( + name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1", "f2"] + ) + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.as_tensor([0, 1]), + lengths=torch.as_tensor([1, 1]), + ) + self._test_ebc([eb_config], features) + + def test_shared_features(self) -> None: + eb1_config = EmbeddingBagConfig( + name="t1", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + eb2_config = EmbeddingBagConfig( + name="t2", embedding_dim=16, num_embeddings=10, feature_names=["f1"] + ) + features = KeyedJaggedTensor( + keys=["f1"], + values=torch.as_tensor([0, 1]), + lengths=torch.as_tensor([1, 1]), + ) + self._test_ebc([eb1_config, eb2_config], features)