Skip to content

Commit

Permalink
fix small bug about dense_grad (apache#5695)
Browse files Browse the repository at this point in the history
  • Loading branch information
handar423 authored and trevor-m committed Jun 18, 2020
1 parent 5c916fc commit 68c6501
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,10 @@ def bias_add_grad(orig, grad):
def dense_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
return [collapse_sum_like(transpose(grad) * weight, data),
collapse_sum_like(data * transpose(grad), weight)]

return [collapse_sum_like(_nn.dense(grad, transpose(weight),
units=weight.checked_type.shape[1]), data),
collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
units=data.checked_type.shape[1]), weight)]

@register_gradient("reshape")
def reshape_grad(orig, grad):
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def verify_dense_grad(d_shape, w_shape):
def test_dense_grad():
verify_dense_grad((1, 8), (16, 8))
verify_dense_grad((1, 4), (3, 4))
verify_dense_grad((5, 4), (3, 4))


def verify_batch_flatten_grad(d_shape):
Expand Down

0 comments on commit 68c6501

Please sign in to comment.