Skip to content

Commit

Permalink
[MPS] add aten::normal.Tensor_float aten::normal.float_Tensor `at…
Browse files Browse the repository at this point in the history
…en::normal.Tensor_Tensor` (pytorch#80297)

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#80297
Approved by: https://github.com/albanD, https://github.com/kulinseth
  • Loading branch information
qqaatw authored and kulinseth committed Jul 9, 2022
1 parent 43bacb4 commit 6de26fa
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
56 changes: 56 additions & 0 deletions aten/src/ATen/native/mps/operations/Distributions.mm
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,62 @@
return normal_mps_out(mean_t, std_t, gen, self);
}

Tensor normal_mps(const Tensor& mean, double std, c10::optional<Generator> gen) {
Tensor output = empty_mps(
mean.sizes(),
mean.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

Tensor std_t = empty_mps(
output.sizes(),
output.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
std_t.fill_(std);

return normal_mps_out(mean, std_t, gen, output);
}

Tensor normal_mps(double mean, const Tensor& std, c10::optional<Generator> gen) {
Tensor output = empty_mps(
std.sizes(),
std.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

Tensor mean_t = empty_mps(
output.sizes(),
output.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
mean_t.fill_(mean);

return normal_mps_out(mean_t, std, gen, output);
}

Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
auto shape = at::infer_size(mean.sizes(), std.sizes());

Tensor output = empty_mps(
shape,
mean.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

return normal_mps_out(mean, std, gen, output);
}

Tensor& normal_mps_out(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
TORCH_CHECK(std >= 0.0, "normal_mps_out expects std >= 0.0, but found std=", std);

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8383,7 +8383,7 @@
- func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
dispatch:
CPU, CUDA: normal
#MPS: normal_mps
MPS: normal_mps
Meta: normal_meta

- func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
Expand All @@ -8395,8 +8395,8 @@
- func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
dispatch:
CPU, CUDA: normal
MPS: normal_mps
Meta: normal_meta
#MPS: normal_mps

- func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand All @@ -8407,8 +8407,8 @@
- func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
dispatch:
CPU, CUDA: normal
MPS: normal_mps
Meta: normal_meta
#MPS: normal_mps

- func: normal.float_float(float mean, float std, int[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

Expand Down
18 changes: 12 additions & 6 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4114,9 +4114,6 @@ def helper(shape, x_shape, y_shape, cond_dtype=torch.bool, x_dtype=torch.float):
# Test normal
def test_normal(self):
def helper(shape, mean=0.0, std=1.0):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')

mps_out = torch.normal(mean, std, shape, device='mps')

mean_array = np.ones(shape)
Expand All @@ -4129,6 +4126,7 @@ def helper(shape, mean=0.0, std=1.0):
cpu_std_tensor = torch.tensor(std_array, device='cpu', dtype=torch.float, requires_grad=False)
std_tensor = cpu_std_tensor.detach().clone().to('mps')

# test out
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean_tensor, std, out=mps_out)

Expand All @@ -4138,14 +4136,22 @@ def helper(shape, mean=0.0, std=1.0):
mps_out = torch.zeros(shape, device='mps')
torch.normal(mean_tensor, std_tensor, out=mps_out)

# test without out
mps_out = torch.normal(mean_tensor, std)
self.assertEqual(mps_out.size(), mean_tensor.size())

mps_out = torch.normal(mean, std_tensor)
self.assertEqual(mps_out.size(), std_tensor.size())

inferred_shape = torch.broadcast_shapes(mean_tensor.size(), std_tensor.size())
mps_out = torch.normal(mean_tensor, std_tensor)
self.assertEqual(mps_out.size(), inferred_shape)

helper((2, 3, 4, 5, 6))
helper((100, 100), 2.5, 1.2)

def test_bernoulli(self):
def helper(shape, prob=0.5):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
x = cpu_x.detach().clone().to('mps')

prob_array = np.ones(shape)
prob_array *= prob
cpu_prob_tensor = torch.tensor(prob_array, device='cpu', dtype=torch.float, requires_grad=False)
Expand Down

0 comments on commit 6de26fa

Please sign in to comment.