From 08e4f7b9a13560de9a919dc5cfd61f3b4987a1f8 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 18 Jun 2025 15:47:11 -0700 Subject: [PATCH 1/2] add stride into KJT pytree (#2587) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2587 # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Differential Revision: D66400821 Reviewed By: PaulZhang12 --- .github/scripts/install_libs.sh | 2 +- requirements.txt | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/scripts/install_libs.sh b/.github/scripts/install_libs.sh index 27522ff92..193aaf46d 100644 --- a/.github/scripts/install_libs.sh +++ b/.github/scripts/install_libs.sh @@ -28,4 +28,4 @@ elif [ "$CHANNEL" = "test" ]; then fi -${CONDA_RUN} pip install importlib-metadata +${CONDA_RUN} pip install importlib-metadata click PyYAML diff --git a/requirements.txt b/requirements.txt index 6b17aeac6..6239d0d90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ black +click cmake fbgemm-gpu hypothesis==6.70.1 +importlib-metadata iopath numpy pandas @@ -13,6 +15,7 @@ torchx tqdm usort parameterized +PyYAML # for tests # https://github.com/pytorch/pytorch/blob/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc/requirements.txt#L3 From 696d332d9fde50daeab21886373bcec5e4e3eb39 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 18 Jun 2025 16:03:36 -0700 Subject: [PATCH 2/2] fix stride_per_key_per_rank in stagger scenario in D74366343 (#3111) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3111 # context * original diff D74366343 broke cogwheel test and was reverted * the error stack P1844048578 is shown below: ``` File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__ data = request.wait() File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl kjts.append(w.wait()) File "/dev/torchrec/distributed/types.py", line 334, in wait ret: W = self._wait_impl() File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl return type(self._input).dist_init( File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init return kjt.sync() File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync self.length_per_key() File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key _length_per_key = _maybe_compute_length_per_key( File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key _length_per_key_from_stride_per_key(lengths, stride_per_key) File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key if _use_segment_sum_csr(stride_per_key): File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr elements_per_segment = sum(stride_per_key) / len(stride_per_key) ZeroDivisionError: division by zero ``` * the complaint is `stride_per_key` is an empty list, which comes from the following function call: ``` stride_per_key = _maybe_compute_stride_per_key( self._stride_per_key, self._stride_per_key_per_rank, self.stride(), self._keys, ) ``` * the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2` ``` def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], stride_per_key_per_rank: Optional[torch.IntTensor], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: if stride_per_key is not None: return stride_per_key elif stride_per_key_per_rank is not None: if stride_per_key_per_rank.dim() != 2: # after permute the kjt could be empty return [] rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist() if not torch.jit.is_scripting() and is_torchdynamo_compiling(): pt2_checks_all_is_size(rt) return rt elif stride is not None: return [stride] * len(keys) else: return None ``` # the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`: * baseline ``` if stagger > 1: stride_per_key_per_rank_stagger: List[List[int]] = [] local_world_size = num_workers // stagger for i in range(len(keys)): stride_per_rank_stagger: List[int] = [] for j in range(local_world_size): stride_per_rank_stagger.extend( stride_per_key_per_rank[i][j::local_world_size] ) stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) stride_per_key_per_rank = stride_per_key_per_rank_stagger ``` * D76875546 (correct, this diff) ``` if stagger > 1: indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1) stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` * D74366343 (incorrect, reverted) ``` if stagger > 1: local_world_size = num_workers // stagger indices = [ list(range(i, num_workers, local_world_size)) for i in range(local_world_size) ] stride_per_key_per_rank = stride_per_key_per_rank[:, indices] ``` Differential Revision: D76903646 --- torchrec/pt2/utils.py | 4 +- .../api_tests/test_jagged_tensor_schema.py | 6 +- torchrec/sparse/jagged_tensor.py | 121 +++++++++++------- 3 files changed, 83 insertions(+), 48 deletions(-) diff --git a/torchrec/pt2/utils.py b/torchrec/pt2/utils.py index 55accff68..44af5ae1f 100644 --- a/torchrec/pt2/utils.py +++ b/torchrec/pt2/utils.py @@ -54,7 +54,7 @@ def kjt_for_pt2_tracing( values=values, lengths=lengths, weights=kjt.weights_or_none(), - stride_per_key_per_rank=[[stride]] * n, + stride_per_key_per_rank=torch.IntTensor([[stride]] * n, device="cpu"), inverse_indices=(kjt.keys(), inverse_indices_tensor), ) @@ -85,7 +85,7 @@ def kjt_for_pt2_tracing( lengths=lengths, weights=weights, stride=stride if not is_vb else None, - stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None, + stride_per_key_per_rank=kjt._stride_per_key_per_rank if is_vb else None, inverse_indices=inverse_indices, ) diff --git a/torchrec/schema/api_tests/test_jagged_tensor_schema.py b/torchrec/schema/api_tests/test_jagged_tensor_schema.py index eacb10d9e..d51368b12 100644 --- a/torchrec/schema/api_tests/test_jagged_tensor_schema.py +++ b/torchrec/schema/api_tests/test_jagged_tensor_schema.py @@ -9,7 +9,7 @@ import inspect import unittest -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch from torchrec.schema.utils import is_signature_compatible @@ -112,7 +112,9 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Optional[ + Union[List[List[int]], torch.IntTensor] + ] = None, # Below exposed to ensure torch.script-able stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index ba5fdd470..9cda3f9dd 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1094,13 +1094,15 @@ def _maybe_compute_stride_kjt( stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Optional[torch.IntTensor], ) -> int: if stride is None: if len(keys) == 0: stride = 0 - elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: - stride = max([sum(s) for s in stride_per_key_per_rank]) + elif ( + stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0 + ): + stride = int(stride_per_key_per_rank.sum(dim=1).max().item()) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: @@ -1452,8 +1454,8 @@ def _maybe_compute_kjt_to_jt_dict( def _kjt_empty_like(kjt: "KeyedJaggedTensor") -> "KeyedJaggedTensor": # empty like function fx wrapped, also avoids device hardcoding stride, stride_per_key_per_rank = ( - (None, kjt.stride_per_key_per_rank()) - if kjt.variable_stride_per_key() + (None, kjt._stride_per_key_per_rank) + if kjt._stride_per_key_per_rank is not None and kjt.variable_stride_per_key() else (kjt.stride(), None) ) @@ -1639,14 +1641,20 @@ def _maybe_compute_lengths_offset_per_key( def _maybe_compute_stride_per_key( stride_per_key: Optional[List[int]], - stride_per_key_per_rank: Optional[List[List[int]]], + stride_per_key_per_rank: Optional[torch.IntTensor], stride: Optional[int], keys: List[str], ) -> Optional[List[int]]: if stride_per_key is not None: return stride_per_key elif stride_per_key_per_rank is not None: - return [sum(s) for s in stride_per_key_per_rank] + if stride_per_key_per_rank.dim() != 2: + # after permute the kjt could be empty + return [] + rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist() + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + pt2_checks_all_is_size(rt) + return rt elif stride is not None: return [stride] * len(keys) else: @@ -1725,7 +1733,9 @@ def __init__( lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, stride: Optional[int] = None, - stride_per_key_per_rank: Optional[List[List[int]]] = None, + stride_per_key_per_rank: Optional[ + Union[torch.IntTensor, List[List[int]]] + ] = None, # Below exposed to ensure torch.script-able stride_per_key: Optional[List[int]] = None, length_per_key: Optional[List[int]] = None, @@ -1747,8 +1757,14 @@ def __init__( self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets self._stride: Optional[int] = stride - self._stride_per_key_per_rank: Optional[List[List[int]]] = ( - stride_per_key_per_rank + if not torch.jit.is_scripting() and is_torchdynamo_compiling(): + # in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None + # does not take List[List[int]] + assert not isinstance(stride_per_key_per_rank, list) + self._stride_per_key_per_rank: Optional[torch.IntTensor] = ( + torch.IntTensor(stride_per_key_per_rank, device="cpu") + if isinstance(stride_per_key_per_rank, list) + else stride_per_key_per_rank ) self._stride_per_key: Optional[List[int]] = stride_per_key self._length_per_key: Optional[List[int]] = length_per_key @@ -1759,6 +1775,8 @@ def __init__( self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = ( inverse_indices ) + # this is only needed for torch.compile case + self._pt2_stride_per_key_per_rank: Optional[List[List[int]]] = None # legacy attribute, for backward compatabilibity self._variable_stride_per_key: Optional[bool] = None @@ -1774,10 +1792,6 @@ def _init_pt2_checks(self) -> None: return if self._stride_per_key is not None: pt2_checks_all_is_size(self._stride_per_key) - if self._stride_per_key_per_rank is not None: - # pyre-ignore [16] - for s in self._stride_per_key_per_rank: - pt2_checks_all_is_size(s) @staticmethod def from_offsets_sync( @@ -1987,7 +2001,7 @@ def from_jt_dict(jt_dict: Dict[str, JaggedTensor]) -> "KeyedJaggedTensor": kjt_stride, kjt_stride_per_key_per_rank = ( (stride_per_key[0], None) if all(s == stride_per_key[0] for s in stride_per_key) - else (None, [[stride] for stride in stride_per_key]) + else (None, torch.IntTensor(stride_per_key, device="cpu").reshape(-1, 1)) ) kjt = KeyedJaggedTensor( keys=kjt_keys, @@ -2152,12 +2166,32 @@ def stride_per_key_per_rank(self) -> List[List[int]]: Returns: List[List[int]]: stride per key per rank of the KeyedJaggedTensor. """ - stride_per_key_per_rank = self._stride_per_key_per_rank - return stride_per_key_per_rank if stride_per_key_per_rank is not None else [] + # making a local reference to the class variable to make jit.script behave + _stride_per_key_per_rank = self._stride_per_key_per_rank + if ( + not torch.jit.is_scripting() + and is_torchdynamo_compiling() + and _stride_per_key_per_rank is not None + ): + if self._pt2_stride_per_key_per_rank is not None: + return self._pt2_stride_per_key_per_rank + stride_per_key_per_rank = _stride_per_key_per_rank.tolist() + for stride_per_rank in stride_per_key_per_rank: + pt2_checks_all_is_size(stride_per_rank) + self._pt2_stride_per_key_per_rank = stride_per_key_per_rank + return stride_per_key_per_rank + return ( + [] + if _stride_per_key_per_rank is None + else _stride_per_key_per_rank.tolist() + ) def variable_stride_per_key(self) -> bool: """ Returns whether the KeyedJaggedTensor has variable stride per key. + NOTE: `self._variable_stride_per_key` could be `False` when `self._stride_per_key_per_rank` + is not `None`. It might be assigned to False externally/intentionally, usually the + `self._stride_per_key_per_rank` is trivial. Returns: bool: whether the KeyedJaggedTensor has variable stride per key. @@ -2302,13 +2336,16 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: start_offset = 0 _length_per_key = self.length_per_key() _offset_per_key = self.offset_per_key() + # use local copy/ref for self._stride_per_key_per_rank to satisfy jit.script + _stride_per_key_per_rank = self._stride_per_key_per_rank for segment in segments: end = start + segment end_offset = _offset_per_key[end] keys: List[str] = self._keys[start:end] stride_per_key_per_rank = ( - self.stride_per_key_per_rank()[start:end] + _stride_per_key_per_rank[start:end, :] if self.variable_stride_per_key() + and _stride_per_key_per_rank is not None else None ) if segment == len(self._keys): @@ -2456,17 +2493,24 @@ def permute( length_per_key = self.length_per_key() permuted_keys: List[str] = [] - permuted_stride_per_key_per_rank: List[List[int]] = [] permuted_length_per_key: List[int] = [] permuted_length_per_key_sum = 0 for index in indices: key = self.keys()[index] permuted_keys.append(key) permuted_length_per_key.append(length_per_key[index]) - if self.variable_stride_per_key(): - permuted_stride_per_key_per_rank.append( - self.stride_per_key_per_rank()[index] - ) + + stride_per_key = self._stride_per_key + permuted_stride_per_key = ( + [stride_per_key[i] for i in indices] if stride_per_key is not None else None + ) + + _stride_per_key_per_rank = self._stride_per_key_per_rank + permuted_stride_per_key_per_rank = ( + _stride_per_key_per_rank[indices, :] + if self.variable_stride_per_key() and _stride_per_key_per_rank is not None + else None + ) permuted_length_per_key_sum = sum(permuted_length_per_key) if not torch.jit.is_scripting() and is_non_strict_exporting(): @@ -2518,9 +2562,7 @@ def permute( self.weights_or_none(), permuted_length_per_key_sum, ) - stride_per_key_per_rank = ( - permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None - ) + kjt = KeyedJaggedTensor( keys=permuted_keys, values=permuted_values, @@ -2528,8 +2570,8 @@ def permute( lengths=permuted_lengths.view(-1), offsets=None, stride=self._stride, - stride_per_key_per_rank=stride_per_key_per_rank, - stride_per_key=None, + stride_per_key_per_rank=permuted_stride_per_key_per_rank, + stride_per_key=permuted_stride_per_key, length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, lengths_offset_per_key=None, offset_per_key=None, @@ -2848,7 +2890,7 @@ def dist_init( if variable_stride_per_key: assert stride_per_rank_per_key is not None - stride_per_key_per_rank_tensor: torch.Tensor = stride_per_rank_per_key.view( + stride_per_key_per_rank: torch.Tensor = stride_per_rank_per_key.view( num_workers, len(keys) ).T.cpu() @@ -2885,23 +2927,14 @@ def dist_init( weights, ) - stride_per_key_per_rank = torch.jit.annotate( - List[List[int]], stride_per_key_per_rank_tensor.tolist() - ) + if stride_per_key_per_rank.numel() == 0: + stride_per_key_per_rank = torch.zeros( + (len(keys), 1), device="cpu", dtype=torch.int64 + ) - if not stride_per_key_per_rank: - stride_per_key_per_rank = [[0]] * len(keys) if stagger > 1: - stride_per_key_per_rank_stagger: List[List[int]] = [] - local_world_size = num_workers // stagger - for i in range(len(keys)): - stride_per_rank_stagger: List[int] = [] - for j in range(local_world_size): - stride_per_rank_stagger.extend( - stride_per_key_per_rank[i][j::local_world_size] - ) - stride_per_key_per_rank_stagger.append(stride_per_rank_stagger) - stride_per_key_per_rank = stride_per_key_per_rank_stagger + indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1) + stride_per_key_per_rank = stride_per_key_per_rank[:, indices] kjt = KeyedJaggedTensor( keys=keys,