In [35]:
import haiku as hk
import jax
import jax.numpy as np


class DQN(hk.Module):
    def __init__(self, hidden_size, n_actions):
        super(DQN, self).__init__()
        self.hidden_size = hidden_size
        self.n_actions = n_actions

    def __call__(self, x):
        # cnn hyperparameters are as described in the METHODS section of
        # Mnih et al., 2015, Human-level control through deep reinforcement learning
        cnn = hk.Sequential([
            hk.Conv2D(32, 8, 4),
            jax.nn.relu,
            hk.Conv2D(64, 4, 2),
            jax.nn.relu,
            hk.Conv2D(64, 3, 1),
            jax.nn.relu,
            hk.Linear(self.hidden_size),
            jax.nn.relu,
            hk.Linear(self.n_actions)
            ])
        return cnn(x)


if __name__ == "__main__":
    model = hk.transform(DQN)
    print("model:", model)
    params = model.init(None, 512, 10)
    print("params:", params)
    x = np.ones((84, 84, 4))
    y = model.apply(params, None, x)
    print("y:", y)
    d = jax.grad(model.apply)(params, None, x)

model: Transformed(init=<function without_state.<locals>.init_fn at 0x7f8dbc339830>, apply=<function without_state.<locals>.apply_fn at 0x7f8dbc3398c0>)
params: frozendict({})
y: DQN(
    hidden_size=None,
    n_actions=DeviceArray([[[1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            ...,
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.]],
              
                           [[1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            ...,
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.]],
              
                           [[1., 1., 1., 1.],
                            [1., 1., 1., 1.],
                            [1., 1., 1., 1.],
      

TypeError: __init__() missing 1 required positional argument: 'n_actions'

In [28]:
model.get("hidden_size")

In [22]:
model.keys()

KeysOnlyKeysView([])