Skip to content

Commit

Permalink
fix embedding_backward_dense decomp with broadcasting (pytorch#95499)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdhirsh authored and jhavukainen committed Mar 15, 2024
1 parent 768647d commit 2af0af4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,25 @@ def f(x):

f(torch.ones(2, device="cuda", dtype=torch.float64))

def test_embedding_backward_broadcasting_decomp(self):
def f(grad_output, indices):
num_weights = 10
padding_idx = 1
scale_grad_by_freq = True
return torch.ops.aten.embedding_dense_backward(
grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
)

f_compiled = torch.compile(f, backend="aot_eager")

grad_output = torch.ones(2, 4, 3, dtype=torch.float16)
indices = torch.ones(2, 4, dtype=torch.int64)

out_ref = f(grad_output, indices)
out_test = f_compiled(grad_output, indices)

self.assertEqual(out_ref, out_test)

def test_reformer_eval(self):
with torch.no_grad():
cnt = self._reformer(nopython=True)
Expand Down
2 changes: 1 addition & 1 deletion torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ def embedding_dense_backward(
ones = torch.ones_like(indices)
counts = counts.index_put([indices], ones, accumulate=True)
grad_weights_scale = counts[indices]
grad_output = grad_output / grad_weights_scale.unsqueeze(1)
grad_output = grad_output / grad_weights_scale.unsqueeze(-1)

mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
grad = grad_output.masked_fill(mask, 0)
Expand Down

0 comments on commit 2af0af4

Please sign in to comment.