Skip to content

Commit

Permalink
fix some dumb mistakes (should return a tuple from grad methods
Browse files Browse the repository at this point in the history
  • Loading branch information
m0saan committed Jul 3, 2023
1 parent 137effa commit c2f21aa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions minima/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor,]:
The gradients with respect to the inputs.
"""
a = node.children[0].compute_cached_data()
return out_grad * Tensor(a > 0)
return (out_grad * Tensor(a > 0), )

def relu(a: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -754,7 +754,7 @@ def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor, ...]:
Tuple[Tensor, ...]: The gradient with respect to the input tensor.
"""
input_shape = node.children[0].shape
return reshape(out_grad, input_shape),
return (reshape(out_grad, input_shape), )

def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor:
"""
Expand Down Expand Up @@ -932,7 +932,7 @@ def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor]:
broadcasted_grad = broadcast_to(reshaped_grad, node.children[0].shape)

# The gradient method needs to return a tuple, even though there's only one input
return (broadcasted_grad,)
return (broadcasted_grad, )


def summation(a: Tensor, axes: Optional[tuple] = None) -> Tensor:
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor]:
sum_over = tuple([idx for idx in range(len(self.shape)) if self.shape[idx] != shape[idx]])

# Finally, we reshape the gradient after summing over the appropriate dimensions to match `a`'s shape.
return reshape(summation(out_grad, sum_over), a_shape)
return (reshape(summation(out_grad, sum_over), a_shape), )

def broadcast_to(a: Tensor, shape: Tuple[int, ...]) -> Tensor:
"""
Expand Down
18 changes: 9 additions & 9 deletions nbs/01_operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@
" The gradients with respect to the inputs.\n",
" \"\"\"\n",
" a = node.children[0].compute_cached_data()\n",
" return out_grad * Tensor(a > 0)\n",
" return (out_grad * Tensor(a > 0), )\n",
"\n",
"def relu(a: Tensor) -> Tensor:\n",
" \"\"\"\n",
Expand Down Expand Up @@ -1675,7 +1675,7 @@
" Tuple[Tensor, ...]: The gradient with respect to the input tensor.\n",
" \"\"\"\n",
" input_shape = node.children[0].shape\n",
" return reshape(out_grad, input_shape), \n",
" return (reshape(out_grad, input_shape), )\n",
"\n",
"def reshape(a: Tensor, shape: Tuple[int, ...]) -> Tensor:\n",
" \"\"\"\n",
Expand Down Expand Up @@ -1892,11 +1892,11 @@
{
"data": {
"text/plain": [
"tensor([[ 0.7752, -6.1767, -5.3929, -8.5221, 5.0633],\n",
" [-10.6650, 8.4653, 3.0637, 5.5141, 4.5129],\n",
" [ 3.4482, -5.3584, -2.6963, 1.9794, 1.4696],\n",
" [ 7.3074, 0.4407, 2.0073, 1.7052, -11.3009],\n",
" [ 3.4645, -3.7268, -3.1002, 1.4796, -2.3030]])"
"tensor([[-0.1309, -1.9203, -4.4179, 2.8422, -0.4453],\n",
" [-1.5883, -8.1020, -6.7316, -1.3045, 0.6170],\n",
" [-0.5317, 2.3444, 1.6038, -3.5786, -0.1689],\n",
" [ 1.0831, -1.3743, 0.8485, -3.0593, 2.2023],\n",
" [ 0.3071, 1.8321, -3.6827, -9.4409, -1.1884]])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -2194,7 +2194,7 @@
" broadcasted_grad = broadcast_to(reshaped_grad, node.children[0].shape)\n",
"\n",
" # The gradient method needs to return a tuple, even though there's only one input\n",
" return (broadcasted_grad,)\n",
" return (broadcasted_grad, )\n",
"\n",
"\n",
"def summation(a: Tensor, axes: Optional[tuple] = None) -> Tensor:\n",
Expand Down Expand Up @@ -2485,7 +2485,7 @@
" sum_over = tuple([idx for idx in range(len(self.shape)) if self.shape[idx] != shape[idx]])\n",
"\n",
" # Finally, we reshape the gradient after summing over the appropriate dimensions to match `a`'s shape.\n",
" return reshape(summation(out_grad, sum_over), a_shape)\n",
" return (reshape(summation(out_grad, sum_over), a_shape), )\n",
"\n",
"def broadcast_to(a: Tensor, shape: Tuple[int, ...]) -> Tensor:\n",
" \"\"\"\n",
Expand Down

0 comments on commit c2f21aa

Please sign in to comment.