In [10]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random
import jax

In [11]:
# Generate key which is used to generate random numbers
key = random.PRNGKey(1)

In [12]:
def ReLU(x):
    """ Rectified Linear Unit (ReLU) activation function """
    return np.maximum(0, x)

In [13]:
from flax import linen as nn

In [14]:
class SimpleClassifier(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    def setup(self):
        # Create the modules we need to build the network
        # nn.Dense is a linear layer
        self.linear1 = nn.Dense(features=self.num_hidden)
        self.linear2 = nn.Dense(features=self.num_outputs)

    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)
        return x

In [15]:
model = SimpleClassifier(num_hidden=8, num_outputs=1)

In [16]:
model

SimpleClassifier(
    # attributes
    num_hidden = 8
    num_outputs = 1
)

In [18]:
rng = jax.random.PRNGKey(42)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
print(params)

{'params': {'linear1': {'kernel': Array([[ 0.5564613 ,  0.9367376 ,  0.2285179 , -0.23255277, -0.25101846,
        -0.48948383,  0.11607227,  0.40487856],
       [-0.3619682 ,  0.9271343 ,  0.6478837 ,  0.26224074,  0.34578732,
         1.1132734 ,  0.06098709,  0.49297702]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'linear2': {'kernel': Array([[ 0.4818003 ],
       [-0.35573798],
       [-0.62196773],
       [ 0.28606406],
       [-0.79486924],
       [ 0.5573447 ],
       [-0.1400483 ],
       [ 0.41512278]], dtype=float32), 'bias': Array([0.], dtype=float32)}}}


In [21]:
model.apply(params, np.array([1,2]))

Array([-0.38878164], dtype=float32)