From 363a1d24fe1462aff7b5b871c272afada573bea4 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Tue, 18 Jun 2024 12:41:05 -0700 Subject: [PATCH] Add dynamo configs to TorchrecPT2TrainPipeline (#2130) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2130 Torchrec Dynamo compilation requires those settings - setting them before every compilation to avoid failures for users on forgotten configuraiton Adding `import fbgemm_gpu.sparse_ops` to have fbgemm meta functions registered. Reviewed By: TroyGarden, gnahzg, Microve Differential Revision: D58688934 --- .../distributed/train_pipeline/train_pipelines.py | 14 ++++++++++++-- torchrec/pt2/utils.py | 11 +++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 1f48444d3..0152cf221 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -54,10 +54,10 @@ ) from torchrec.distributed.types import Awaitable from torchrec.pt2.checks import is_torchdynamo_compiling +from torchrec.pt2.utils import default_pipeline_input_transformer from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable - logger: logging.Logger = logging.getLogger(__name__) @@ -190,7 +190,10 @@ def __init__( ) self._pre_compile_fn = pre_compile_fn self._post_compile_fn = post_compile_fn - self._input_transformer = input_transformer + # pyre-ignore + self._input_transformer = ( + input_transformer or default_pipeline_input_transformer + ) self._iter = 0 self._cur_batch: Optional[In] = None @@ -215,6 +218,13 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: logger.info("Compiling model...") if self._pre_compile_fn: self._pre_compile_fn(self._model) + + # Mandatory dynamo configuration for Torchrec PT2 compilation + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt = ( + True + ) self._model.compile( fullgraph=cc.fullgraph, dynamic=cc.dynamic, backend=cc.backend ) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index ec43cd6a0..bf0800042 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -7,6 +7,7 @@ # pyre-strict + import torch from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -81,3 +82,13 @@ def kjt_for_pt2_tracing( stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, inverse_indices=inverse_indices, ) + + +# pyre-ignore +def default_pipeline_input_transformer(inp): + for attr_name in ["id_list_features", "id_score_list_features"]: + if hasattr(inp, attr_name): + attr = getattr(inp, attr_name) + if isinstance(attr, KeyedJaggedTensor): + setattr(inp, attr_name, kjt_for_pt2_tracing(attr)) + return inp