Skip to content

Commit

Permalink
Support randn_like() for NT (#96528)
Browse files Browse the repository at this point in the history
To satisfy an internal ask.
Pull Request resolved: pytorch/pytorch#96528
Approved by: https://github.com/mikaylagawarecki, https://github.com/cpuhrsch
  • Loading branch information
jbschlosser authored and cyyever committed Mar 27, 2023
1 parent 650985c commit 635ee16
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 2 deletions.
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4439,7 +4439,7 @@
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randn_like
CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
autogen: randn_like.out

- func: randperm(int n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down Expand Up @@ -9645,6 +9645,7 @@
MPS: normal_mps_
Meta: normal_meta_
SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_
NestedTensorCPU, NestedTensorCUDA: normal_nested_
autogen: normal.out

# Only used by the functionalization pass.
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,5 +970,11 @@ Tensor reshape_as_nested(const Tensor& self, const Tensor& other) {
return self.reshape(sizes);
}

Tensor& normal_nested_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
self_buf.normal_(mean, std, gen);
return self;
}

} // namespace native
} // namespace at
1 change: 1 addition & 0 deletions docs/source/nested.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,5 @@ NestedTensor and any constraints they have.
:func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
:func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
:func:`torch.empty_like`; "Behavior is analogous to that of regular tensors; returns a new empty nested tensor (i.e. with uninitialized values) matching the nested structure of the input."
:func:`torch.randn_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with values randomly initialized according to a standard normal distribution matching the nested structure of the input."
:func:`torch.zeros_like`; "Behavior is analogous to that of regular tensors; returns a new nested tensor with all zero values matching the nested structure of the input."
4 changes: 3 additions & 1 deletion test/test_nestedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,15 @@ def test_zero_(self):
t.fill_(0.)
self.assertEqual(nt_ub, t)

@parametrize("func", [torch.ones_like, torch.zeros_like],
@parametrize("func", [torch.ones_like, torch.zeros_like, torch.randn_like],
name_fn=lambda f: f.__name__)
def test_like_functions(self, func):
ntensors = 4
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
torch.manual_seed(1)
nt_like = func(nt)

torch.manual_seed(1)
for nt_ub in nt_like.unbind():
t_like = func(nt_ub)
self.assertEqual(nt_ub, t_like)
Expand Down

0 comments on commit 635ee16

Please sign in to comment.