In [6]:
# https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html


import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

In [4]:
class MyLinear1(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
    return jnp.dot(x, w) + b

In [7]:
def _forward_fn_linear1(x):
  module = MyLinear1(output_size=2)
  return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

forward_linear1

Transformed(init=<function without_state.<locals>.init_fn at 0x7f04ce069b80>, apply=<function without_state.<locals>.apply_fn at 0x7f04ce069c10>)

In [8]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)



{'my_linear1': {'w': DeviceArray([[-0.30350363,  0.5123802 ],
             [ 0.08009142, -0.3163005 ],
             [ 0.6056666 ,  0.5820702 ]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32)}}


In [9]:
sample_x = jnp.array([[1., 2., 3.]])
sample_x_2 = jnp.array([[4., 5., 6.], [7., 8., 9.]])

output_1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
# Outputs are identical for given inputs since the forward inference is non-stochastic.
output_2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

output_3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)

print(f'Output 1 : {output_1}')
print(f'Output 2 (same as output 1): {output_2}')
print(f'Output 3 : {output_3}')

Output 1 : [[2.6736789 2.6259897]]
Output 2 (same as output 1): [[2.6736789 2.6259897]]
Output 3 : [[3.820442 4.960439]
 [4.967205 7.294889]]


Built-in Haiku nets and nested modules

In [10]:
# See: https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules

class MyModuleCustom(hk.Module):
  def __init__(self, output_size=2, name='custom_linear'):
    super().__init__(name=name)
    self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
    self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')

  def __call__(self, x):
    return self._internal_linear_2(self._internal_linear_1(x))

def _custom_forward_fn(x):
  module = MyModuleCustom()
  return module(x)

custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
params

{'custom_linear/~/hk_internal_linear/~/linear_0': {'w': DeviceArray([[-0.30350363,  0.5123802 ],
               [ 0.08009142, -0.3163005 ],
               [ 0.6056666 ,  0.5820702 ]], dtype=float32),
  'b': DeviceArray([0., 0.], dtype=float32)},
 'custom_linear/~/hk_internal_linear/~/linear_1': {'w': DeviceArray([[-0.22075887, -0.27375957,  0.5931483 ],
               [ 0.7818068 ,  0.72626334, -0.6860752 ]], dtype=float32),
  'b': DeviceArray([0., 0., 0.], dtype=float32)},
 'custom_linear/~/old_linear': {'w': DeviceArray([[ 0.28584382,  0.31626168],
               [ 0.2335775 , -0.4827032 ],
               [-0.14647584, -0.7185701 ]], dtype=float32),
  'b': DeviceArray([1., 1.], dtype=float32)}}