In [14]:
import jax
import jax.numpy as jnp
from jax import random
import math
from typing import Callable
import os
import orbax.checkpoint
import optax
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'


In [15]:
try:
    import flax
except ModuleNotFoundError:  # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn

In [16]:
rng = jax.random.PRNGKey(42)

In [17]:
class SimpleClassifier(nn.Module):
    num_hidden: int  # Number of hidden neurons
    num_outputs: int  # Number of output neurons

    def setup(self):
        # Create the modules we need to build the network
        # nn.Dense is a linear layer
        self.linear1 = nn.Dense(features=self.num_hidden)
        self.linear2 = nn.Dense(features=self.num_outputs)

    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)
        return x

In [18]:
model = SimpleClassifier(num_hidden=8, num_outputs=1)
# Printing the model shows its attributes
print(model)

SimpleClassifier(
    # attributes
    num_hidden = 8
    num_outputs = 1
)


In [19]:
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  # Batch size 8, input size 2
# Initialize the model
params = model.init(init_rng, inp)
print(inp)

[[ 0.60576403  0.7990441 ]
 [-0.908927   -0.63525754]
 [-1.2226585  -0.83226097]
 [-0.47417238 -1.2504351 ]
 [-0.17678244 -0.04917514]
 [-0.41177532 -0.39363015]
 [ 1.3116323   0.21555556]
 [ 0.41164538 -0.28955024]]


In [20]:
res = model.apply(params, inp)
res.dtype

dtype('float32')

In [22]:
res = jnp.asarray(res, dtype=jnp.bfloat16)
res.dtype

dtype(bfloat16)

In [8]:
params

{'params': {'linear1': {'kernel': Array([[-1.4184448 , -0.13778795,  0.01538001, -0.16879076, -0.04171572,
           -0.13396461,  1.3444221 ,  0.3372816 ],
          [-0.88903946, -0.36091748, -0.41084424,  1.3910713 ,  1.4182491 ,
           -0.68443036, -0.84274894,  1.0029515 ]], dtype=float32),
   'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)},
  'linear2': {'kernel': Array([[ 0.06771415],
          [ 0.3815392 ],
          [-0.44763517],
          [ 0.10989622],
          [-0.12707736],
          [ 0.03953529],
          [-0.51339453],
          [ 0.33707327]], dtype=float32),
   'bias': Array([0.], dtype=float32)}}}

In [None]:

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  