In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

Haiku modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs. 

JAX operates on pure function transformations, Haiku modules cannot be instantiated verbatim. --> modules need to be wrapped into pure function transformations.

Haiku provides a simple function transformation: _hk.transform_. Turns functions that use these object-oriented, functionally “impure” modules into pure functions that can be used with JAX.

In [2]:
# Example class, 1 FC linear layer
class MyLinear1(hk.Module): # inherits from hk.Module

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size # input size, output size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j)) # weight initialization
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
    return jnp.dot(x, w) + b

In [10]:
# wrap example class function
def _forward_fn_linear1(x):
  module = MyLinear1(output_size=2)
  return module(x)

# pass function to hk.transform
forward_linear1 = hk.transform(_forward_fn_linear1)

In [11]:
?? forward_linear1

[0;31mType:[0m        Transformed
[0;31mString form:[0m Transformed(init=<function without_state.<locals>.init_fn at 0x11138cdc0>, apply=<function without_state.<locals>.apply_fn at 0x111417040>)
[0;31mLength:[0m      2
[0;31mFile:[0m        ~/miniconda3/envs/safe_sim2real/lib/python3.9/site-packages/haiku/_src/transform.py
[0;31mSource:[0m     
[0;32mclass[0m [0mTransformed[0m[0;34m([0m[0mNamedTuple[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m  [0;34m"""Holds a pair of pure functions.[0m
[0;34m[0m
[0;34m  Attributes:[0m
[0;34m    init: A pure function: ``params = init(rng, *a, **k)``[0m
[0;34m    apply: A pure function: ``out = apply(params, rng, *a, **k)``[0m
[0;34m  """[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m  [0;31m# Args: [Optional[PRNGKey], ...][0m[0;34m[0m
[0;34m[0m  [0minit[0m[0;34m:[0m [0mCallable[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m [0mhk[0m[0;34m.[0m[0mParams[0m[0;34m][0m[0;34m[0m
[0;34m[0m[0;34m[0m
[

In [12]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)

{'my_linear1': {'w': Array([[-0.30350363,  0.5123802 ],
       [ 0.08009142, -0.3163005 ],
       [ 0.6056666 ,  0.5820702 ]], dtype=float32), 'b': Array([1., 1.], dtype=float32)}}


In [13]:
sample_x = jnp.array([[1., 2., 3.]])
sample_x_2 = jnp.array([[4., 5., 6.], [7., 8., 9.]])

output_1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
# Outputs are identical for given inputs since the forward inference is non-stochastic.
output_2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

output_3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)

print(f'Output 1 : {output_1}')
print(f'Output 2 (same as output 1): {output_2}')
print(f'Output 3 : {output_3}')

Output 1 : [[2.6736789 2.6259897]]
Output 2 (same as output 1): [[2.6736789 2.6259897]]
Output 3 : [[3.820442 4.960439]
 [4.967205 7.294889]]


Inference without random key

The module that we built is inherently non-stochastic. In that case, passing a random key to the apply method seems redundant. Haiku offers another transformation hk.without_apply_rng which can be further wrapped around our hk.transform method.

In [20]:
# if deterministic module, wrap transform object with non-probabilistic
forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))

params = forward_without_rng.init(rng=rng_key, x=sample_x)
output = forward_without_rng.apply(x=sample_x, params=params)

print(f'Output without random key in forward pass \n {output_1}')

Output without random key in forward pass 
 [[2.6736789 2.6259897]]


In [24]:
# Change the parameters as you would when performing a gradient update
mutated_params = jax.tree_util.tree_map(lambda x: x+1., params)
print(f'Mutated params \n : {mutated_params}')
mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)
print(f'Output with mutated params \n {mutated_output}')

Mutated params 
 : {'my_linear1': {'b': Array([2., 2.], dtype=float32), 'w': Array([[0.69649637, 1.5123801 ],
       [1.0800915 , 0.6836995 ],
       [1.6056666 , 1.5820701 ]], dtype=float32)}}
Output with mutated params 
 [[9.673679 9.62599 ]]


### Stateful inference in Haiku

For some modules you might want to maintain and carry over the internal state across function calls. Here, we demonstrate a simple example, where we declare a state variable counter within our Haiku transformation which gets updated on each call to the function. Note that we didn’t explicitly instantiate this as a Haiku module (the same could be replicated as a hk module as shown earlier).

In [25]:
def stateful_f(x):
  counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
  hk.set_state("counter", counter + 1)
  output = x + multiplier * counter
  return output

stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
sample_x = jnp.array([[5., ]])
params, state = stateful_forward.init(x=sample_x, rng=rng_key)
print(f'Initial params:\n{params}\nInitial state:\n{state}')
print('##########')
for i in range(3):
  output, state = stateful_forward.apply(params, state, x=sample_x)
  print(f'After {i+1} iterations:\nOutput: {output}\nState: {state}')
  print('##########')


Initial params:
{'~': {'multiplier': Array([1.], dtype=float32)}}
Initial state:
{'~': {'counter': Array(1, dtype=int32)}}
##########
After 1 iterations:
Output: [[6.]]
State: {'~': {'counter': Array(2, dtype=int32)}}
##########
After 2 iterations:
Output: [[7.]]
State: {'~': {'counter': Array(3, dtype=int32)}}
##########
After 3 iterations:
Output: [[8.]]
State: {'~': {'counter': Array(4, dtype=int32)}}
##########


In [26]:
# See: https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules

class MyModuleCustom(hk.Module):
  def __init__(self, output_size=2, name='custom_linear'):
    super().__init__(name=name)
    self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
    self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')

  def __call__(self, x):
    return self._internal_linear_2(self._internal_linear_1(x))

def _custom_forward_fn(x):
  module = MyModuleCustom()
  return module(x)

custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
params

{'custom_linear/~/hk_internal_linear/~/linear_0': {'w': Array([[ 1.51595   , -0.23353337]], dtype=float32),
  'b': Array([0., 0.], dtype=float32)},
 'custom_linear/~/hk_internal_linear/~/linear_1': {'w': Array([[-0.22075887, -0.27375957,  0.5931483 ],
         [ 0.7818068 ,  0.72626334, -0.6860752 ]], dtype=float32),
  'b': Array([0., 0., 0.], dtype=float32)},
 'custom_linear/~/old_linear': {'w': Array([[ 0.28584382,  0.31626168],
         [ 0.2335775 , -0.4827032 ],
         [-0.14647584, -0.7185701 ]], dtype=float32),
  'b': Array([1., 1.], dtype=float32)}}

In [34]:
class HkRandom2(hk.Module):
  def __init__(self, rate=0.5):
    super().__init__()
    self.rate = rate

  def __call__(self, x):
    key1 = hk.next_rng_key()
    return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)


class HkRandomNest(hk.Module):
  def __init__(self, rate=0.5):
    super().__init__()
    self.rate = rate
    self._another_random_module = HkRandom2()

  def __call__(self, x):
    key2 = hk.next_rng_key()
    p1 = self._another_random_module(x)
    p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)
    print(f'Bernoullis are  : {p1, p2}')

# Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()
forward = hk.transform(lambda x: HkRandomNest()(x))

x = jnp.array(1.)
print("INIT:")
params = forward.init(rng_key, x=x)
print("APPLY:")
prediction = forward.apply(params, x=x, rng=rng_key)

INIT:
Bernoullis are  : (Array(True, dtype=bool), Array(False, dtype=bool))
APPLY:
Bernoullis are  : (Array(True, dtype=bool), Array(False, dtype=bool))


In [35]:
for _ in range(3):
  forward.apply(params, x=x, rng=rng_key)

Bernoullis are  : (Array(True, dtype=bool), Array(False, dtype=bool))
Bernoullis are  : (Array(True, dtype=bool), Array(False, dtype=bool))
Bernoullis are  : (Array(True, dtype=bool), Array(False, dtype=bool))


In [36]:
for _ in range(3):
  rng_key, apply_rng_key = jax.random.split(rng_key)
  forward.apply(params, x=x, rng=apply_rng_key)

Bernoullis are  : (Array(False, dtype=bool), Array(False, dtype=bool))
Bernoullis are  : (Array(True, dtype=bool), Array(False, dtype=bool))
Bernoullis are  : (Array(False, dtype=bool), Array(False, dtype=bool))


In [38]:
rng_sequence = hk.PRNGSequence(rng_key)
for _ in range(3):
  forward.apply(params, x=x, rng=next(rng_sequence))


Bernoullis are  : (Array(False, dtype=bool), Array(True, dtype=bool))
Bernoullis are  : (Array(False, dtype=bool), Array(False, dtype=bool))
Bernoullis are  : (Array(False, dtype=bool), Array(True, dtype=bool))


In [39]:
??jax.vmap

[0;31mSignature:[0m
[0mjax[0m[0;34m.[0m[0mvmap[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfun[0m[0;34m:[0m [0;34m'F'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0min_axes[0m[0;34m:[0m [0;34m'Union[int, Sequence[Any]]'[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mout_axes[0m[0;34m:[0m [0;34m'Any'[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maxis_name[0m[0;34m:[0m [0;34m'Optional[Hashable]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maxis_size[0m[0;34m:[0m [0;34m'Optional[int]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mspmd_axis_name[0m[0;34m:[0m [0;34m'Optional[Hashable]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'F'[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mvmap[0m[0;34m([0m[0mfun[0m[0;34m:[0m [0mF[0m[0;34m,[0m[0;34m[0m
[0;34m[0m      