In [5]:
import jax
import jax.numpy as jnp
from jax import random
import math
from typing import Callable

In [6]:
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [10]:
rng = jax.random.PRNGKey(42)

In [7]:
class SimpleClassifier(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    def setup(self):
        # Create the modules we need to build the network
        # nn.Dense is a linear layer
        self.linear1 = nn.Dense(features=self.num_hidden)
        self.linear2 = nn.Dense(features=self.num_outputs)

    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)
        return x

In [8]:
model = SimpleClassifier(num_hidden=8, num_outputs=1)
# Printing the model shows its attributes
print(model)

SimpleClassifier(
    # attributes
    num_hidden = 8
    num_outputs = 1
)


In [13]:
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
print(inp)

[[-0.21089035 -1.3627948 ]
 [-0.04500385 -1.1536394 ]
 [ 1.9141139  -0.47701314]
 [ 0.6478766   0.6747401 ]
 [ 2.9508727  -0.8744793 ]
 [ 1.3046614  -0.525778  ]
 [ 0.5039801   1.0394477 ]
 [-0.16569884 -0.4633415 ]]


In [12]:
model.apply(params, inp)

Array([[ 0.08598191],
       [ 0.18361846],
       [ 0.23252794],
       [-0.41932803],
       [ 0.09644738],
       [ 0.02926508],
       [-0.44354892],
       [-0.4412725 ]], dtype=float32)