In [22]:
import jax
import jax.numpy as jnp
from jax import random
import math
from typing import Callable
import os
import orbax.checkpoint
import optax
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'


In [9]:
try:
    import flax
except ModuleNotFoundError:  # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [10]:
rng = jax.random.PRNGKey(42)

In [17]:
class SimpleClassifier(nn.Module):
    num_hidden: int  # Number of hidden neurons
    num_outputs: int  # Number of output neurons

    def setup(self):
        # Create the modules we need to build the network
        # nn.Dense is a linear layer
        self.linear1 = nn.Dense(features=self.num_hidden)
        self.linear2 = nn.Dense(features=self.num_outputs)

    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)
        return x

In [18]:
model = SimpleClassifier(num_hidden=8, num_outputs=1)
# Printing the model shows its attributes
print(model)

SimpleClassifier(
    # attributes
    num_hidden = 8
    num_outputs = 1
)


In [19]:
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
print(inp)

[[-0.21089035 -1.3627948 ]
 [-0.04500385 -1.1536394 ]
 [ 1.9141139  -0.47701314]
 [ 0.6478766   0.6747401 ]
 [ 2.9508727  -0.8744793 ]
 [ 1.3046614  -0.525778  ]
 [ 0.5039801   1.0394477 ]
 [-0.16569884 -0.4633415 ]]


In [20]:
model.apply(params, inp)

Array([[-0.15102524],
       [-0.24684235],
       [-1.0363195 ],
       [-0.4635623 ],
       [-1.1413009 ],
       [-0.9482199 ],
       [-0.27979338],
       [ 0.10091412]], dtype=float32)

In [21]:
params

{'params': {'linear1': {'kernel': Array([[ 0.9349091 , -0.6829195 ,  1.1437343 , -0.00761674,  0.708436  ,
            0.43199268,  0.3907109 ,  0.2719054 ],
          [-1.3076503 ,  0.7088574 ,  0.01150007, -0.22155732, -0.4058151 ,
            0.738009  ,  0.26554996, -0.18983054]], dtype=float32),
   'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
  'linear2': {'kernel': Array([[ 0.5661024 ],
          [ 0.45534813],
          [-0.55111575],
          [-0.4596257 ],
          [-0.7046358 ],
          [ 0.16724774],
          [-0.0724616 ],
          [ 0.03341391]], dtype=float32),
   'bias': Array([0.], dtype=float32)}}}

In [23]:
ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  