diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py index d86235591..214e81cb5 100644 --- a/torchrec/distributed/train_pipeline/__init__.py +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -13,6 +13,10 @@ Out, TrainPipelineContext, ) +from torchrec.distributed.train_pipeline.tracing import ( # noqa + ArgInfoStepFactory, # noqa + Tracer, # noqa +) from torchrec.distributed.train_pipeline.train_pipelines import ( # noqa EvalPipelineSparseDist, # noqa PrefetchTrainPipelineSparseDist, # noqa @@ -25,17 +29,14 @@ TrainPipelineSparseDist, # noqa TrainPipelineSparseDistCompAutograd, # noqa ) +from torchrec.distributed.train_pipeline.types import ArgInfo, CallArgs # noqa from torchrec.distributed.train_pipeline.utils import ( # noqa _override_input_dist_forwards, # noqa _rewrite_model, # noqa _start_data_dist, # noqa _to_device, # noqa _wait_for_batch, # noqa - ArgInfo, # noqa - ArgInfoStepFactory, # noqa - CallArgs, # noqa DataLoadingThread, # noqa SparseDataDistUtil, # noqa StageOut, # noqa - Tracer, # noqa ) diff --git a/torchrec/distributed/train_pipeline/postproc.py b/torchrec/distributed/train_pipeline/postproc.py new file mode 100644 index 000000000..d762d7f0d --- /dev/null +++ b/torchrec/distributed/train_pipeline/postproc.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +from collections import OrderedDict +from typing import Any, Dict, Iterable, Iterator, Optional, Set, Tuple, Union + +import torch + +from torch.nn.modules.module import _IncompatibleKeys +from torch.profiler import record_function + +from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext +from torchrec.distributed.train_pipeline.types import CallArgs +from torchrec.streamable import Pipelineable + +logger: logging.Logger = logging.getLogger(__name__) + + +class NoOpStream: + """No-Op Context manager that takes in a stream""" + + def __init__(self, stream: Optional[torch.Stream]) -> None: + self._stream = stream + + def __enter__(self) -> "NoOpStream": + """Return `self` upon entering the runtime context.""" + return self + + # pyre-ignore + def __exit__(self, exc_type, exc_value, traceback) -> None: + return None + + +class PipelinedPostproc(torch.nn.Module): + """ + Wrapper around postproc module found during model graph traversal for sparse data dist + pipelining. In addition to the original module, it encapsulates information needed for + execution such as list of ArgInfo and the current training pipeline context. + + Args: + postproc_module (torch.nn.Module): postproc module to run + fqn (str): fqn of the postproc module in the model being pipelined + args (CallArgs): CallArgs for the postproc module + context (TrainPipelineContext): Training context for the next iteration / batch + + Returns: + Any + + Example: + postproc = PipelinedPostproc(postproc_module, fqn, args, context) + # module-swap with pipeliend postproc + setattr(model, fqn, postproc) + """ + + _FORCE_STATE_DICT_LOAD = True + + def __init__( + self, + postproc_module: torch.nn.Module, + fqn: str, + args: CallArgs, + context: TrainPipelineContext, + # TODO: make streams non-optional - skipping now to avoid ripple effect + default_stream: Optional[torch.Stream], + dist_stream: Optional[torch.Stream], + ) -> None: + super().__init__() + self._postproc_module = postproc_module + self._fqn = fqn + self._args = args + self._context = context + self._default_stream = default_stream + self._dist_stream = dist_stream + if not default_stream: + logger.warning( + f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" + ) + if not dist_stream: + logger.warning( + f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" + ) + + if self._dist_stream: + device: torch.device = self._dist_stream.device + # pyre-ignore + self._stream_context = ( + torch.get_device_module(device).stream + if device.type in ["cuda", "mtia"] + else torch.cuda.stream + ) + else: + self._stream_context = NoOpStream + + @property + def postproc_module(self) -> torch.nn.Module: + return self._postproc_module + + @property + def fqn(self) -> str: + return self._fqn + + # pyre-ignore + def forward(self, *input, **kwargs) -> Any: + """ + Args: + Any args and kwargs during model fwd + During _start_data_dist, input[0] contains the current data + Returns: + Any + """ + if self._fqn in self._context.postproc_fwd_results: + # This should only be hit in two cases: + # 1) During model forward + # During model forward, avoid duplicate work + # by returning the cached result from previous + # iteration's _start_data_dist + # 2) During _start_data_dist when postproc module is + # shared by more than one args. e.g. if we have + # postproc_out_a = postproc_a(input) + # postproc_out_b = postproc_b(postproc_out_a) <- postproc_a shared + # postproc_out_c = postproc_c(postproc_out_a) <-^ + # When processing postproc_b, we cache value of postproc_a(input) + # so when processing postproc_c, we can reuse postproc_a(input) + res = self._context.postproc_fwd_results[self._fqn] + return res + + # Everything below should only be called during _start_data_dist stage + + # Build up arg and kwargs from recursive call to pass to postproc module + # Arguments to postproc module can be also be a derived product + # of another postproc module call, as long as module is pipelineable + + # Use input[0] as _start_data_dist only passes 1 arg + args, kwargs = self._args.build_args_kwargs(input[0]) + + with record_function(f"## sdd_input_postproc {self._context.index} ##"): + # should be no-op as we call this in dist stream + with self._stream_context(self._dist_stream): + res = self._postproc_module(*args, **kwargs) + + # Ensure postproc modules output is safe to use from default stream later + if self._default_stream and self._dist_stream: + self._default_stream.wait_stream(self._dist_stream) + + if isinstance(res, (torch.Tensor, Pipelineable, Iterable, Dict)): + # Result from module forward might be a complex type such as + # Tuple[KeyedJaggedTensor, Dict[str, torch.Tensor]] + # In this case, we need to first iterate over each element of tuple + # and call record_stream on first item as KJT is Pipelineable + # for the second item (Dict), we iterate over the values and call + # record_stream accordingly. + + # pyre-ignore[6] + PipelinedPostproc.recursive_record_stream(res, self._default_stream) + elif self._context.index == 0: + logger.warning( + f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" + ) + + with self._stream_context(self._default_stream): + # Cache results, only during _start_data_dist + self._context.postproc_fwd_results[self._fqn] = res + + return res + + @property + def args(self) -> CallArgs: + return self._args + + def set_context(self, context: TrainPipelineContext) -> None: + self._context = context + + def get_context(self) -> TrainPipelineContext: + return self._context + + def named_modules( + self, + memo: Optional[Set[torch.nn.Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ) -> Iterator[Tuple[str, torch.nn.Module]]: + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + # This is needed because otherwise the rewrite won't find the existing postproc, and will create a new one + # Also, `named_modules` need to include self - see base implementation in the nn.modules.Module + yield prefix, self + # Difference from base implementation is here - the child name (_postproc_module) is not added to the prefix + yield from self._postproc_module.named_modules( + memo, prefix, remove_duplicate + ) + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + yield from self._postproc_module.named_parameters( + prefix, + recurse, + remove_duplicate, + ) + + def named_buffers( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + yield from self._postproc_module.named_buffers( + prefix, recurse, remove_duplicate + ) + + # pyre-ignore [14] + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + # super().state_dict(destination, prefix, keep_vars) + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + self._postproc_module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + return destination + + # pyre-ignore [14] + def load_state_dict( + self, + state_dict: OrderedDict[str, torch.Tensor], + strict: bool = True, + ) -> _IncompatibleKeys: + return self._postproc_module.load_state_dict(state_dict, strict=strict) + + @staticmethod + def recursive_record_stream( + # pyre-fixme[2]: Parameter `re` must have a type that does not contain `Any` + res: Union[torch.Tensor, Pipelineable, Iterable[Any], Dict[Any, Any]], + stream: torch.Stream, + ) -> None: + if isinstance(res, torch.Tensor) and res.device.type in ["cuda", "mtia"]: + res.record_stream(stream) + elif isinstance(res, Pipelineable): + res.record_stream(stream) + elif isinstance(res, (list, tuple)): + for v in res: + PipelinedPostproc.recursive_record_stream(v, stream) + elif isinstance(res, dict): + for v in res.values(): + PipelinedPostproc.recursive_record_stream(v, stream) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 89bc97210..95ad64959 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -48,9 +48,16 @@ create_module_and_freeze, ) from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext +from torchrec.distributed.train_pipeline.postproc import PipelinedPostproc from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( TrainPipelineSparseDistTestBase, ) +from torchrec.distributed.train_pipeline.tracing import ( + GetAttrArgInfoStep, + GetItemArgInfoStep, + NoopArgInfoStep, + PostprocArgInfoStep, +) from torchrec.distributed.train_pipeline.train_pipelines import ( EvalPipelineSparseDist, PrefetchTrainPipelineSparseDist, @@ -65,13 +72,8 @@ DataLoadingThread, EmbeddingPipelinedForward, get_h2d_func, - GetAttrArgInfoStep, - GetItemArgInfoStep, - NoopArgInfoStep, PipelinedForward, - PipelinedPostproc, PipelineStage, - PostprocArgInfoStep, SparseDataDistUtil, StageOut, ) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 0138a4114..bab9e01e5 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -23,15 +23,14 @@ from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( TrainPipelineSparseDistTestBase, ) -from torchrec.distributed.train_pipeline.utils import ( - _rewrite_model, +from torchrec.distributed.train_pipeline.tracing import ( ArgInfo, ArgInfoStepFactory, CallArgs, NodeArgsHelper, - PipelinedForward, PipelinedPostproc, ) +from torchrec.distributed.train_pipeline.utils import _rewrite_model, PipelinedForward from torchrec.distributed.types import ShardingType from torchrec.sparse.jagged_tensor import KeyedJaggedTensor diff --git a/torchrec/distributed/train_pipeline/tracing.py b/torchrec/distributed/train_pipeline/tracing.py new file mode 100644 index 000000000..946348785 --- /dev/null +++ b/torchrec/distributed/train_pipeline/tracing.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +from dataclasses import dataclass +from itertools import chain +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import torch + +if not torch._running_with_deploy(): + from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 +else: + + class FSDP2: + pass + + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.fx.immutable_collections import ( + immutable_dict as fx_immutable_dict, + immutable_list as fx_immutable_list, +) +from torch.fx.node import Node + +from torchrec.distributed.model_parallel import ShardedModule +from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext +from torchrec.distributed.train_pipeline.postproc import PipelinedPostproc +from torchrec.distributed.train_pipeline.types import ArgInfo, BaseArgInfoStep, CallArgs + +logger: logging.Logger = logging.getLogger(__name__) + + +class NoopArgInfoStep(BaseArgInfoStep): + # pyre-ignore + def process(self, arg) -> Any: + return arg + + +class GetAttrArgInfoStep(BaseArgInfoStep): + def __init__(self, attr_name: str) -> None: + super().__init__() + self.attr_name = attr_name + + # pyre-ignore + def process(self, arg) -> Any: + return getattr(arg, self.attr_name) + + +class GetItemArgInfoStep(BaseArgInfoStep): + def __init__(self, item_index: Union[str, int]) -> None: + super().__init__() + self.item_index = item_index + + # pyre-ignore + def process(self, arg) -> Any: + return arg[self.item_index] + + +class PostprocArgInfoStep(BaseArgInfoStep): + def __init__(self, postproc_module: PipelinedPostproc) -> None: + super().__init__() + self.postproc_module = postproc_module + + # pyre-ignore + def process(self, arg) -> Any: + return self.postproc_module(arg) + + +class ScalarArgInfoStep(BaseArgInfoStep): + def __init__(self, value: object) -> None: + super().__init__() + self.value = value + + # pyre-ignore + def process(self, _arg) -> Any: + return self.value + + +class ListArgInfoStep(BaseArgInfoStep): + def __init__(self, value: List[object]) -> None: + super().__init__() + self.value = value + + # pyre-ignore + def process(self, arg) -> Any: + return [ + (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) + for v in self.value + ] + + +class DictArgInfoStep(BaseArgInfoStep): + def __init__(self, value: Dict[str, object]) -> None: + super().__init__() + self.value = value + + # pyre-ignore + def process(self, arg) -> Any: + return { + k: (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) + for k, v in self.value.items() + } + + +class ArgInfoStepFactory: + """ + Convenience class to reduce the amount of imports the external uses will have. + Should closely follow the constructor interfaces for the corresponding classes. + """ + + @classmethod + def noop(cls) -> NoopArgInfoStep: + return NoopArgInfoStep() + + @classmethod + def get_attr(cls, name: str) -> GetAttrArgInfoStep: + return GetAttrArgInfoStep(name) + + @classmethod + def get_item(cls, index: Union[str, int]) -> GetItemArgInfoStep: + return GetItemArgInfoStep(index) + + @classmethod + def postproc( + cls, pipelined_postproc_module: PipelinedPostproc + ) -> PostprocArgInfoStep: + return PostprocArgInfoStep(pipelined_postproc_module) + + @classmethod + def from_scalar(cls, value: object) -> ScalarArgInfoStep: + return ScalarArgInfoStep(value) + + @classmethod + def from_list(cls, value: List[object]) -> ListArgInfoStep: + return ListArgInfoStep(value) + + @classmethod + def from_dict(cls, value: Dict[str, object]) -> DictArgInfoStep: + return DictArgInfoStep(value) + + +def _check_args_for_call_module( + node: torch.fx.Node, +) -> bool: + """ + Recursively checks if args to a node is the result of a call_module. + """ + if node.op == "call_module": + return True + + for arg in node.args: + if isinstance(arg, torch.fx.Node) and _check_args_for_call_module(arg): + return True + + return False + + +def _check_postproc_pipelineable( + module: torch.nn.Module, +) -> bool: + for _, _ in module.named_parameters(recurse=True): + # Cannot have any trainable params for it to be pipelined + logger.warning( + f"Module {module} cannot be pipelined as it has trainable parameters" + ) + return False + return True + + +def _find_postproc_module_recursive( + module: torch.nn.Module, + postproc_module_fqn: str, +) -> Optional[torch.nn.Module]: + """ + Finds the postproc module in the model. + """ + for name, child in module.named_modules(): + if name == postproc_module_fqn: + return child + return None + + +class NodeArgsHelper: + def __init__( + self, + model: torch.nn.Module, + context: TrainPipelineContext, + pipeline_postproc: bool, + default_stream: Optional[torch.Stream] = None, + dist_stream: Optional[torch.Stream] = None, + ) -> None: + self._model = model + self._context = context + self._pipeline_postproc = pipeline_postproc + self._default_stream = default_stream + self._dist_stream = dist_stream + self._pipelined_postprocs: Set[PipelinedPostproc] = set() + + @property + def pipelined_postprocs(self) -> Set[PipelinedPostproc]: + return self._pipelined_postprocs + + def _swap_postproc_module_recursive( + self, + module: torch.nn.Module, + to_swap_module: torch.nn.Module, + postproc_module_fqn: str, + path: str = "", + ) -> torch.nn.Module: + """ + Swaps the postproc module in the model. + """ + if isinstance(module, PipelinedPostproc): + return module + + if path == postproc_module_fqn: + return to_swap_module + + for name, child in module.named_children(): + child = self._swap_postproc_module_recursive( + child, + to_swap_module, + postproc_module_fqn, + path + "." + name if path else name, + ) + setattr(module, name, child) + + return module + + def _handle_constant( + self, + arg: Any, # pyre-ignore + arg_info: ArgInfo, + for_postproc_module: bool = False, + ) -> Optional[ArgInfo]: + if not self._pipeline_postproc: + return None + + if isinstance(arg, fx_immutable_dict): + step = ArgInfoStepFactory.from_dict( + { + k: self._handle_collection_element(v, for_postproc_module) + for k, v in arg.items() + } + ) + elif isinstance(arg, fx_immutable_list): + step = ArgInfoStepFactory.from_list( + [self._handle_collection_element(v, for_postproc_module) for v in arg] + ) + else: + step = ArgInfoStepFactory.from_scalar(arg) + arg_info.add_step(step) + return arg_info + + # pyre-ignore[3] + def _handle_collection_element( + self, + # pyre-ignore[2] + arg: Any, + for_postproc_module: bool = False, + ) -> Any: + if not isinstance(arg, torch.fx.Node): + return arg + + arg_info_nested = self._get_node_args_helper_inner( + arg, + for_postproc_module, + ) + return arg_info_nested + + def _handle_placeholder( + self, child_node: torch.fx.Node, arg_info: ArgInfo + ) -> ArgInfo: + # note: mutates arg_info + if hasattr(child_node, "ph_key"): + # pyre-fixme[16] + ph_key: str = child_node.ph_key + # example: ph_key = 'event_id_list_features_seqs[marketplace]' + ph_key = ph_key.replace("[", ".") + ph_keys = ph_key.split(".") + for key in ph_keys: + if "]" in key: + k_ = key[:-1] + try: + k_ = int(k_) + except ValueError: + pass + arg_info.append_step(ArgInfoStepFactory.get_item(k_)) + else: + arg_info.append_step(ArgInfoStepFactory.get_attr(key)) + else: + # no-op + arg_info.add_step(ArgInfoStepFactory.noop()) + return arg_info + + def _handle_module( + self, child_node: torch.fx.Node, arg_info: ArgInfo + ) -> Optional[ArgInfo]: + postproc_module_fqn = str(child_node.target) + postproc_module = _find_postproc_module_recursive( + self._model, postproc_module_fqn + ) + + if not self._pipeline_postproc: + logger.warning( + f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" + ) + return None + + if not postproc_module: + # Could not find such module, should not happen + return None + + if isinstance(postproc_module, PipelinedPostproc): + # Already did module swap and registered args, early exit + self._pipelined_postprocs.add(postproc_module) + arg_info.add_step(ArgInfoStepFactory.postproc(postproc_module)) + return arg_info + + if not isinstance(postproc_module, torch.nn.Module): + logger.warning( + f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" + ) + return None + + # check if module is safe to pipeline i.e.no trainable param + if not _check_postproc_pipelineable(postproc_module): + return None + + # For module calls, `self` isn't counted + total_num_args = len(child_node.args) + len(child_node.kwargs) + if total_num_args == 0: + # module call without any args, assume KJT modified + return None + + # recursive call to check that all inputs to this postproc module + # is either made of postproc module or non-modifying train batch input + # transformations + postproc_args, num_found_safe_postproc_args = self.get_node_args( + child_node, + for_postproc_module=True, + ) + if num_found_safe_postproc_args == total_num_args: + logger.info( + f"""Module {postproc_module} is a valid postproc module (no + trainable params and inputs can be derived from train batch input + via a series of either valid postproc modules or non-modifying + transformations) and will be applied during sparse data dist + stage""" + ) + + pipelined_postproc_module = PipelinedPostproc( + postproc_module, + postproc_module_fqn, + postproc_args, + self._context, + default_stream=self._default_stream, + dist_stream=self._dist_stream, + ) + + # module swap + self._model = self._swap_postproc_module_recursive( + self._model, pipelined_postproc_module, postproc_module_fqn + ) + + self._pipelined_postprocs.add(pipelined_postproc_module) + arg_info.add_step(ArgInfoStepFactory.postproc(pipelined_postproc_module)) + return arg_info + + return None + + def _get_node_args_helper_inner( + self, + # pyre-ignore + arg, + for_postproc_module: bool = False, + ) -> Optional[ArgInfo]: + arg_info = ArgInfo([]) + while True: + if not isinstance(arg, torch.fx.Node): + return self._handle_constant(arg, arg_info, for_postproc_module) + + child_node = arg + + if child_node.op == "placeholder": + return self._handle_placeholder(arg, arg_info) + elif child_node.op == "call_module": + return self._handle_module(arg, arg_info) + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "builtins" + # pyre-fixme[16] + and child_node.target.__name__ == "getattr" + ): + arg_info.add_step( + # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown + ArgInfoStepFactory.get_attr(child_node.args[1]) + ) + arg = child_node.args[0] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "_operator" + # pyre-fixme[16] + and child_node.target.__name__ == "getitem" + ): + arg_info.add_step( + # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown + ArgInfoStepFactory.get_item(child_node.args[1]) + ) + arg = child_node.args[0] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "torch.utils._pytree" + # pyre-fixme[16] + and child_node.target.__name__ == "tree_unflatten" + ): + """ + This is for the PT2 export path where we unflatten the input to reconstruct + the structure with the recorded tree spec. + """ + step = arg_info.steps[0] + assert isinstance(step, GetItemArgInfoStep) + # pyre-fixme[16] + arg = child_node.args[0][step.item_index] + elif ( + child_node.op == "call_function" + and child_node.target.__module__ == "torchrec.sparse.jagged_tensor" + # pyre-fixme[16] + and child_node.target.__name__ == "KeyedJaggedTensor" + ): + call_module_found = False + + for arg_node in chain(child_node.args, child_node.kwargs.values()): + if isinstance( + arg_node, torch.fx.Node + ) and _check_args_for_call_module(arg_node): + call_module_found = True + break + + if call_module_found: + break + + if "values" in child_node.kwargs: + arg = child_node.kwargs["values"] + else: + arg = child_node.args[1] + + elif child_node.op == "call_method" and child_node.target == "get": + # pyre-ignore[6] + arg_info.add_step(ArgInfoStepFactory.get_item(child_node.args[1])) + arg = child_node.args[0] + else: + break + + # if we couldn't hit one of the "decisive" outcomes (constant, placeholder or module), return "not found" + return None + + def _get_node_args_helper( + self, + arguments, # pyre-ignore[2] + # Add `None` constants to arg info only for postproc modules + # Defaults to False for backward compatibility + for_postproc_module: bool = False, + ) -> Tuple[List[ArgInfo], int]: + """ + Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. + It also counts the number of (args + kwargs) found. + """ + num_found = 0 + arg_info_list = [] + for arg in arguments: + if not for_postproc_module and arg is None: + arg_info = ArgInfo([ArgInfoStepFactory.from_scalar(None)]) + arg_info_list.append(arg_info) + num_found += 1 + continue + arg_info = self._get_node_args_helper_inner( + arg, + for_postproc_module, + ) + if arg_info is not None: + num_found += 1 + arg_info_list.append(arg_info) + return arg_info_list, num_found + + def get_node_args( + self, + node: Node, + for_postproc_module: bool = False, + ) -> Tuple[CallArgs, int]: + pos_arg_info_list, args_found = self._get_node_args_helper( + node.args, + for_postproc_module, + ) + kwargs_arg_info_list, kwargs_found = self._get_node_args_helper( + node.kwargs.values(), + for_postproc_module, + ) + + # Replace with proper names for kwargs + kwargs_info_list = dict(zip(node.kwargs, kwargs_arg_info_list)) + + return CallArgs(pos_arg_info_list, kwargs_info_list), args_found + kwargs_found + + +def _get_leaf_module_names(model: torch.nn.Module) -> List[str]: + """ + Returns a list of top level modules to be used as leaf modules for FX tracing. + This is a shallow FX trace that only goes the minimum depth required to pipeline. + Any sub-module who does not contain a ShardedModule would be considered as a leaf + module unless explicitly tagged as `_is_pytorch_fx_traceable = True`. + """ + + def _get_leaf_module_names_helper( + model: torch.nn.Module, + path: str, + leaf_module_names: Set[str], + ) -> bool: + """ + recursive function returns True if any of the sub-modules is ShardedModule. + it also added the fqns of the sub-modules who do not contain any ShardedModule + into the `leaf_module_names` unless it's marked as `_is_pytorch_fx_traceable = True`, + which suggests this ShardedModule-free module should NOT be treated as a leaf module + """ + sharded_children = set() + for name, child in model.named_children(): + curr_path = path + name + if isinstance(child, ShardedModule): + sharded_children.add(name) + else: + child_sharded = _get_leaf_module_names_helper( + child, + curr_path + ".", + leaf_module_names, + ) + if child_sharded: + sharded_children.add(name) + + # only do this for hybrid module (has sharded child) + if len(sharded_children) > 0: + for name, child in model.named_children(): + if name in sharded_children: + continue + # assume module is leaf node unless annotated otherwise + if not getattr(child, "_is_pytorch_fx_traceable", False): + leaf_module_names.add(path + name) + return len(sharded_children) > 0 + + leaf_module_names: Set[str] = set() + _get_leaf_module_names_helper( + model, + "", + leaf_module_names, + ) + return list(leaf_module_names) + + +class Tracer(torch.fx.Tracer): + """ + The Trace class used in `_rewrite_model`, treating all ShardedModules and ShardedModule-free + modules as leaf modules. A module who is not a ShardedModule but contains ShardedModule would + NOT be considered as a leaf module. + """ + + # Disables proxying buffers during tracing. Ideally, proxying buffers would be + # disabled, but some models are currently mutating buffer values, which causes errors + # during tracing. If those models can be rewritten to not do that, we can likely + # remove this line. + proxy_buffer_attributes = False + + def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: + super().__init__() + self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] + + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + if ( + isinstance(m, ShardedModule) + or module_qualified_name in self._leaf_modules + or isinstance(m, FSDP) + or isinstance(m, FSDP2) + ): + return True + return super().is_leaf_module(m, module_qualified_name) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 99babf987..e8517e8d7 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -40,6 +40,7 @@ PrefetchTrainPipelineContext, TrainPipelineContext, ) +from torchrec.distributed.train_pipeline.tracing import PipelinedPostproc from torchrec.distributed.train_pipeline.utils import ( _override_input_dist_forwards, _pipeline_detach_model, @@ -54,7 +55,6 @@ EmbeddingPipelinedForward, InSyncEmbeddingPipelinedForward, PipelinedForward, - PipelinedPostproc, PipelineStage, PrefetchPipelinedForward, RunnableType, diff --git a/torchrec/distributed/train_pipeline/types.py b/torchrec/distributed/train_pipeline/types.py new file mode 100644 index 000000000..f96dfcc8c --- /dev/null +++ b/torchrec/distributed/train_pipeline/types.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import abc +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + + +class BaseArgInfoStep(abc.ABC): + @abc.abstractmethod + # pyre-ignore + def process(self, arg) -> Any: + raise Exception("Not implemented in the BaseArgInfoStep") + + def __eq__(self, other: object) -> bool: + """ + Some tests use the equality checks on the ArgInfo and/or CallArgs, so it's + natural to use dataclasses for ArgInfoStep implementations. However + Torchrec doesn't like dataclasses: https://github.com/pytorch/pytorch/issues/74909 + + So, this class creates a makeshift generic implementation similar to dataclass, but without + dataclass. + """ + if not isinstance(other, type(self)): + return False + return all( + getattr(self, field_name) == getattr(other, field_name) + for field_name in self.__dict__.keys() + ) + + +@dataclass +class ArgInfo: + """ + Representation of args from a node. + + Attributes: + steps (List[ArgInfoStep]): sequence of transformations from input batch. + Steps can be thought of consequtive transformations on the input, with + output of previous step used as an input for the next. I.e. for 3 steps + it is similar to step3(step2(step1(input))) + See `BaseArgInfoStep` class hierearchy for supported transformations + """ + + steps: List[BaseArgInfoStep] + + def add_step(self, step: BaseArgInfoStep) -> "ArgInfo": + self.steps.insert(0, step) + return self + + def append_step(self, step: BaseArgInfoStep) -> "ArgInfo": + self.steps.append(step) + return self + + # pyre-ignore[3] + def process_steps( + self, + arg: Any, # pyre-ignore[2] + ) -> Any: + if not self.steps: + return None + for step in self.steps: + arg = step.process(arg) + + return arg + + +@dataclass +class CallArgs: + args: List[ArgInfo] + kwargs: Dict[str, ArgInfo] + + # pyre-ignore[3] + def build_args_kwargs( + self, initial_input: Any # pyre-ignore[2] + ) -> Tuple[List[Any], Dict[str, Any]]: + args = [arg.process_steps(initial_input) for arg in self.args] + kwargs = { + key: arg.process_steps(initial_input) for key, arg in self.kwargs.items() + } + return args, kwargs diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 74337514c..cd6736c1d 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -6,17 +6,14 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -import abc - import contextlib import copy import itertools import logging -from collections import defaultdict, deque, OrderedDict +from collections import defaultdict, deque from contextlib import AbstractContextManager from dataclasses import dataclass -from itertools import chain from threading import Event, Thread from typing import ( Any, @@ -39,24 +36,9 @@ import torch from torch import distributed as dist +from torch.profiler import record_function from torch.utils.hooks import RemovableHandle -if not torch._running_with_deploy(): - from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2 -else: - - class FSDP2: - pass - - -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.fx.immutable_collections import ( - immutable_dict as fx_immutable_dict, - immutable_list as fx_immutable_list, -) -from torch.fx.node import Node -from torch.nn.modules.module import _IncompatibleKeys -from torch.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable from torchrec.distributed.embedding_sharding import ( FusedKJTListSplitsAwaitable, @@ -72,7 +54,13 @@ class FSDP2: PrefetchTrainPipelineContext, TrainPipelineContext, ) - +from torchrec.distributed.train_pipeline.tracing import ( + _get_leaf_module_names, + CallArgs, + NodeArgsHelper, + PipelinedPostproc, + Tracer, +) from torchrec.distributed.types import Awaitable, LazyNoWait from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor @@ -112,425 +100,6 @@ class PipelineStage: data_exhausted_callback: Optional[Callable[[], None]] = None -class BaseArgInfoStep(abc.ABC): - @abc.abstractmethod - # pyre-ignore - def process(self, arg) -> Any: - raise Exception("Not implemented in the BaseArgInfoStep") - - def __eq__(self, other: object) -> bool: - """ - Some tests use the equality checks on the ArgInfo and/or CallArgs, so it's - natural to use dataclasses for ArgInfoStep implementations. However - Torchrec doesn't like dataclasses: https://github.com/pytorch/pytorch/issues/74909 - - So, this class creates a makeshift generic implementation similar to dataclass, but without - dataclass. - """ - if not isinstance(other, type(self)): - return False - return all( - getattr(self, field_name) == getattr(other, field_name) - for field_name in self.__dict__.keys() - ) - - -class NoopArgInfoStep(BaseArgInfoStep): - # pyre-ignore - def process(self, arg) -> Any: - return arg - - -class GetAttrArgInfoStep(BaseArgInfoStep): - def __init__(self, attr_name: str) -> None: - super().__init__() - self.attr_name = attr_name - - # pyre-ignore - def process(self, arg) -> Any: - return getattr(arg, self.attr_name) - - -class GetItemArgInfoStep(BaseArgInfoStep): - def __init__(self, item_index: Union[str, int]) -> None: - super().__init__() - self.item_index = item_index - - # pyre-ignore - def process(self, arg) -> Any: - return arg[self.item_index] - - -class PostprocArgInfoStep(BaseArgInfoStep): - def __init__(self, postproc_module: "PipelinedPostproc") -> None: - super().__init__() - self.postproc_module = postproc_module - - # pyre-ignore - def process(self, arg) -> Any: - return self.postproc_module(arg) - - -class ScalarArgInfoStep(BaseArgInfoStep): - def __init__(self, value: object) -> None: - super().__init__() - self.value = value - - # pyre-ignore - def process(self, _arg) -> Any: - return self.value - - -class ListArgInfoStep(BaseArgInfoStep): - def __init__(self, value: List[object]) -> None: - super().__init__() - self.value = value - - # pyre-ignore - def process(self, arg) -> Any: - return [ - (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) - for v in self.value - ] - - -class DictArgInfoStep(BaseArgInfoStep): - def __init__(self, value: Dict[str, object]) -> None: - super().__init__() - self.value = value - - # pyre-ignore - def process(self, arg) -> Any: - return { - k: (v if not isinstance(v, ArgInfo) else v.process_steps(arg)) - for k, v in self.value.items() - } - - -class ArgInfoStepFactory: - """ - Convenience class to reduce the amount of imports the external uses will have. - Should closely follow the constructor interfaces for the corresponding classes. - """ - - @classmethod - def noop(cls) -> NoopArgInfoStep: - return NoopArgInfoStep() - - @classmethod - def get_attr(cls, name: str) -> GetAttrArgInfoStep: - return GetAttrArgInfoStep(name) - - @classmethod - def get_item(cls, index: Union[str, int]) -> GetItemArgInfoStep: - return GetItemArgInfoStep(index) - - @classmethod - def postproc( - cls, pipelined_postproc_module: "PipelinedPostproc" - ) -> PostprocArgInfoStep: - return PostprocArgInfoStep(pipelined_postproc_module) - - @classmethod - def from_scalar(cls, value: object) -> ScalarArgInfoStep: - return ScalarArgInfoStep(value) - - @classmethod - def from_list(cls, value: List[object]) -> ListArgInfoStep: - return ListArgInfoStep(value) - - @classmethod - def from_dict(cls, value: Dict[str, object]) -> DictArgInfoStep: - return DictArgInfoStep(value) - - -@dataclass -class ArgInfo: - """ - Representation of args from a node. - - Attributes: - steps (List[ArgInfoStep]): sequence of transformations from input batch. - Steps can be thought of consequtive transformations on the input, with - output of previous step used as an input for the next. I.e. for 3 steps - it is similar to step3(step2(step1(input))) - See `BaseArgInfoStep` class hierearchy for supported transformations - """ - - steps: List[BaseArgInfoStep] - - def add_step(self, step: BaseArgInfoStep) -> "ArgInfo": - self.steps.insert(0, step) - return self - - def append_step(self, step: BaseArgInfoStep) -> "ArgInfo": - self.steps.append(step) - return self - - # pyre-ignore[3] - def process_steps( - self, - arg: Any, # pyre-ignore[2] - ) -> Any: - if not self.steps: - return None - for step in self.steps: - arg = step.process(arg) - - return arg - - -@dataclass -class CallArgs: - args: List[ArgInfo] - kwargs: Dict[str, ArgInfo] - - # pyre-ignore[3] - def build_args_kwargs( - self, initial_input: Any # pyre-ignore[2] - ) -> Tuple[List[Any], Dict[str, Any]]: - args = [arg.process_steps(initial_input) for arg in self.args] - kwargs = { - key: arg.process_steps(initial_input) for key, arg in self.kwargs.items() - } - return args, kwargs - - -def recursive_record_stream( - # pyre-fixme[2]: Parameter `re` must have a type that does not contain `Any` - res: Union[torch.Tensor, Pipelineable, Iterable[Any], Dict[Any, Any]], - stream: torch.Stream, -) -> None: - if isinstance(res, torch.Tensor) and res.device.type in ["cuda", "mtia"]: - res.record_stream(stream) - elif isinstance(res, Pipelineable): - res.record_stream(stream) - elif isinstance(res, (list, tuple)): - for v in res: - recursive_record_stream(v, stream) - elif isinstance(res, dict): - for v in res.values(): - recursive_record_stream(v, stream) - - -class NoOpStream: - """No-Op Context manager that takes in a stream""" - - def __init__(self, stream: Optional[torch.Stream]) -> None: - self._stream = stream - - def __enter__(self) -> "NoOpStream": - """Return `self` upon entering the runtime context.""" - return self - - # pyre-ignore - def __exit__(self, exc_type, exc_value, traceback) -> None: - return None - - -class PipelinedPostproc(torch.nn.Module): - """ - Wrapper around postproc module found during model graph traversal for sparse data dist - pipelining. In addition to the original module, it encapsulates information needed for - execution such as list of ArgInfo and the current training pipeline context. - - Args: - postproc_module (torch.nn.Module): postproc module to run - fqn (str): fqn of the postproc module in the model being pipelined - args (CallArgs): CallArgs for the postproc module - context (TrainPipelineContext): Training context for the next iteration / batch - - Returns: - Any - - Example: - postproc = PipelinedPostproc(postproc_module, fqn, args, context) - # module-swap with pipeliend postproc - setattr(model, fqn, postproc) - """ - - _FORCE_STATE_DICT_LOAD = True - - def __init__( - self, - postproc_module: torch.nn.Module, - fqn: str, - args: CallArgs, - context: TrainPipelineContext, - # TODO: make streams non-optional - skipping now to avoid ripple effect - default_stream: Optional[torch.Stream], - dist_stream: Optional[torch.Stream], - ) -> None: - super().__init__() - self._postproc_module = postproc_module - self._fqn = fqn - self._args = args - self._context = context - self._default_stream = default_stream - self._dist_stream = dist_stream - if not default_stream: - logger.warning( - f"Postproc module {fqn} has no default stream. This may cause race conditions and NaNs during training!" - ) - if not dist_stream: - logger.warning( - f"Postproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!" - ) - - if self._dist_stream: - device: torch.device = self._dist_stream.device - # pyre-ignore - self._stream_context = ( - torch.get_device_module(device).stream - if device.type in ["cuda", "mtia"] - else torch.cuda.stream - ) - else: - self._stream_context = NoOpStream - - @property - def postproc_module(self) -> torch.nn.Module: - return self._postproc_module - - @property - def fqn(self) -> str: - return self._fqn - - # pyre-ignore - def forward(self, *input, **kwargs) -> Any: - """ - Args: - Any args and kwargs during model fwd - During _start_data_dist, input[0] contains the current data - Returns: - Any - """ - if self._fqn in self._context.postproc_fwd_results: - # This should only be hit in two cases: - # 1) During model forward - # During model forward, avoid duplicate work - # by returning the cached result from previous - # iteration's _start_data_dist - # 2) During _start_data_dist when postproc module is - # shared by more than one args. e.g. if we have - # postproc_out_a = postproc_a(input) - # postproc_out_b = postproc_b(postproc_out_a) <- postproc_a shared - # postproc_out_c = postproc_c(postproc_out_a) <-^ - # When processing postproc_b, we cache value of postproc_a(input) - # so when processing postproc_c, we can reuse postproc_a(input) - res = self._context.postproc_fwd_results[self._fqn] - return res - - # Everything below should only be called during _start_data_dist stage - - # Build up arg and kwargs from recursive call to pass to postproc module - # Arguments to postproc module can be also be a derived product - # of another postproc module call, as long as module is pipelineable - - # Use input[0] as _start_data_dist only passes 1 arg - args, kwargs = self._args.build_args_kwargs(input[0]) - - with record_function(f"## sdd_input_postproc {self._context.index} ##"): - # should be no-op as we call this in dist stream - with self._stream_context(self._dist_stream): - res = self._postproc_module(*args, **kwargs) - - # Ensure postproc modules output is safe to use from default stream later - if self._default_stream and self._dist_stream: - self._default_stream.wait_stream(self._dist_stream) - - if isinstance(res, (torch.Tensor, Pipelineable, Iterable, Dict)): - # Result from module forward might be a complex type such as - # Tuple[KeyedJaggedTensor, Dict[str, torch.Tensor]] - # In this case, we need to first iterate over each element of tuple - # and call record_stream on first item as KJT is Pipelineable - # for the second item (Dict), we iterate over the values and call - # record_stream accordingly. - - # pyre-ignore[6] - recursive_record_stream(res, self._default_stream) - elif self._context.index == 0: - logger.warning( - f"Result of postproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!" - ) - - with self._stream_context(self._default_stream): - # Cache results, only during _start_data_dist - self._context.postproc_fwd_results[self._fqn] = res - - return res - - @property - def args(self) -> CallArgs: - return self._args - - def set_context(self, context: TrainPipelineContext) -> None: - self._context = context - - def get_context(self) -> TrainPipelineContext: - return self._context - - def named_modules( - self, - memo: Optional[Set[torch.nn.Module]] = None, - prefix: str = "", - remove_duplicate: bool = True, - ) -> Iterator[Tuple[str, torch.nn.Module]]: - if memo is None: - memo = set() - if self not in memo: - if remove_duplicate: - memo.add(self) - # This is needed because otherwise the rewrite won't find the existing postproc, and will create a new one - # Also, `named_modules` need to include self - see base implementation in the nn.modules.Module - yield prefix, self - # Difference from base implementation is here - the child name (_postproc_module) is not added to the prefix - yield from self._postproc_module.named_modules( - memo, prefix, remove_duplicate - ) - - def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.nn.Parameter]]: - yield from self._postproc_module.named_parameters( - prefix, - recurse, - remove_duplicate, - ) - - def named_buffers( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - yield from self._postproc_module.named_buffers( - prefix, recurse, remove_duplicate - ) - - # pyre-ignore [14] - def state_dict( - self, - destination: Optional[Dict[str, Any]] = None, - prefix: str = "", - keep_vars: bool = False, - ) -> Dict[str, Any]: - # super().state_dict(destination, prefix, keep_vars) - if destination is None: - destination = OrderedDict() - # pyre-ignore [16] - destination._metadata = OrderedDict() - self._postproc_module.state_dict( - destination=destination, prefix=prefix, keep_vars=keep_vars - ) - return destination - - # pyre-ignore [14] - def load_state_dict( - self, - state_dict: OrderedDict[str, torch.Tensor], - strict: bool = True, - ) -> _IncompatibleKeys: - return self._postproc_module.load_state_dict(state_dict, strict=strict) - - TForwardContext = TypeVar("TForwardContext", bound=TrainPipelineContext) EmbeddingModuleRetType = Union[Dict[str, JaggedTensor], KeyedTensor] @@ -803,34 +372,6 @@ def __call__(self, input: KeyedJaggedTensor) -> KJTSplitsAllToAllMeta: ) -class Tracer(torch.fx.Tracer): - """ - The Trace class used in `_rewrite_model`, treating all ShardedModules and ShardedModule-free - modules as leaf modules. A module who is not a ShardedModule but contains ShardedModule would - NOT be considered as a leaf module. - """ - - # Disables proxying buffers during tracing. Ideally, proxying buffers would be - # disabled, but some models are currently mutating buffer values, which causes errors - # during tracing. If those models can be rewritten to not do that, we can likely - # remove this line. - proxy_buffer_attributes = False - - def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: - super().__init__() - self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] - - def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - if ( - isinstance(m, ShardedModule) - or module_qualified_name in self._leaf_modules - or isinstance(m, FSDP) - or isinstance(m, FSDP2) - ): - return True - return super().is_leaf_module(m, module_qualified_name) - - def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: assert isinstance( batch, (torch.Tensor, Pipelineable) @@ -976,424 +517,6 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: ) -def _check_args_for_call_module( - node: torch.fx.Node, -) -> bool: - """ - Recursively checks if args to a node is the result of a call_module. - """ - if node.op == "call_module": - return True - - for arg in node.args: - if isinstance(arg, torch.fx.Node) and _check_args_for_call_module(arg): - return True - - return False - - -def _check_postproc_pipelineable( - module: torch.nn.Module, -) -> bool: - for _, _ in module.named_parameters(recurse=True): - # Cannot have any trainable params for it to be pipelined - logger.warning( - f"Module {module} cannot be pipelined as it has trainable parameters" - ) - return False - return True - - -def _find_postproc_module_recursive( - module: torch.nn.Module, - postproc_module_fqn: str, -) -> Optional[torch.nn.Module]: - """ - Finds the postproc module in the model. - """ - for name, child in module.named_modules(): - if name == postproc_module_fqn: - return child - return None - - -class NodeArgsHelper: - def __init__( - self, - model: torch.nn.Module, - context: TrainPipelineContext, - pipeline_postproc: bool, - default_stream: Optional[torch.Stream] = None, - dist_stream: Optional[torch.Stream] = None, - ) -> None: - self._model = model - self._context = context - self._pipeline_postproc = pipeline_postproc - self._default_stream = default_stream - self._dist_stream = dist_stream - self._pipelined_postprocs: Set[PipelinedPostproc] = set() - - @property - def pipelined_postprocs(self) -> Set[PipelinedPostproc]: - return self._pipelined_postprocs - - def _swap_postproc_module_recursive( - self, - module: torch.nn.Module, - to_swap_module: torch.nn.Module, - postproc_module_fqn: str, - path: str = "", - ) -> torch.nn.Module: - """ - Swaps the postproc module in the model. - """ - if isinstance(module, PipelinedPostproc): - return module - - if path == postproc_module_fqn: - return to_swap_module - - for name, child in module.named_children(): - child = self._swap_postproc_module_recursive( - child, - to_swap_module, - postproc_module_fqn, - path + "." + name if path else name, - ) - setattr(module, name, child) - - return module - - def _handle_constant( - self, - arg: Any, # pyre-ignore - arg_info: ArgInfo, - for_postproc_module: bool = False, - ) -> Optional[ArgInfo]: - if not self._pipeline_postproc: - return None - - if isinstance(arg, fx_immutable_dict): - step = ArgInfoStepFactory.from_dict( - { - k: self._handle_collection_element(v, for_postproc_module) - for k, v in arg.items() - } - ) - elif isinstance(arg, fx_immutable_list): - step = ArgInfoStepFactory.from_list( - [self._handle_collection_element(v, for_postproc_module) for v in arg] - ) - else: - step = ArgInfoStepFactory.from_scalar(arg) - arg_info.add_step(step) - return arg_info - - # pyre-ignore[3] - def _handle_collection_element( - self, - # pyre-ignore[2] - arg: Any, - for_postproc_module: bool = False, - ) -> Any: - if not isinstance(arg, torch.fx.Node): - return arg - - arg_info_nested = self._get_node_args_helper_inner( - arg, - for_postproc_module, - ) - return arg_info_nested - - def _handle_placeholder( - self, child_node: torch.fx.Node, arg_info: ArgInfo - ) -> ArgInfo: - # note: mutates arg_info - if hasattr(child_node, "ph_key"): - # pyre-fixme[16] - ph_key: str = child_node.ph_key - # example: ph_key = 'event_id_list_features_seqs[marketplace]' - ph_key = ph_key.replace("[", ".") - ph_keys = ph_key.split(".") - for key in ph_keys: - if "]" in key: - k_ = key[:-1] - try: - k_ = int(k_) - except ValueError: - pass - arg_info.append_step(ArgInfoStepFactory.get_item(k_)) - else: - arg_info.append_step(ArgInfoStepFactory.get_attr(key)) - else: - # no-op - arg_info.add_step(ArgInfoStepFactory.noop()) - return arg_info - - def _handle_module( - self, child_node: torch.fx.Node, arg_info: ArgInfo - ) -> Optional[ArgInfo]: - postproc_module_fqn = str(child_node.target) - postproc_module = _find_postproc_module_recursive( - self._model, postproc_module_fqn - ) - - if not self._pipeline_postproc: - logger.warning( - f"Found module {postproc_module} that potentially modifies KJ. Train pipeline initialized with `pipeline_postproc=False` (default), so we assume KJT input modification. To allow torchrec to check if this module can be safely pipelined, please set `pipeline_postproc=True`" - ) - return None - - if not postproc_module: - # Could not find such module, should not happen - return None - - if isinstance(postproc_module, PipelinedPostproc): - # Already did module swap and registered args, early exit - self._pipelined_postprocs.add(postproc_module) - arg_info.add_step(ArgInfoStepFactory.postproc(postproc_module)) - return arg_info - - if not isinstance(postproc_module, torch.nn.Module): - logger.warning( - f"Expected postproc_module to be nn.Module but was {type(postproc_module)}" - ) - return None - - # check if module is safe to pipeline i.e.no trainable param - if not _check_postproc_pipelineable(postproc_module): - return None - - # For module calls, `self` isn't counted - total_num_args = len(child_node.args) + len(child_node.kwargs) - if total_num_args == 0: - # module call without any args, assume KJT modified - return None - - # recursive call to check that all inputs to this postproc module - # is either made of postproc module or non-modifying train batch input - # transformations - postproc_args, num_found_safe_postproc_args = self.get_node_args( - child_node, - for_postproc_module=True, - ) - if num_found_safe_postproc_args == total_num_args: - logger.info( - f"""Module {postproc_module} is a valid postproc module (no - trainable params and inputs can be derived from train batch input - via a series of either valid postproc modules or non-modifying - transformations) and will be applied during sparse data dist - stage""" - ) - - pipelined_postproc_module = PipelinedPostproc( - postproc_module, - postproc_module_fqn, - postproc_args, - self._context, - default_stream=self._default_stream, - dist_stream=self._dist_stream, - ) - - # module swap - self._model = self._swap_postproc_module_recursive( - self._model, pipelined_postproc_module, postproc_module_fqn - ) - - self._pipelined_postprocs.add(pipelined_postproc_module) - arg_info.add_step(ArgInfoStepFactory.postproc(pipelined_postproc_module)) - return arg_info - - return None - - def _get_node_args_helper_inner( - self, - # pyre-ignore - arg, - for_postproc_module: bool = False, - ) -> Optional[ArgInfo]: - arg_info = ArgInfo([]) - while True: - if not isinstance(arg, torch.fx.Node): - return self._handle_constant(arg, arg_info, for_postproc_module) - - child_node = arg - - if child_node.op == "placeholder": - return self._handle_placeholder(arg, arg_info) - elif child_node.op == "call_module": - return self._handle_module(arg, arg_info) - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "builtins" - # pyre-fixme[16] - and child_node.target.__name__ == "getattr" - ): - arg_info.add_step( - # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown - ArgInfoStepFactory.get_attr(child_node.args[1]) - ) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "_operator" - # pyre-fixme[16] - and child_node.target.__name__ == "getitem" - ): - arg_info.add_step( - # pyre-fixme[6]: For 2nd argument expected `str` but got Unknown - ArgInfoStepFactory.get_item(child_node.args[1]) - ) - arg = child_node.args[0] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "torch.utils._pytree" - # pyre-fixme[16] - and child_node.target.__name__ == "tree_unflatten" - ): - """ - This is for the PT2 export path where we unflatten the input to reconstruct - the structure with the recorded tree spec. - """ - step = arg_info.steps[0] - assert isinstance(step, GetItemArgInfoStep) - # pyre-fixme[16] - arg = child_node.args[0][step.item_index] - elif ( - child_node.op == "call_function" - and child_node.target.__module__ == "torchrec.sparse.jagged_tensor" - # pyre-fixme[16] - and child_node.target.__name__ == "KeyedJaggedTensor" - ): - call_module_found = False - - for arg_node in chain(child_node.args, child_node.kwargs.values()): - if isinstance( - arg_node, torch.fx.Node - ) and _check_args_for_call_module(arg_node): - call_module_found = True - break - - if call_module_found: - break - - if "values" in child_node.kwargs: - arg = child_node.kwargs["values"] - else: - arg = child_node.args[1] - - elif child_node.op == "call_method" and child_node.target == "get": - # pyre-ignore[6] - arg_info.add_step(ArgInfoStepFactory.get_item(child_node.args[1])) - arg = child_node.args[0] - else: - break - - # if we couldn't hit one of the "decisive" outcomes (constant, placeholder or module), return "not found" - return None - - def _get_node_args_helper( - self, - # pyre-ignore - arguments, - # Add `None` constants to arg info only for postproc modules - # Defaults to False for backward compatibility - for_postproc_module: bool = False, - ) -> Tuple[List[ArgInfo], int]: - """ - Goes through the args/kwargs of a node and arranges them into a list of `ArgInfo`s. - It also counts the number of (args + kwargs) found. - """ - num_found = 0 - arg_info_list = [] - for arg in arguments: - if not for_postproc_module and arg is None: - arg_info = ArgInfo([ArgInfoStepFactory.from_scalar(None)]) - arg_info_list.append(arg_info) - num_found += 1 - continue - arg_info = self._get_node_args_helper_inner( - arg, - for_postproc_module, - ) - if arg_info is not None: - num_found += 1 - arg_info_list.append(arg_info) - return arg_info_list, num_found - - def get_node_args( - self, - node: Node, - for_postproc_module: bool = False, - ) -> Tuple[CallArgs, int]: - pos_arg_info_list, args_found = self._get_node_args_helper( - node.args, - for_postproc_module, - ) - kwargs_arg_info_list, kwargs_found = self._get_node_args_helper( - node.kwargs.values(), - for_postproc_module, - ) - - # Replace with proper names for kwargs - kwargs_info_list = dict(zip(node.kwargs, kwargs_arg_info_list)) - - return CallArgs(pos_arg_info_list, kwargs_info_list), args_found + kwargs_found - - -def _get_leaf_module_names_helper( - model: torch.nn.Module, - path: str, - leaf_module_names: Set[str], -) -> bool: - """ - recursive function returns True if any of the sub-modules is ShardedModule. - it also added the fqns of the sub-modules who do not contain any ShardedModule - into the `leaf_module_names` unless it's marked as `_is_pytorch_fx_traceable = True`, - which suggests this ShardedModule-free module should NOT be treated as a leaf module - """ - sharded_children = set() - for name, child in model.named_children(): - curr_path = path + name - if isinstance(child, ShardedModule): - sharded_children.add(name) - else: - child_sharded = _get_leaf_module_names_helper( - child, - curr_path + ".", - leaf_module_names, - ) - if child_sharded: - sharded_children.add(name) - - # only do this for hybrid module (has sharded child) - if len(sharded_children) > 0: - for name, child in model.named_children(): - if name in sharded_children: - continue - # assume module is leaf node unless annotated otherwise - if not getattr(child, "_is_pytorch_fx_traceable", False): - leaf_module_names.add(path + name) - return len(sharded_children) > 0 - - -def _get_leaf_module_names(model: torch.nn.Module) -> List[str]: - """ - Returns a list of top level modules to be used as leaf modules for FX tracing. - This is a shallow FX trace that only goes the minimum depth required to pipeline. - Any sub-module who does not contain a ShardedModule would be considered as a leaf - module unless explicitly tagged as `_is_pytorch_fx_traceable = True`. - """ - - leaf_module_names: Set[str] = set() - _get_leaf_module_names_helper( - model, - "", - leaf_module_names, - ) - return list(leaf_module_names) - - def _jit_modules(module: torch.nn.Module, path: str, optional: bool = True) -> bool: sharded_children = set() for name, child in module.named_children():