From 090a46a2c656b72083bac579e7fb14960950e1a0 Mon Sep 17 00:00:00 2001 From: James Dong Date: Thu, 19 Jun 2025 20:07:56 -0700 Subject: [PATCH] Update docstrings for stride_per_key_per_rank Summary: Created from CodeHub with https://fburl.com/edit-in-codehub Differential Revision: D76999897 --- torchrec/sparse/jagged_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 9cda3f9dd..14d6577d9 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1679,9 +1679,9 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): offsets (Optional[torch.Tensor]): jagged slices, represented as cumulative offsets. stride (Optional[int]): number of examples per batch. - stride_per_key_per_rank (Optional[List[List[int]]]): batch size - (number of examples) per key per rank, with the outer list representing the - keys and the inner list representing the values. + stride_per_key_per_rank (Optional[Union[torch.IntTensor, List[List[int]]]]): + batch size (number of examples) per key per rank, with the outer list + representing the keys and the inner list representing the values. Each value in the inner list represents the number of examples in the batch from the rank of its index in a distributed context. length_per_key (Optional[List[int]]): start length for each key.