From 6e9d982231e938ecba3cc3b70b7a54b9532d2330 Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Sat, 3 Jun 2023 07:10:55 +0100 Subject: [PATCH] reverse AD ok! --- minima/operators.py | 3 + nbs/01_operators.ipynb | 376 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 342 insertions(+), 37 deletions(-) diff --git a/minima/operators.py b/minima/operators.py index b116def..55f0c53 100644 --- a/minima/operators.py +++ b/minima/operators.py @@ -920,6 +920,9 @@ def gradient(self, out_grad: Tensor, node: Tensor) -> Tuple[Tensor]: # If axes were specified, set those dimensions to 1 in the new shape if self.axes: for axis in self.axes: new_shape[axis] = 1 + + else: + new_shape = [1] * len(new_shape) # Reshape out_grad to the new shape reshaped_grad = reshape(out_grad, new_shape) diff --git a/nbs/01_operators.ipynb b/nbs/01_operators.ipynb index 7626871..131d23a 100644 --- a/nbs/01_operators.ipynb +++ b/nbs/01_operators.ipynb @@ -241,17 +241,14 @@ "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'needle' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m result\n", - "Cell \u001b[0;32mIn[3], line 26\u001b[0m, in \u001b[0;36mEWiseAdd.compute\u001b[0;34m(self, a, b)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, a: NDArray, b: NDArray) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m NDArray:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124;03m Computes the element-wise sum of two tensors.\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;124;03m The element-wise sum of a and b.\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43ma\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\n", - "File \u001b[0;32m~/Desktop/minima/minima/autograd.py:855\u001b[0m, in \u001b[0;36mTensor.__add__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 852\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m other\u001b[38;5;241m.\u001b[39mshape:\n\u001b[1;32m 853\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTensors must be of the same shape for addition. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mother\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 855\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mneedle\u001b[49m\u001b[38;5;241m.\u001b[39mops\u001b[38;5;241m.\u001b[39mEWiseAdd()(\u001b[38;5;28mself\u001b[39m, other)\n\u001b[1;32m 857\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(other, (\u001b[38;5;28mint\u001b[39m, \u001b[38;5;28mfloat\u001b[39m)):\n\u001b[1;32m 858\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m needle\u001b[38;5;241m.\u001b[39mops\u001b[38;5;241m.\u001b[39mAddScalar(scalar\u001b[38;5;241m=\u001b[39mother)(\u001b[38;5;28mself\u001b[39m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'needle' is not defined" - ] + "data": { + "text/plain": [ + "minima.Tensor([5 7 9])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -272,7 +269,18 @@ "execution_count": null, "id": "c6db6072-65fb-4cad-af75-e87236dbc921", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "minima.Tensor([5 7 9])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "result = add(a, b)\n", "result" @@ -291,7 +299,18 @@ "execution_count": null, "id": "ba584690-7fe5-4d93-95d7-2cd7e373811b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "minima.Tensor([5 7 9])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "op(a,b)" ] @@ -309,7 +328,19 @@ "execution_count": null, "id": "1dbbc80e-a2f0-46d2-97df-35344197da88", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "minima.Tensor([[ 8 10 12]\n", + " [14 16 18]])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "a = Tensor([[1, 2, 3], [4, 5, 6]])\n", "b = Tensor([[7, 8, 9], [10, 11, 12]])\n", @@ -1707,7 +1738,18 @@ "execution_count": null, "id": "b83b2489-10fb-4796-82cd-15c94d1586e1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([3, 5, 7]), torch.Size([5, 7]), torch.Size([7, 5]))" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "out_shape, A_shape, B_shape = out_grad.shape, A.shape, B.shape\n", "out_shape, A_shape, B_shape" @@ -1718,7 +1760,18 @@ "execution_count": null, "id": "7b883959-15a4-4f63-827e-6d2cea419e16", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 2)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "len(out_shape), len(A_shape)" ] @@ -1728,7 +1781,18 @@ "execution_count": null, "id": "b267d83b-2604-444e-8bd6-01e1c7b19cb2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "range(0, 1)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "rng = range(len(out_shape) - len(A_shape))\n", "rng" @@ -1739,7 +1803,18 @@ "execution_count": null, "id": "865fd42c-482d-479e-9d1f-9fe7183a5d37", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(0,)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "tuple(rng)" ] @@ -1749,7 +1824,18 @@ "execution_count": null, "id": "5849d174-7aab-4406-9c4f-caad9eaf95fe", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(0,)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "axes_to_sum_over = tuple(range(len(out_shape) - len(A_shape)))\n", "axes_to_sum_over" @@ -1760,7 +1846,22 @@ "execution_count": null, "id": "2ad258ce-40a0-4f6c-beed-69691d3c9a3c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ -5.0927, 2.2641, -4.6312, 1.7456, -3.5126],\n", + " [ -1.0134, 1.5722, -0.7978, -0.3207, 1.4216],\n", + " [ -4.8770, -7.6298, -4.5855, 4.5206, 4.9156],\n", + " [ 0.0776, -3.4887, -7.7766, 1.1798, -5.2914],\n", + " [ -1.1959, -6.8538, -10.8534, -0.3308, -4.9055]])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "torch.sum(out_grad @ B, axes_to_sum_over) " ] @@ -1910,7 +2011,18 @@ "execution_count": null, "id": "2c938f74-5389-453e-9970-9459de0182a1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([12., 15., 18.], grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "y" ] @@ -1930,7 +2042,17 @@ "execution_count": null, "id": "d17f204d-bf16-41b7-803b-a3e6a45ce692", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[1., 1., 1.],\n", + " [1., 1., 1.],\n", + " [1., 1., 1.]])\n" + ] + } + ], "source": [ "# Mock out_grad tensor\n", "out_grad = torch.tensor([1., 1., 1.])\n", @@ -2014,6 +2136,9 @@ " # If axes were specified, set those dimensions to 1 in the new shape\n", " if self.axes:\n", " for axis in self.axes: new_shape[axis] = 1\n", + " \n", + " else:\n", + " new_shape = [1] * len(new_shape)\n", "\n", " # Reshape out_grad to the new shape\n", " reshaped_grad = reshape(out_grad, new_shape)\n", @@ -2071,7 +2196,20 @@ "execution_count": null, "id": "1cc455d1-a34b-4edc-abd5-72adc86fa13b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1., 2., 3.],\n", + " [1., 2., 3.],\n", + " [1., 2., 3.]], grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "b" ] @@ -2081,7 +2219,18 @@ "execution_count": null, "id": "44ae4881-226b-42c5-bbf8-678b83f74b50", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 3])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "b.shape" ] @@ -2106,7 +2255,18 @@ "execution_count": null, "id": "bd44d852-e580-49ef-93a5-3f7a61be32d9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([3., 3., 3.])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "a.grad" ] @@ -2116,7 +2276,18 @@ "execution_count": null, "id": "24f51faf-819b-4937-94c7-5817b1fa8316", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 3])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Define the output gradient tensor\n", "out_grad = torch.tensor([[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]])\n", @@ -2128,7 +2299,18 @@ "execution_count": null, "id": "7f5f46a8-4fcf-4c53-9e8e-5484a4b4f789", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "a_shape = a.shape\n", "a_shape" @@ -2139,7 +2321,18 @@ "execution_count": null, "id": "069bcef5-d269-4625-b957-dfbc1555abb0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[3]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "shape = [1] * (len((3,3)) - len((3,3))) + list(a_shape)\n", "shape" @@ -2166,7 +2359,15 @@ "execution_count": null, "id": "946d4abc-1bf7-4a87-9e73-b6c2e5ac2c92", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([3., 6., 9.])\n" + ] + } + ], "source": [ "# The gradient for the broadcast operation is the sum of out_grad over the dimension that was broadcasted\n", "grad_a = out_grad.sum(dim=0)\n", @@ -2268,7 +2469,18 @@ "execution_count": null, "id": "d3761138-0a81-4a85-b966-04708310479c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 3)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "a.shape" ] @@ -2278,7 +2490,18 @@ "execution_count": null, "id": "87adc1bb-0b23-466f-b03c-de92b35e5407", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(5, 2, 3)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "a_br = br.compute(a)\n", "a_br.shape" @@ -2289,7 +2512,18 @@ "execution_count": null, "id": "75e4c955-9b3f-4bbd-a2c3-e6ee06db40dc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(5, 2, 3)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "out_grad = Tensor(numpy.ones_like(a_br))\n", "out_grad.shape" @@ -2300,7 +2534,31 @@ "execution_count": null, "id": "240e5580-3a55-461e-9b0e-fd5f5647b3c8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "minima.Tensor([[[1 1 1]\n", + " [1 1 1]]\n", + "\n", + " [[1 1 1]\n", + " [1 1 1]]\n", + "\n", + " [[1 1 1]\n", + " [1 1 1]]\n", + "\n", + " [[1 1 1]\n", + " [1 1 1]]\n", + "\n", + " [[1 1 1]\n", + " [1 1 1]]])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "out_grad" ] @@ -2310,7 +2568,18 @@ "execution_count": null, "id": "39c1c27b-b01d-4563-8483-f15dc370f8f6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 3)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "a_shape = a.shape\n", "a_shape" @@ -2331,7 +2600,18 @@ "execution_count": null, "id": "7d0bb6a2-effd-4d41-8038-dcc46d7b3efc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "((5, 2, 3), [1, 2, 3])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "br.shape, shape" ] @@ -2341,7 +2621,18 @@ "execution_count": null, "id": "835d2f72-61a7-4fb9-b94b-2dcba608c028", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(0,)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "sum_over = tuple([idx for idx in range(len(br.shape)) if br.shape[idx] != shape[idx]])\n", "sum_over" @@ -2352,7 +2643,18 @@ "execution_count": null, "id": "9cf744e8-42b4-49cd-bb21-d38a4931b917", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 3)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "reshape(summation(out_grad, sum_over), a_shape).shape" ]