In [9]:
import jax 
import optax
import haiku as hk
import numpy as np
import jax.numpy as jnp

In [10]:
INPUT_DIM = 5
HIDDEN_DIM = 2

SEED = 42

In [21]:
from typing import Optional

class ToyModel(hk.Module):
    def __init__(self, hidden_dim: int, name: Optional[str] = None):
        super().__init__(name=name)
        self.hidden_dim = hidden_dim

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        j = x.shape[-1] # Input dimension
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))

        # Get parameters
        W = hk.get_parameter('W', shape=[self.hidden_dim, j], dtype=x.dtype, init=w_init)
        b = hk.get_parameter('b', shape=[self.hidden_dim], dtype=x.dtype, init=jnp.ones)

        # Pass input through network
        x = x.T
        x = W.dot(x)
        x = W.T.dot(x)
        x = x + b

        return x

In [22]:
def toy_model_fn(x: jnp.ndarray) -> jnp.ndarray:
    model = ToyModel(hidden_dim=HIDDEN_DIM)
    return model(x)

toy_model = hk.without_apply_rng(hk.transform(toy_model_fn))

In [25]:
rng_key = jax.random.PRNGKey(SEED)
dummy_x = jnp.array([[1., 2., 3., 4., 5.]])

params = toy_model.init(rng_key, dummy_x)

TypeError: Incompatible shapes for dot: got (2, 5) and (1, 5).

In [17]:
dummy_x.shape

(1, 5)