From 165c2d17d9d3a12f6d115daa157df4c9bf5cac8a Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 18 Nov 2024 14:00:09 -0800 Subject: [PATCH] add NJT/TD support in test data generator (#2528) Summary: # context * add NJT/TD support in test data generator * add NJT/TD input option in pipeline benchmark * resolve pyre/typing errors in multiple places * should be safe to land, no production impact NOTE: This diff is splitted from the next one (D65103519) to resolve pyre/typing errors Reviewed By: dstaay-fb Differential Revision: D65120889 --- install-requirements.txt | 1 + requirements.txt | 1 + ...enchmark_split_table_batched_embeddings.py | 9 +- .../distributed/benchmark/benchmark_utils.py | 5 +- .../distributed/test_utils/infer_utils.py | 4 +- torchrec/distributed/test_utils/test_model.py | 131 ++++++++++++------ .../distributed/tests/test_infer_shardings.py | 3 + .../tests/pipeline_benchmarks.py | 12 +- .../tests/test_train_pipelines.py | 6 +- .../keyed_jagged_tensor_benchmark_lib.py | 1 + 10 files changed, 126 insertions(+), 47 deletions(-) diff --git a/install-requirements.txt b/install-requirements.txt index ab2736d78..ed3c6aced 100644 --- a/install-requirements.txt +++ b/install-requirements.txt @@ -1,4 +1,5 @@ fbgemm-gpu +tensordict torchmetrics==1.0.3 tqdm pyre-extensions diff --git a/requirements.txt b/requirements.txt index b60a348f4..6d63107dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ numpy pandas pyre-extensions scikit-build +tensordict torchmetrics==1.0.3 torchx tqdm diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py index b03e7b417..8af1f9a46 100644 --- a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -9,6 +9,8 @@ #!/usr/bin/env python3 +from typing import Dict, List + import click import torch @@ -82,9 +84,10 @@ def op_bench( ) def _func_to_benchmark( - kjt: KeyedJaggedTensor, + kjts: List[Dict[str, KeyedJaggedTensor]], model: torch.nn.Module, ) -> torch.Tensor: + kjt = kjts[0]["feature"] return model.forward(kjt.values(), kjt.offsets()) # breakpoint() # import fbvscode; fbvscode.set_trace() @@ -108,8 +111,8 @@ def _func_to_benchmark( result = benchmark_func( name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}", - bench_inputs=inputs, # pyre-ignore - prof_inputs=inputs, # pyre-ignore + bench_inputs=[{"feature": inputs}], + prof_inputs=[{"feature": inputs}], num_benchmarks=10, num_profiles=10, profile_dir=".", diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 2177830cc..e8b37c4ec 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -374,11 +374,14 @@ def get_inputs( if train: sparse_features_by_rank = [ - model_input.idlist_features for model_input in model_input_by_rank + model_input.idlist_features + for model_input in model_input_by_rank + if isinstance(model_input.idlist_features, KeyedJaggedTensor) ] inputs_batch.append(sparse_features_by_rank) else: sparse_features = model_input_by_rank[0].idlist_features + assert isinstance(sparse_features, KeyedJaggedTensor) inputs_batch.append([sparse_features]) # Transpose if train, as inputs_by_rank is currently in [B X R] format diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 12fe247a5..a2018aaa2 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -262,6 +262,7 @@ def model_input_to_forward_args_kjt( Optional[torch.Tensor], ]: kjt = mi.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) return ( kjt._keys, kjt._values, @@ -289,7 +290,8 @@ def model_input_to_forward_args( ]: idlist_kjt = mi.idlist_features idscore_kjt = mi.idscore_features - assert idscore_kjt is not None + assert isinstance(idlist_kjt, KeyedJaggedTensor) + assert isinstance(idscore_kjt, KeyedJaggedTensor) return ( mi.float_features, idlist_kjt._keys, diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index b91ed8eb4..83ab255ee 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn +from tensordict import TensorDict from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -46,8 +47,8 @@ @dataclass class ModelInput(Pipelineable): float_features: torch.Tensor - idlist_features: KeyedJaggedTensor - idscore_features: Optional[KeyedJaggedTensor] + idlist_features: Union[KeyedJaggedTensor, TensorDict] + idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] label: torch.Tensor @staticmethod @@ -76,11 +77,13 @@ def generate( randomize_indices: bool = True, device: Optional[torch.device] = None, max_feature_lengths: Optional[List[int]] = None, + input_type: str = "kjt", ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch and a list of local (multi-rank training) batches of world_size. """ + batch_size_by_rank = [batch_size] * world_size if variable_batch_size: batch_size_by_rank = [ @@ -199,11 +202,26 @@ def _validate_pooling_factor( ) global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) - global_idlist_kjt = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(global_idlist_indices), - lengths=torch.cat(global_idlist_lengths), - ) + + if input_type == "kjt": + global_idlist_input = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + lengths=torch.cat(global_idlist_lengths), + ) + elif input_type == "td": + dict_of_nt = { + k: torch.nested.nested_tensor_from_jagged( + values=values, + lengths=lengths, + ) + for k, values, lengths in zip( + idlist_features, global_idlist_indices, global_idlist_lengths + ) + } + global_idlist_input = TensorDict(source=dict_of_nt) + else: + raise ValueError(f"For IdList features, unknown input type {input_type}") for idx in range(len(idscore_ind_ranges)): ind_range = idscore_ind_ranges[idx] @@ -245,16 +263,25 @@ def _validate_pooling_factor( global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) - global_idscore_kjt = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(global_idscore_indices), - lengths=torch.cat(global_idscore_lengths), - weights=torch.cat(global_idscore_weights), + + if input_type == "kjt": + global_idscore_input = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + lengths=torch.cat(global_idscore_lengths), + weights=torch.cat(global_idscore_weights), + ) + if global_idscore_indices + else None ) - if global_idscore_indices - else None - ) + elif input_type == "td": + assert ( + len(idscore_features) == 0 + ), "TensorDict does not support weighted features" + global_idscore_input = None + else: + raise ValueError(f"For weighted features, unknown input type {input_type}") if randomize_indices: global_float = torch.rand( @@ -303,27 +330,48 @@ def _validate_pooling_factor( weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) - local_idlist_kjt = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(local_idlist_indices), - lengths=torch.cat(local_idlist_lengths), - ) + if input_type == "kjt": + local_idlist_input = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + lengths=torch.cat(local_idlist_lengths), + ) - local_idscore_kjt = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(local_idscore_indices), - lengths=torch.cat(local_idscore_lengths), - weights=torch.cat(local_idscore_weights), + local_idscore_input = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + lengths=torch.cat(local_idscore_lengths), + weights=torch.cat(local_idscore_weights), + ) + if local_idscore_indices + else None + ) + elif input_type == "td": + dict_of_nt = { + k: torch.nested.nested_tensor_from_jagged( + values=values, + lengths=lengths, + ) + for k, values, lengths in zip( + idlist_features, local_idlist_indices, local_idlist_lengths + ) + } + local_idlist_input = TensorDict(source=dict_of_nt) + assert ( + len(idscore_features) == 0 + ), "TensorDict does not support weighted features" + local_idscore_input = None + + else: + raise ValueError( + f"For weighted features, unknown input type {input_type}" ) - if local_idscore_indices - else None - ) local_input = ModelInput( float_features=global_float[r * batch_size : (r + 1) * batch_size], - idlist_features=local_idlist_kjt, - idscore_features=local_idscore_kjt, + idlist_features=local_idlist_input, + idscore_features=local_idscore_input, label=global_label[r * batch_size : (r + 1) * batch_size], ) local_inputs.append(local_input) @@ -331,8 +379,8 @@ def _validate_pooling_factor( return ( ModelInput( float_features=global_float, - idlist_features=global_idlist_kjt, - idscore_features=global_idscore_kjt, + idlist_features=global_idlist_input, + idscore_features=global_idscore_input, label=global_label, ), local_inputs, @@ -623,8 +671,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": def record_stream(self, stream: torch.Stream) -> None: self.float_features.record_stream(stream) - self.idlist_features.record_stream(stream) - if self.idscore_features is not None: + if isinstance(self.idlist_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if isinstance(self.idscore_features, KeyedJaggedTensor): self.idscore_features.record_stream(stream) self.label.record_stream(stream) @@ -1753,10 +1802,12 @@ def forward( if self._preproc_module is not None: modified_input = self._preproc_module(modified_input) elif self._run_preproc_inline: + idlist_features = modified_input.idlist_features + assert isinstance(idlist_features, KeyedJaggedTensor) modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync( - modified_input.idlist_features.keys(), - modified_input.idlist_features.values(), - modified_input.idlist_features.lengths(), + idlist_features.keys(), + idlist_features.values(), + idlist_features.lengths(), ) modified_idlist_features = self.preproc_nonweighted( @@ -1820,6 +1871,8 @@ def forward(self, input: ModelInput) -> ModelInput: ) # stride will be same but features will be joined + assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) + assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) modified_input.idlist_features = KeyedJaggedTensor.concat( [modified_input.idlist_features, self._extra_input.idlist_features] ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index fc80ab17e..8118b6d9c 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -1969,6 +1969,7 @@ def test_sharded_quant_fp_ebc_tw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device @@ -2149,6 +2150,7 @@ def test_sharded_quant_mc_ec_rw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features + assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = None inputs.append( @@ -2285,6 +2287,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: ) inputs = [] kjt = model_inputs[0].idlist_features + assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index 538264c04..e8dc5eccb 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -75,6 +75,11 @@ def _gen_pipelines( default=100, help="Total number of sparse embeddings to be used.", ) +@click.option( + "--ratio_features_weighted", + default=0.4, + help="percentage of features weighted vs unweighted", +) @click.option( "--dim_emb", type=int, @@ -132,6 +137,7 @@ def _gen_pipelines( def main( world_size: int, n_features: int, + ratio_features_weighted: float, dim_emb: int, n_batches: int, batch_size: int, @@ -149,8 +155,9 @@ def main( os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) - num_features = n_features // 2 - num_weighted_features = n_features // 2 + num_weighted_features = int(n_features * ratio_features_weighted) + num_features = n_features - num_weighted_features + tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 1000, @@ -257,6 +264,7 @@ def _generate_data( world_size=world_size, num_float_features=num_float_features, pooling_avg=pooling_factor, + input_type=input_type, )[1] for i in range(num_batches) ] diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 729de3dec..18d0abd7a 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -301,7 +301,11 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) - data = [i.idlist_features for i in local_model_inputs] + data = [ + i.idlist_features + for i in local_model_inputs + if isinstance(i.idlist_features, KeyedJaggedTensor) + ] dataloader = iter(data) pipeline = TrainPipelinePT2( model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 235495494..1c409fcf2 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -169,6 +169,7 @@ def generate_kjt( randomize_indices=True, device=device, )[0] + assert isinstance(global_input.idlist_features, KeyedJaggedTensor) return global_input.idlist_features