Skip to content

Commit 02039df

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Add sanity check for stride_per_key_per_rank's dim (#3124)
Summary: Pull Request resolved: #3124 `stride_per_key_per_rank` can only be a 2D tensor or 2D list. Created from CodeHub with https://fburl.com/edit-in-codehub Reviewed By: TroyGarden Differential Revision: D77052836 fbshipit-source-id: c06b94c7c8d0999276b5d626feaf434222586860
1 parent 3169481 commit 02039df

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,12 @@ def __init__(
17611761
# in pt2.compile the stride_per_key_per_rank has to be torch.Tensor or None
17621762
# does not take List[List[int]]
17631763
assert not isinstance(stride_per_key_per_rank, list)
1764+
1765+
if isinstance(stride_per_key_per_rank, torch.IntTensor):
1766+
assert (
1767+
stride_per_key_per_rank.dim() == 2
1768+
), f"Expect 2D tensor with shape [len(keys), len(ranks)] for stride_per_key_per_rank, but got tensor with shape: {stride_per_key_per_rank.shape}"
1769+
17641770
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
17651771
torch.IntTensor(stride_per_key_per_rank, device="cpu")
17661772
if isinstance(stride_per_key_per_rank, list)

0 commit comments

Comments
 (0)