diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 53e0509a7..5daa4d0c9 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -619,7 +619,7 @@ def forward( permuted_lengths_after_sparse_data_all2all, sharded_input_embeddings, _, - ) = torch.ops.fbgemm.permute_sparse_data( + ) = torch.ops.fbgemm.permute_2D_sparse_data( forward_recat_tensor, lengths_after_sparse_data_all2all.view(local_T * world_size, -1), sharded_input_embeddings.view(-1), @@ -676,7 +676,7 @@ def backward(ctx, *unused) -> Tuple[None, None, None, Tensor]: if permuted_lengths_after_sparse_data_all2all is not None: with record_function("## alltoall_seq_embedding_bwd_permute ##"): - _, sharded_grad_input, _ = torch.ops.fbgemm.permute_sparse_data( + _, sharded_grad_input, _ = torch.ops.fbgemm.permute_2D_sparse_data( backward_recat_tensor, permuted_lengths_after_sparse_data_all2all, sharded_grad_input, diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index c6428f3bd..46d5926c5 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -195,7 +195,7 @@ def _wait_impl(self) -> KeyedJaggedTensor: with record_function("## all2all_data:recat_values ##"): if self._recat.numel(): - lengths, values, weights = torch.ops.fbgemm.permute_sparse_data( + lengths, values, weights = torch.ops.fbgemm.permute_2D_sparse_data( self._recat, lengths.view(self._workers * self._splits[self._pg.rank()], -1), values, diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 1cb0a0fed..d5699b8a2 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -742,7 +742,7 @@ def permute( permuted_lengths, permuted_values, permuted_weights, - ) = torch.ops.fbgemm.permute_sparse_data( + ) = torch.ops.fbgemm.permute_2D_sparse_data( indices_tensor, self.lengths().view(len(self._keys), -1), self.values(),