Skip to content

Commit

Permalink
reverse AD ok!
Browse files Browse the repository at this point in the history
  • Loading branch information
m0saan committed Jun 3, 2023
1 parent e280c79 commit 6e9d982
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 37 deletions.
3 changes: 3 additions & 0 deletions minima/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,9 @@ def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor]:
# If axes were specified, set those dimensions to 1 in the new shape
if self.axes:
for axis in self.axes: new_shape[axis] = 1

else:
new_shape = [1] * len(new_shape)

# Reshape out_grad to the new shape
reshaped_grad = reshape(out_grad, new_shape)
Expand Down
Loading

0 comments on commit 6e9d982

Please sign in to comment.