In [1]:
!pip install git+https://github.com/kandarpa02/neonet.git

Collecting git+https://github.com/kandarpa02/neonet.git
  Cloning https://github.com/kandarpa02/neonet.git to /tmp/pip-req-build-br1xvgem
  Running command git clone --filter=blob:none --quiet https://github.com/kandarpa02/neonet.git /tmp/pip-req-build-br1xvgem
  Resolved https://github.com/kandarpa02/neonet.git to commit 7a4bc5826f457123a64de4db03cdb87e485a8307
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: neonet
  Building wheel for neonet (setup.py) ... [?25l[?25hdone
  Created wheel for neonet: filename=neonet-0.0.1a1-py3-none-any.whl size=9031 sha256=c6cbe0151c80a93c1dade0279dc326c622cde66a8e91a579f3c475760a6d4e1a
  Stored in directory: /tmp/pip-ephem-wheel-cache-bntaa5_a/wheels/30/68/a3/1e288d38f0373f2c46d66052dfbc2f2bc5a647564ceaa6b289
Successfully built neonet
Installing collected packages: neonet
Successfully installed neonet-0.0.1a1


In [None]:
import neo
import neo.numpy as nep
from neo import autograd
from neo.functions import neo_function

In [None]:
# You can define any funcion and its backward rule
# with autograd.Policy module, its inner working a bit
# verbose, I will make everythng clear once it is complete

class IF_IT_WORKS_DONT_TOUCH_IT(autograd.Policy):
    def forward(self, X, Y, b):
        self.ctx.save(X, Y, b)
        return (X @ Y) + b
    
    def backward(self, grad):
        X, Y, b = self.ctx.release
        x_grad = grad @ Y.T
        y_grad = X.T @ grad
        b_grad = grad.sum(axis=0) if b.size > 1 else grad.sum()
        return x_grad, y_grad, b_grad

In [4]:
X = neo.randn((3,4), device='cuda') # for 'cuda' it uses cupy under the hood (numpy's evil twin)
Y = neo.randn((4,2), device='cuda')
b = neo.randn((2,), device='cuda')

In [5]:
print(f"matrix X\n{X}")
print(f"matrixY\n{Y}")
print(f"vector b\n{b}")

matrix X
[[ 1.63607024  0.65679874  0.49440371 -1.00391713]
 [-0.14999782  1.11938317 -1.35716351 -0.21304375]
 [-0.03646202  0.55950737 -1.20680271  0.70346151]]
matrixY
[[-0.86843071  1.05386462]
 [-0.9428054   1.17611029]
 [ 1.19467713  0.51843372]
 [ 0.5259126   0.19863528]]
vector b
[ 1.32473385 -1.3522553 ]


In [None]:
forward = neo_function(IF_IT_WORKS_DONT_TOUCH_IT)

out, grads = autograd.session.value_and_grad(forward)(X, Y, b)
print("Output :\n", out, "\n")

matrices = list(grads.values())
names = ["X_grad", "Y_grad", "b_grad"]

for name, mat in zip(names, matrices):
    print(f"Matrix {name}:\n{mat}\n")

Output :
 [[-0.65263305  1.20131121]
 [-1.33377853 -0.93973197]
 [-0.24288831 -1.21855391]] 

Matrix X_grad:
[[0.18543391 0.23330489 1.71311085 0.72454788]
 [0.18543391 0.23330489 1.71311085 0.72454788]
 [0.18543391 0.23330489 1.71311085 0.72454788]]

Matrix Y_grad:
[[ 1.4496104   1.4496104 ]
 [ 2.33568928  2.33568928]
 [-2.0695625  -2.0695625 ]
 [-0.51349937 -0.51349937]]

Matrix b_grad:
[3. 3.]



In [7]:
import jax.numpy as jnp
from jax import grad as gfn

X_, Y_, b_ = X.to('cpu').numpy(), Y.to('cpu').numpy(), b.to('cpu').numpy()

grads_jax = gfn(lambda x, y, b: (x@y + b).sum(), argnums=[0,1,2])(X_, Y_, b_)

matrices = list(grads_jax)
names = ["X_JAX_grad", "Y_JAX_grad", "b_JAX_grad"]

for name, mat in zip(names, matrices):
    print(f"Matrix {name}:\n{mat}\n")

INFO:2025-07-18 19:19:05,301:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-07-18 19:19:05,314:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Matrix X_JAX_grad:
[[0.18543386 0.23330486 1.7131109  0.72454786]
 [0.18543386 0.23330486 1.7131109  0.72454786]
 [0.18543386 0.23330486 1.7131109  0.72454786]]

Matrix Y_JAX_grad:
[[ 1.4496104  1.4496104]
 [ 2.3356893  2.3356893]
 [-2.0695624 -2.0695624]
 [-0.5134994 -0.5134994]]

Matrix b_JAX_grad:
[3. 3.]

