From 660663ba94937c7bf557697f8ff42d9fa75ead2b Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Mon, 13 Dec 2021 23:03:03 -0800 Subject: [PATCH] Make to_dict(...) lazy (#28) Summary: Pull Request resolved: https://github.com/facebookresearch/torchrec/pull/28 make to_dict(...) lazy since it is expensive and might be caled multiple times Reviewed By: colin2328, bigning Differential Revision: D33053563 fbshipit-source-id: 56a197a629b6c148f5f7aeed6d635bc29ea5a010 --- torchrec/sparse/jagged_tensor.py | 65 ++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 0c28fa8d2..579d5e3a3 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -374,7 +374,7 @@ def _jagged_tensor_string( ) -def _kjt_to_jt_dict( +def _maybe_compute_kjt_to_jt_dict( stride: int, keys: List[str], length_per_key: List[int], @@ -382,30 +382,33 @@ def _kjt_to_jt_dict( lengths: torch.Tensor, offsets: torch.Tensor, weights: Optional[torch.Tensor], + jt_dict: Optional[Dict[str, JaggedTensor]], ) -> Dict[str, JaggedTensor]: - jt_dict: Dict[str, JaggedTensor] = {} - values_list = torch.split(values, length_per_key) - lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) - if weights is not None: - weights_list = torch.split(weights, length_per_key) - for idx, key in enumerate(keys): - length = lengths_tuple[idx] - offset = _to_offsets(length) - jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - weights=weights_list[idx], - ) - else: - for idx, key in enumerate(keys): - length = lengths_tuple[idx] - offset = _to_offsets(length) - jt_dict[key] = JaggedTensor( - lengths=length, - offsets=offset, - values=values_list[idx], - ) + if jt_dict is None: + _jt_dict: Dict[str, JaggedTensor] = {} + values_list = torch.split(values, length_per_key) + lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) + if weights is not None: + weights_list = torch.split(weights, length_per_key) + for idx, key in enumerate(keys): + length = lengths_tuple[idx] + offset = _to_offsets(length) + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + weights=weights_list[idx], + ) + else: + for idx, key in enumerate(keys): + length = lengths_tuple[idx] + offset = _to_offsets(length) + _jt_dict[key] = JaggedTensor( + lengths=length, + offsets=offset, + values=values_list[idx], + ) + jt_dict = _jt_dict return jt_dict @@ -468,6 +471,7 @@ def __init__( length_per_key: Optional[List[int]] = None, offset_per_key: Optional[List[int]] = None, index_per_key: Optional[Dict[str, int]] = None, + jt_dict: Optional[Dict[str, JaggedTensor]] = None, ) -> None: self._keys: List[str] = keys self._values: torch.Tensor = values @@ -485,6 +489,7 @@ def __init__( self._length_per_key: Optional[List[int]] = length_per_key self._offset_per_key: Optional[List[int]] = offset_per_key self._index_per_key: Optional[Dict[str, int]] = index_per_key + self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict @staticmethod def from_offsets_sync( @@ -660,6 +665,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: length_per_key=self._length_per_key, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, + jt_dict=self._jt_dict, ) ) elif segment == 0: @@ -682,6 +688,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: length_per_key=None, offset_per_key=None, index_per_key=None, + jt_dict=None, ) ) else: @@ -701,6 +708,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: length_per_key=split_length_per_key, offset_per_key=None, index_per_key=None, + jt_dict=None, ) ) start = end @@ -751,6 +759,7 @@ def permute( length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, offset_per_key=None, index_per_key=None, + jt_dict=None, ) return kjt @@ -769,7 +778,7 @@ def __getitem__(self, key: str) -> JaggedTensor: ) def to_dict(self) -> Dict[str, JaggedTensor]: - return _kjt_to_jt_dict( + _jt_dict = _maybe_compute_kjt_to_jt_dict( self.stride(), self.keys(), self.length_per_key(), @@ -777,7 +786,10 @@ def to_dict(self) -> Dict[str, JaggedTensor]: self.lengths(), self.offsets(), self.weights_or_none(), + self._jt_dict, ) + self._jt_dict = _jt_dict + return _jt_dict # pyre-ignore [56] @torch.jit.unused @@ -802,6 +814,7 @@ def to( length_per_key = self._length_per_key offset_per_key = self._offset_per_key index_per_key = self._index_per_key + jt_dict = self._jt_dict return KeyedJaggedTensor( keys=self._keys, @@ -819,6 +832,7 @@ def to( length_per_key=length_per_key, offset_per_key=offset_per_key, index_per_key=index_per_key, + jt_dict=jt_dict, ) def __str__(self) -> str: @@ -861,6 +875,7 @@ def pin_memory(self) -> "KeyedJaggedTensor": length_per_key=self._length_per_key, offset_per_key=self._offset_per_key, index_per_key=self._index_per_key, + jt_dict=None, )