Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,38 +374,41 @@ def _jagged_tensor_string(
)


def _kjt_to_jt_dict(
def _maybe_compute_kjt_to_jt_dict(
stride: int,
keys: List[str],
length_per_key: List[int],
values: torch.Tensor,
lengths: torch.Tensor,
offsets: torch.Tensor,
weights: Optional[torch.Tensor],
jt_dict: Optional[Dict[str, JaggedTensor]],
) -> Dict[str, JaggedTensor]:
jt_dict: Dict[str, JaggedTensor] = {}
values_list = torch.split(values, length_per_key)
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
if weights is not None:
weights_list = torch.split(weights, length_per_key)
for idx, key in enumerate(keys):
length = lengths_tuple[idx]
offset = _to_offsets(length)
jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
weights=weights_list[idx],
)
else:
for idx, key in enumerate(keys):
length = lengths_tuple[idx]
offset = _to_offsets(length)
jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
)
if jt_dict is None:
_jt_dict: Dict[str, JaggedTensor] = {}
values_list = torch.split(values, length_per_key)
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
if weights is not None:
weights_list = torch.split(weights, length_per_key)
for idx, key in enumerate(keys):
length = lengths_tuple[idx]
offset = _to_offsets(length)
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
weights=weights_list[idx],
)
else:
for idx, key in enumerate(keys):
length = lengths_tuple[idx]
offset = _to_offsets(length)
_jt_dict[key] = JaggedTensor(
lengths=length,
offsets=offset,
values=values_list[idx],
)
jt_dict = _jt_dict
return jt_dict


Expand Down Expand Up @@ -468,6 +471,7 @@ def __init__(
length_per_key: Optional[List[int]] = None,
offset_per_key: Optional[List[int]] = None,
index_per_key: Optional[Dict[str, int]] = None,
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
) -> None:
self._keys: List[str] = keys
self._values: torch.Tensor = values
Expand All @@ -485,6 +489,7 @@ def __init__(
self._length_per_key: Optional[List[int]] = length_per_key
self._offset_per_key: Optional[List[int]] = offset_per_key
self._index_per_key: Optional[Dict[str, int]] = index_per_key
self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict

@staticmethod
def from_offsets_sync(
Expand Down Expand Up @@ -660,6 +665,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
length_per_key=self._length_per_key,
offset_per_key=self._offset_per_key,
index_per_key=self._index_per_key,
jt_dict=self._jt_dict,
)
)
elif segment == 0:
Expand All @@ -682,6 +688,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
length_per_key=None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
)
)
else:
Expand All @@ -701,6 +708,7 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
length_per_key=split_length_per_key,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
)
)
start = end
Expand Down Expand Up @@ -751,6 +759,7 @@ def permute(
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
offset_per_key=None,
index_per_key=None,
jt_dict=None,
)
return kjt

Expand All @@ -769,15 +778,18 @@ def __getitem__(self, key: str) -> JaggedTensor:
)

def to_dict(self) -> Dict[str, JaggedTensor]:
return _kjt_to_jt_dict(
_jt_dict = _maybe_compute_kjt_to_jt_dict(
self.stride(),
self.keys(),
self.length_per_key(),
self.values(),
self.lengths(),
self.offsets(),
self.weights_or_none(),
self._jt_dict,
)
self._jt_dict = _jt_dict
return _jt_dict

# pyre-ignore [56]
@torch.jit.unused
Expand All @@ -802,6 +814,7 @@ def to(
length_per_key = self._length_per_key
offset_per_key = self._offset_per_key
index_per_key = self._index_per_key
jt_dict = self._jt_dict

return KeyedJaggedTensor(
keys=self._keys,
Expand All @@ -819,6 +832,7 @@ def to(
length_per_key=length_per_key,
offset_per_key=offset_per_key,
index_per_key=index_per_key,
jt_dict=jt_dict,
)

def __str__(self) -> str:
Expand Down Expand Up @@ -861,6 +875,7 @@ def pin_memory(self) -> "KeyedJaggedTensor":
length_per_key=self._length_per_key,
offset_per_key=self._offset_per_key,
index_per_key=self._index_per_key,
jt_dict=None,
)


Expand Down