**From Haiku documentation**

In [1]:
pip install git+https://github.com/deepmind/dm-haiku

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-bb0j1585
  Running command git clone --filter=blob:none --quiet https://github.com/deepmind/dm-haiku /tmp/pip-req-build-bb0j1585
  Resolved https://github.com/deepmind/dm-haiku to commit a7b7e73dae840153ecd828e97a64b6a875b168f7
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jmp>=0.0.2 (from dm-haiku==0.0.13.dev0)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone
  Created wheel for dm-haiku: filename=dm_haiku-0.0.13.dev0-py3-none-any.whl size=373915 sha256=5717c5aab50a0c52ea8e142bc7faa4bce03fa180efde8b8ccf70c60d47911f47
  Stored in directory: /tmp/pip-ephem-wheel-cache-jtz30r4p/wheels/b1/df/f1/a357fa8f00c36052bdae1e1fd363650c0bd1e8c3959487b6fb
Successfully built dm-haiku
Installing collected packages: jmp, dm-haiku
Successfully installed dm-h

# Haiku Basics

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

## A first example with `hk.transform`

 A linear module with weights and biases with custom initializations.

### Define module

In [3]:
class MyLinear1(hk.Module):

    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, x):
        j, k = x.shape[-1], self.output_size
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
        w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
        b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)

        return jnp.dot(x, w) + b

### Pure function transformation

In [4]:
def _forward_fn_linear1(x):
    module = MyLinear1(output_size=2)
    return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

We see that the forward wrapper object now contains two methods, `init` and `apply`, that are used to initialize the variables and do forward inference on the module.


In [5]:
forward_linear1

Transformed(init=<function without_state.<locals>.init_fn at 0x7cf8241d9cf0>, apply=<function without_state.<locals>.apply_fn at 0x7cf8241d9c60>)

#### `Init` fn
Calling the init method will initialize the parameters of the network and return them. The init method takes a `jax.random.PRNGKey` and a sample input (usually just some dummy values to tell the networks about the expected shapes).


In [8]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)

{'my_linear1': {'w': Array([[-0.30350363,  0.5123802 ],
       [ 0.08009142, -0.3163005 ],
       [ 0.6056666 ,  0.5820702 ]], dtype=float32), 'b': Array([1., 1.], dtype=float32)}}
(1, 3)


#### `apply` fn
use the params to apply the forward function to some inputs.

In [11]:
sample_x = jnp.array([[1., 2., 3.]])
sample_x_2 = jnp.array([[4., 5., 6.],
                        [7., 8., 9.]])

output1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
# Outputs are identical for given inputs since the forward inference is non-stochastic.
output2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

output3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)

print(output1)
print(output2)  # same as ouptut1
print(output3)

[[2.6736789 2.6259897]]
[[2.6736789 2.6259897]]
[[3.820442 4.960439]
 [4.967205 7.294889]]


### Inference without random key
In some cases, the module that we built is inherently non-stochastic. Hence, passing a random key to apply method seems redundant. Haiku offers another transformation `hk.without_apply_rng` which can be further wrapped around our `hk.transform` method.

In [12]:
forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))

params = forward_without_rng.init(rng=rng_key, x=sample_x)
output = forward_without_rng.apply(params=params, x=sample_x)

print(output)

[[2.6736789 2.6259897]]


We can also mutate the parameters and then do forward inference to generate a different output for the same inputs. This is what is done to apply gradient descent to our parameters while learning.

In [21]:
mutated_params = jax.tree_util.tree_map(lambda x: x+1., params)
print(f'Mutated params \n : {mutated_params}')
mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)
print(f'Output with mutated params \n {mutated_output}')

Mutated params 
 : {'my_linear1': {'b': Array([2., 2.], dtype=float32), 'w': Array([[0.69649637, 1.5123801 ],
       [1.0800915 , 0.6836995 ],
       [1.6056666 , 1.5820701 ]], dtype=float32)}}
Output with mutated params 
 [[9.673679 9.62599 ]]


## Stateful Inference in Haiku
For some modules you might want to maintain and carry over the internal state across function calls.

In [22]:
def stateful_f(x):
    counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
    multiplier = hk.get_parameter("multiplier", shape=[1, ], dtype=x.dtype, init=jnp.ones)
    hk.set_state("counter", counter + 1)
    output = x + multiplier * counter

    return output

stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
sample_x = jnp.array([[5., ]])
params, state = stateful_forward.init(x=sample_x, rng=rng_key)

print(f'Initial params:\n{params}\nInitial state:\n{state}')
print('##########')
for i in range(3):
  output, state = stateful_forward.apply(params, state, x=sample_x)
  print(f'After {i+1} iterations:\nOutput: {output}\nState: {state}')
  print('##########')

