In [14]:
import jax 
import jax.numpy as np
from jax import jit, vmap, grad, value_and_grad

In [105]:
class Dense:
    def __init__(self, input_dim, output_dim, activation_function):
        self.input_dim = input_dim
        self.output_dim = output_dim 
        self.activation_function = activation_function
        
    def init(self, key):
        self.weight = jax.random.normal(key, shape=(self.input_dim, self.output_dim))
        self.bias = np.zeros((self.output_dim, ))
        return self
            
    def forward(self, params, X):
        weight, bias = params
        return self.activation_function(X @ weight + bias)
    
    def parameters(self):
        return (self.weight, self.bias)
    
    def __call__(self, X):
        return self.forward(self.parameters(), X)
    
    def __repr__(self):
        return f"Dense({self.input_dim}, {self.output_dim}, {self.activation_function})"
    
class Sequential:
    def __init__(self, *layers):
        self.layers = layers
        self.forwards = [l.forward for l in layers]
        
    def init(self, key):
        keys = jax.random.split(key, len(self.layers))
        for key, layer in zip(keys, self.layers): 
            layer.init(key)
        return self
        
    def forward(self, params, X):
        for forward, params in zip(self.forwards, params):
            X = forward(params, X)
        return X
    
    def parameters(self):
        return [l.parameters() for l in self.layers]
    
    def __call__(self, X):
        return self.forward(self.parameters(), X)
    
    def __getitem__(self, idx):
        return self.layers[idx]
    
    def __repr__(self):
        return "Sequential(\n\t" + "\n\t".join([str(l) for l in self.layers]) + "\n)"

In [106]:
key = jax.random.PRNGKey(10022023)
X = jax.random.normal(key, shape=(100,10))

In [107]:
m = Dense(10, 1, jax.nn.relu).init(key)

In [108]:
f = m.forward # (m.parameters(), X)

In [109]:
f(m.parameters(), X).sum()

Array(132.36319, dtype=float32)

In [110]:
m = Sequential(
    Dense(10, 10, jax.nn.relu),
    Dense(10, 10, jax.nn.relu),
    Dense(10, 1, jax.nn.relu)
).init(key)

In [111]:
m.forward(m.parameters(), X).sum()

Array(57.864437, dtype=float32)

In [112]:
mgrad = grad(lambda params: m.forward(params, X).sum())

In [113]:
mgrad(m.parameters())

[(Array([[ -0.626534  , -17.438547  ,  16.129044  ,   1.2630304 ,
            2.1067398 ,  -6.9283752 ,   2.2276483 ,   0.68698466,
           -2.61875   , -11.255062  ],
         [ -6.12702   ,   7.6707716 ,   2.2247057 ,   1.0202217 ,
            4.463752  ,   1.1785495 ,   3.4267712 ,   0.06752861,
            3.4373503 ,  -0.7026036 ],
         [  2.4765031 , -19.501953  ,  -0.06946325,  -1.2984426 ,
            7.8993673 ,   5.3442926 ,   0.19254339,   2.2392218 ,
            2.4026012 ,  -8.619766  ],
         [  4.3844624 ,  -2.7639103 ,   2.2434168 ,   1.2925116 ,
           -7.4353733 ,  -0.08772272,  -4.419743  ,   1.8536824 ,
           -2.6192687 ,  -1.557554  ],
         [  0.26664865,   8.552589  ,  -0.87628794,   0.3133682 ,
           -8.754532  ,  -3.5224063 ,  -1.4837507 ,   3.55223   ,
           -1.8649522 ,   0.90313494],
         [  0.6428221 ,  -7.1643724 ,  12.608136  ,  -2.7341893 ,
            2.172816  ,  -2.2688553 ,   1.3390968 ,  -0.47164115,
           -1

In [114]:
m

Sequential(
	Dense(10, 10, <jax._src.custom_derivatives.custom_jvp object at 0x7f205c95ef50>)
	Dense(10, 10, <jax._src.custom_derivatives.custom_jvp object at 0x7f205c95ef50>)
	Dense(10, 1, <jax._src.custom_derivatives.custom_jvp object at 0x7f205c95ef50>)
)