Skip to content

Commit

Permalink
FIX Validates that weights are 2d in embedding (pytorch#59314)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#55185

Pull Request resolved: pytorch#59314

Reviewed By: H-Huang

Differential Revision: D28837753

Pulled By: jbschlosser

fbshipit-source-id: 683378244c61b0937c95563f91ef87ab09fd1653
  • Loading branch information
thomasjpfan authored and deniskokarev committed Jun 9, 2021
1 parent 2c17397 commit 6931863
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Embedding.cpp
Expand Up @@ -15,7 +15,7 @@ namespace at { namespace native {

Tensor embedding(const Tensor & weight, const Tensor & indices,
int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
TORCH_CHECK(weight.dim() >= 1, "'weight' must be at least 1-D");
TORCH_CHECK(weight.dim() == 2, "'weight' must be 2-D");
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarTypes("embedding", indices_arg, {kLong, kInt});

Expand Down
11 changes: 8 additions & 3 deletions test/test_nn.py
Expand Up @@ -13383,9 +13383,14 @@ def fn(weight):

def test_embedding_scalar_weight_error(self, device):
indices = torch.rand(2, 2, device=device).long()
weight = torch.tensor(1.0, device=device)
with self.assertRaisesRegex(RuntimeError, "'weight' must be at least 1-D"):
torch.nn.functional.embedding(indices, weight)
weights = [
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device).reshape(1, 1, 1),
]

for weight in weights:
with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"):
torch.nn.functional.embedding(indices, weight)

@dtypesIfCUDA(torch.float16, torch.float64)
@dtypes(torch.float64)
Expand Down
10 changes: 7 additions & 3 deletions test/test_torch.py
Expand Up @@ -4088,9 +4088,13 @@ def backward_func(slf, device):

def test_embedding_scalar_weight_error(self, device):
indices = torch.rand(2, 2, device=device).long()
weight = torch.tensor(1.0)
with self.assertRaisesRegex(RuntimeError, "'weight' must be at least 1-D"):
torch.embedding(weight, indices)
weights = [
torch.tensor(1.0, device=device),
torch.tensor(1.0, device=device).reshape(1, 1, 1),
]
for weight in weights:
with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"):
torch.embedding(weight, indices)

def test_dist(self, device):
def run_test(x, y):
Expand Down

0 comments on commit 6931863

Please sign in to comment.