In [39]:
import jax
import jax.numpy as np

In [40]:
class Layer:
    def __init__(self, key, input_dim, output_dim, activation_function):
        self.weight = jax.random.normal(key, shape=(input_dim, output_dim))
        self.bias = jax.random.normal(key, shape=(output_dim, ))
        self.activation_function = activation_function
    
    def forward(self, X):
        return self.activation_function(X @ self.weight + self.bias)
    
    def parameters(self):
        return self.weight, self.bias

In [41]:
class MLP:
    def __init__(self, key, inout):
        self.layers = [
            Layer(key, inout[i], inout[i+1], jax.nn.relu if i < len(inout)-1 else lambda x: x)
            for i, key in enumerate(jax.random.split(key, len(inout)-1))
        ]
        
    def forward(self, X):
        for layer in self.layers:
            X = layer.forward(X)
        return X
    
    def parameters(self):
        return [layer.parameters() for layer in self.layers]

In [42]:
key = jax.random.PRNGKey(3022023)
m = MLP(key, [10,5,5,5,2])

In [43]:
X = jax.random.normal(key, shape=(100,10))

In [44]:
m.parameters()

[(Array([[-0.6132147 ,  0.36154914, -0.74044627,  1.6700312 , -0.19783515],
         [ 0.5530242 , -0.05752362,  1.5809982 , -0.31582314,  1.1741048 ],
         [ 2.479786  , -2.2329628 ,  0.472316  ,  0.37306872, -0.2113169 ],
         [ 2.073347  , -0.7681033 , -1.5203044 , -1.5845248 , -0.89777017],
         [-0.27748603,  0.51009345, -0.98009336,  0.28076887,  0.32567716],
         [ 0.7538772 , -1.1409338 , -0.5065866 ,  0.44741458, -0.25905186],
         [-0.31241515,  0.09745613,  0.94000655, -1.1347353 ,  0.42485195],
         [ 1.2054641 ,  0.6270398 ,  0.29184487, -0.18270147, -0.15717538],
         [-1.0393407 , -1.793633  , -0.38585338, -1.1537268 , -0.42840117],
         [ 0.03811966, -1.3135958 ,  0.1697295 , -0.07127163, -0.03138855]],      dtype=float32),
  Array([-0.5372561 ,  0.8950548 ,  0.46274158,  0.18877883,  1.0085387 ],      dtype=float32)),
 (Array([[-0.00855536,  0.6676877 , -0.49237144, -0.9737191 , -0.7140486 ],
         [-1.2498182 ,  0.09928743,  0.213685

In [45]:
m.forward(X).sum()

Array(50.209366, dtype=float32)

In [46]:
def Dense(input_dim, output_dim, activation_function=lambda x: x):
    """
    Creates an init function and forward functions.
    The forward function takes as arguments:
    * init return
    * X
    
    and returns Y.
    """
    
    def init_func(key):
        weight = jax.random.normal(key, shape=(input_dim, output_dim))
        bias = np.zeros((output_dim, ))
        return weight, bias
    
    def forward_func(params, X):
        weight, bias = params
        return activation_function(X @ weight + bias)
    
    return init_func, forward_func

def Sequential(*args):
    init_functions, forward_functions = zip(*args)
    def init_func(key):
        return [
            init_func(key)
            for init_func, key in zip(init_functions, jax.random.split(key, len(init_functions)))
        ]
    def forward_func(sequence, X):
        for params, forward_func in zip(sequence, forward_functions):
            X = forward_func(params, X)
        return X
    
    return init_func, forward_func

In [51]:
init_func, forward = Sequential(
    Dense(10,10, jax.nn.relu),
    Dense(10,10, jax.nn.relu),
    Dense(10,10)
)

In [52]:
key = jax.random.PRNGKey(3022023)
params = init_func(key)

In [53]:
X = jax.random.normal(key, shape=(100,10))

In [57]:
forward(params, X)

Array([[ 1.67991886e+01, -4.43893003e+00, -6.75098991e+00,
         1.27910223e+01,  7.31538343e+00,  1.45428228e+01,
        -2.13500404e+00,  4.68682938e+01, -3.22327766e+01,
         2.82776585e+01],
       [ 1.92166209e+00, -1.31667519e+00,  8.53878689e+00,
        -2.90949488e+00,  7.75505662e-01,  4.67215443e+00,
         6.28032160e+00,  5.96889591e+00, -2.94500828e+00,
         1.27225990e+01],
       [ 2.92721291e+01, -3.26458168e+01, -9.17534065e+00,
         5.81160879e+00, -6.54226542e-01,  1.92927551e+01,
         2.81190491e+00,  4.25197601e+01, -4.09291000e+01,
         1.89763393e+01],
       [ 4.02200317e+01, -1.92107925e+01, -7.52824926e+00,
         1.27588568e+01,  4.24554491e+00,  1.90923538e+01,
        -5.16887808e+00,  5.10422554e+01, -3.77789764e+01,
         2.42275906e+01],
       [ 8.04972363e+00,  1.14204097e+00,  1.04273033e+01,
         1.52237630e+00,  1.71597719e+00,  5.23491621e+00,
         2.44497752e+00,  1.13938398e+01, -1.04997149e+01,
         5.