In [1]:
!pip install git+https://github.com/kandarpa02/neonet.git

Collecting git+https://github.com/kandarpa02/neonet.git
  Cloning https://github.com/kandarpa02/neonet.git to /tmp/pip-req-build-2ql3ovk1
  Running command git clone --filter=blob:none --quiet https://github.com/kandarpa02/neonet.git /tmp/pip-req-build-2ql3ovk1
  Resolved https://github.com/kandarpa02/neonet.git to commit ab86f58baa51cf9175b60bf74be573731f209e10
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: neonet
  Building wheel for neonet (setup.py) ... [?25l[?25hdone
  Created wheel for neonet: filename=neonet-0.0.1a1-py3-none-any.whl size=7385 sha256=8e700a56d896e4a9e82aca94585abbc37a69fa33cbf64ada66ae6b5bf8ff1659
  Stored in directory: /tmp/pip-ephem-wheel-cache-s221lttu/wheels/30/68/a3/1e288d38f0373f2c46d66052dfbc2f2bc5a647564ceaa6b289
Successfully built neonet
Installing collected packages: neonet
Successfully installed neonet-0.0.1a1


In [2]:
import neonet as neo
import neonet.numpy as nep
from neonet import autograd
from neonet.functions import fn_forward

In [24]:
# You can define any funcion and its backward rule
# with autograd.Policy module, its inner working a bit
# verbose, I wil make everythng clear once it is complete

class IF_IT_WORKS_DONT_TOUCH_IT(autograd.Policy):
    def forward(self, X, Y, b):
        self.ctx.save(X, Y, b)
        return (X @ Y) + b
    
    def backward(self, grad):
        X, Y, b = self.ctx.release
        x_grad = grad @ Y.T
        y_grad = X.T @ grad
        b_grad = grad.sum(axis=0) if b.size > 1 else grad.sum()
        return x_grad, y_grad, b_grad

In [25]:
X = neo.randn((3,4), device='cuda') # for 'cuda' it uses cupy under the hood (numpy's evil twin)
Y = neo.randn((4,2), device='cuda')
b = neo.randn((2,), device='cuda')

In [26]:
print(f"matrix X\n{X}")
print(f"matrixY\n{Y}")
print(f"vector b\n{b}")

matrix X
[[ 1.53645926  0.68121459  0.10153642 -0.62356898]
 [ 2.08279636  1.0575698   0.86665142 -0.30455532]
 [ 0.73587747 -1.62454334 -0.89650051 -0.64107397]]
matrixY
[[ 0.78848922 -0.91403799]
 [ 0.8977777   0.00596389]
 [-1.81868782  1.41260409]
 [ 0.88750909 -0.32401679]]
vector b
[-0.10963281  1.1082946 ]


In [46]:
forward = fn_forward(IF_IT_WORKS_DONT_TOUCH_IT)

out, grads = autograd.value_and_grad(forward)(X, Y, b)
print("Output :\n", out, "\n")

matrices = list(grads.values())
names = ["X_grad", "Y_grad", "b_grad"]

for name, mat in zip(names, matrices):
    print(f"Matrix {name}:\n{mat}\n")

Output :
 [[ 0.97534184  0.05345273]
 [ 0.63562825  0.53376321]
 [ 0.07361545 -0.63269553]] 

Matrix X_grad:
[[-0.12554877  0.90374159 -0.40608373  0.5634923 ]
 [-0.12554877  0.90374159 -0.40608373  0.5634923 ]
 [-0.12554877  0.90374159 -0.40608373  0.5634923 ]]

Matrix Y_grad:
[[ 4.3551331   4.3551331 ]
 [ 0.11424104  0.11424104]
 [ 0.07168733  0.07168733]
 [-1.56919827 -1.56919827]]

Matrix b_grad:
[3. 3.]



In [48]:
import jax.numpy as jnp
from jax import grad as gfn

X_, Y_, b_ = X.to('cpu').numpy(), Y.to('cpu').numpy(), b.to('cpu').numpy()

grads_jax = gfn(lambda x, y, b: (x@y + b).sum(), argnums=[0,1,2])(X_, Y_, b_)

matrices = list(grads_jax)
names = ["X_JAX_grad", "Y_JAX_grad", "b_JAX_grad"]

for name, mat in zip(names, matrices):
    print(f"Matrix {name}:\n{mat}\n")

Matrix X_JAX_grad:
[[-0.12554878  0.90374154 -0.4060837   0.5634923 ]
 [-0.12554878  0.90374154 -0.4060837   0.5634923 ]
 [-0.12554878  0.90374154 -0.4060837   0.5634923 ]]

Matrix Y_JAX_grad:
[[ 4.355133    4.355133  ]
 [ 0.114241    0.114241  ]
 [ 0.07168728  0.07168728]
 [-1.5691983  -1.5691983 ]]

Matrix b_JAX_grad:
[3. 3.]