Initial params:
{'~': {'multiplier': Array([1.], dtype=float32)}}
Initial state:
{'~': {'counter': Array(1, dtype=int32)}}
##########
After 1 iterations:
Output: [[6.]]
State: {'~': {'counter': Array(2, dtype=int32)}}
##########
After 2 iterations:
Output: [[7.]]
State: {'~': {'counter': Array(3, dtype=int32)}}
##########
After 3 iterations:
Output: [[8.]]
State: {'~': {'counter': Array(4, dtype=int32)}}
##########


## Built-in Haiku nets and nested modules

In [23]:
class MyModuleCustom(hk.Module):

    def __init__(self, output_size=2, name='custom_linear'):
        super().__init__(name=name)
        self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
        self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')

    def __call__(self, x):
        x = self._internal_linear_1(x)
        x = self._internal_linear_2(x)

        return x

def _custom_forward_fn(x):
    module = MyModuleCustom()

    return module(x)

In [27]:
custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
print(jax.tree.map(lambda x: x.shape, params))
params

{'custom_linear/~/hk_internal_linear/~/linear_0': {'b': (2,), 'w': (1, 2)}, 'custom_linear/~/hk_internal_linear/~/linear_1': {'b': (3,), 'w': (2, 3)}, 'custom_linear/~/old_linear': {'b': (2,), 'w': (3, 2)}}


{'custom_linear/~/hk_internal_linear/~/linear_0': {'w': Array([[ 1.51595   , -0.23353337]], dtype=float32),
  'b': Array([0., 0.], dtype=float32)},
 'custom_linear/~/hk_internal_linear/~/linear_1': {'w': Array([[-0.22075887, -0.27375957,  0.5931483 ],
         [ 0.7818068 ,  0.72626334, -0.6860752 ]], dtype=float32),
  'b': Array([0., 0., 0.], dtype=float32)},
 'custom_linear/~/old_linear': {'w': Array([[ 0.28584382,  0.31626168],
         [ 0.2335775 , -0.4827032 ],
         [-0.14647584, -0.7185701 ]], dtype=float32),
  'b': Array([1., 1.], dtype=float32)}}

## RNG Keys with `hk.next_rng_key()`

In [28]:
class HkRandom2(hk.Module):

    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate

    def __call__(self, x):
        key1 = hk.next_rng_key()
        return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)


class HkRandomNest(hk.Module):

    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate
        self._another_random_module = HkRandom2()

    def __call__(self, x):
        key2 = hk.next_rng_key()
        p1 = self._another_random_module(x)
        p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)

        print(f"Bernoullis are: {p1, p2}")

# Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()
forward = hk.transform(lambda x: HkRandomNest()(x))

x = jnp.array(1.)
print("INIT:")
params = forward.init(rng_key, x=x)
print("APPLY:")
prediction = forward.apply(params, x=x, rng=rng_key)

INIT:
Bernoullis are: (Array(True, dtype=bool), Array(False, dtype=bool))
APPLY:
Bernoullis are: (Array(True, dtype=bool), Array(False, dtype=bool))


**Note that this means that passing the same random key to multiple calls of the apply function will generate the same stochastic results!**

In [32]:
for _ in range(3):
  forward.apply(params, x=x, rng=rng_key)

Bernoullis are: (Array(True, dtype=bool), Array(False, dtype=bool))
Bernoullis are: (Array(True, dtype=bool), Array(False, dtype=bool))
Bernoullis are: (Array(True, dtype=bool), Array(False, dtype=bool))


In [33]:
# solution 1:
for _ in range(3):
  rng_key, apply_rng_key = jax.random.split(rng_key)
  forward.apply(params, x=x, rng=apply_rng_key)

Bernoullis are: (Array(False, dtype=bool), Array(False, dtype=bool))
Bernoullis are: (Array(True, dtype=bool), Array(False, dtype=bool))
Bernoullis are: (Array(False, dtype=bool), Array(False, dtype=bool))


In [34]:
# solution 2:
rng_sequence = hk.PRNGSequence(rng_key)
for _ in range(3):
  forward.apply(params, x=x, rng=next(rng_sequence))

Bernoullis are: (Array(False, dtype=bool), Array(True, dtype=bool))
Bernoullis are: (Array(False, dtype=bool), Array(False, dtype=bool))
Bernoullis are: (Array(False, dtype=bool), Array(True, dtype=bool))


# Limitations of Nesting JAX Functions and Haiku Modules

A JAX transform inside of a hk.transform is likely to transform a side effecting function, which will result in an `UnexpectedTracerError`. This page describes two ways to get around this.

Once a Haiku network has been transformed to a pair of pure functions using hk.transform, it’s possible to freely combine these with any JAX transformations like `jax.jit`, `jax.grad`, `jax.lax.scan` and so on.

In [38]:
# example
def net(x): # inside of a hk.transform, this is still side-effecting
    w = hk.get_parameter("w", (2, 2), init=jnp.ones)

    return w @ x

def eval_shape_net(x):
    output_shape = jax.eval_shape(net, x)   # eval_shape on side-effecting function
    return net(x)

init, _ = hk.transform(eval_shape_net)

