In [1]:
import flax
import flax.linen as nn
import jax
import jax.nn.initializers as initializers
import jax.numpy as jnp
from flax.training.train_state import TrainState
import numpy as np

import optax

key = jax.random.PRNGKey(1234)
key, test_key = jax.random.split(key, 2)

In [2]:
class Test(nn.Module):
    """
    Critic Network
    """
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(2, name="D0")(x)
        x = nn.Dense(2, name="D1")(x)
        x = nn.Dense(6, name="D2")(x)
        return x

In [3]:
test = Test()
obs = jnp.array([[0.0, 0.0, -0.5, 0.5]])

In [4]:
test_state = TrainState.create(
    apply_fn=test.apply,
    params=test.init(test_key, np.ones_like(obs)),
    tx=optax.adam(learning_rate=1e-3),
)

In [5]:
test_state.params

{'params': {'D0': {'kernel': Array([[ 0.21801375,  0.65828794],
          [-1.05181   , -0.51655525],
          [ 0.7681196 , -0.1240018 ],
          [-0.84719193, -0.1861443 ]], dtype=float32),
   'bias': Array([0., 0.], dtype=float32)},
  'D1': {'kernel': Array([[ 0.08604589, -0.616443  ],
          [ 0.13109617,  0.25108585]], dtype=float32),
   'bias': Array([0., 0.], dtype=float32)},
  'D2': {'kernel': Array([[ 0.21583107, -1.3484992 ,  1.0286711 , -0.76696074,  1.4725155 ,
            0.32232338],
          [-1.0597309 , -0.18715075, -0.51606786, -0.48071578, -0.21511376,
           -0.25362113]], dtype=float32),
   'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}}

In [6]:
test.apply(test_state.params, obs)

Array([[-0.53522307,  0.00749008, -0.32858858, -0.17916106, -0.21375245,
        -0.14800559]], dtype=float32)

In [12]:
weights = test_state.params['params']

In [13]:
x = np.dot(obs, weights['D0']['kernel']) + weights['D0']['bias']
x = np.dot(x, weights['D1']['kernel']) + weights['D1']['bias']
output = np.dot(x, weights['D2']['kernel']) + weights['D2']['bias']

output

Array([[-0.535223  ,  0.00749008, -0.32858858, -0.17916106, -0.21375243,
        -0.14800559]], dtype=float32)