try:
    init(jax.random.PRNGKey(0), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError as e:
    print(e)

An UnexpectedTracerError was raised while inside a Haiku transformed function (see error above).
Hint: are you using a JAX transform or JAX control-flow function (jax.vmap/jax.lax.scan/...) inside a Haiku transform? You might want to use the Haiku version of the transform instead (hk.vmap/hk.scan/...).
See https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html on why you can't use JAX transforms inside a Haiku module.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError


In [43]:
# solution
def net(w, x):
    return w @ x

def eval_shape_net(x):
    w = hk.get_parameter("w", (3, 2), init=jnp.ones)
    output_shape = jax.eval_shape(net, w, x)
    return net(w, x)

key = jax.random.PRNGKey(0)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x=x)
apply(params, key, x)

Array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)

In [45]:
# however
def eval_shape_net(x):
    net = hk.nets.MLP([300, 100])
    output_shape = jax.eval_shape(net, x)
    return output_shape, net(x)

init, _ = hk.transform(eval_shape_net)

try:
    init(jax.random.PRNGKey(0), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError as e:
    print(e)

## Using `hk.lift`



In [47]:
def eval_shape_net(x):
    net = hk.nets.MLP([300, 100])       # still side-effecting
    init, apply = hk.transform(net)     # nested transform
    params = hk.lift(init, name="inner")(hk.next_rng_key(), x)    # register params in outer module scope with name "inner"

    output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x)  # apply is a functionaly pure function and can be transformed!
    out = net(x)

    return out, output_shape

x = jnp.ones((100, 100))

init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(0), x=x)
apply(params, jax.random.PRNGKey(0), x)

jax.tree.map(lambda x: x.shape, params)

{'inner/mlp/~/linear_0': {'b': (300,), 'w': (100, 300)},
 'inner/mlp/~/linear_1': {'b': (100,), 'w': (300, 100)},
 'mlp/~/linear_0': {'b': (300,), 'w': (100, 300)},
 'mlp/~/linear_1': {'b': (100,), 'w': (300, 100)}}

## Using Haiku versions of JAX transforms

In [48]:
def eval_shape_net(x):
    net = hk.nets.MLP([300, 100])
    output_shape = hk.eval_shape(net, x)
    out = net(x)

    return out, output_shape

init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))

In [50]:
out[1]

ShapeDtypeStruct(shape=(100, 100), dtype=float32)

In [60]:
def test_get_parameter_fn():
    w1 = hk.get_parameter("w", [], init=jnp.zeros)
    w2 = hk.get_parameter("w", [], init=jnp.zeros)

    if w1 is w2:
        print("w1 is w2")

    return 0

test_get_parameter = hk.transform(test_get_parameter_fn)
params = test_get_parameter.init(jax.random.PRNGKey(0))
print(jax.tree.map(lambda x: x.shape, params))
test_get_parameter.apply(rng=jax.random.PRNGKey(0), params=params)

w1 is w2
{'~': {'w': ()}}
w1 is w2


0

# Flax Interop

## Flax inside Haiku
```python
haiku.experimental.flax.lift(mod, *, name)
```

Lifts a flax `nn.Module` into a Haiku transformed function.

For a Flax Module (e.g. `mod = nn.Dense(10)`), `mod = lift(mod)` allows you to run the call method of the module as if the module was a regular Haiku module.

Parameters and state from the Flax module are registered with Haiku and become part of the params/state dictionaries (as returned from `init`/`apply`).


In [95]:
import flax
from transformers import BertConfig, BertModel, FlaxAutoModel, AutoTokenizer

In [96]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [97]:
def f(tokens, pretrained_path="bert-base-uncased"):
    # Create and "lift" a Flax module
    pretrained_model = hk.experimental.flax.lift(FlaxAutoModel.from_pretrained(pretrained_path), name='bert')
    embeddings = pretrained_model(*tokens)

    return embeddings

In [98]:
pretrained_model = FlaxAutoModel.from_pretrained("bert-base-uncased")
embeddings = pretrained_model(**tokens)

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [99]:
embeddings[0].shape

(1, 10, 768)

In [100]:
# prompt: load pretrained_model from "bert-base-uncased" then convert to haiku module

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import flax
from transformers import BertConfig, BertModel, FlaxAutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def f(tokens, pretrained_path="bert-base-uncased"):
    # Create and "lift" a Flax module
    pretrained_model = hk.experimental.flax.lift(FlaxAutoModel.from_pretrained(pretrained_path), name='bert')
    embeddings = pretrained_model(*tokens)

    return embeddings

pretrained_model = FlaxAutoModel.from_pretrained("bert-base-uncased")
# haiku_module = hk.experimental.flax.lift(pretrained_model, name='bert')


Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [101]:
sentences = ["I am a student at UIT."]
tokens = tokenizer(sentences, return_tensors='jax')
tokens

{'input_ids': Array([[  101,  1045,  2572,  1037,  3076,  2012, 21318,  2102,  1012,
          102]], dtype=int32), 'token_type_ids': Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}