In [None]:
import jax
import jax.numpy as jnp
import haiku as hk
from functools import partial
import math
import numpy as np
from numpy.random import random
import optax
import torch
from typing import Any, Sequence

# Playing with Jax/Haiku/Optax

I found Sabrina Mielke’s post very useful: [From PyTorch to JAX: towards neural net frameworks that purify stateful code](https://sjmielke.com/jax-purify.htm)

In [None]:
# generate random input data
x = random((30,2)).astype('float32')
# generate labels corresponding to input data x
y = np.dot(x, [2., -3.]) + 1.
y = np.expand_dims(y, axis=1).astype('float32')
w_source = np.array([2., -3.])
b_source  = np.array([1.])

# Setup

Our model is:
$$
y_t = 2x^1_t-3x^2_t+1, \quad t\in\{1,\dots,30\}
$$

Our task is given the 'observations' $(x_t,y_t)_{t\in\{1,\dots,30\}}$ to recover the weights $w^1=2, w^2=-3$ and the bias $b = 1$.

In order to do so, we will solve the following optimization problem:
$$
\underset{w^1,w^2,b}{\operatorname{argmin}} \sum_{t=1}^{30} \left(w^1x^1_t+w^2x^2_t+b-y_t\right)^2
$$

In [None]:
# randomly initialize learnable weights and bias
w_init = random(2)
b_init = random(1)

w = w_init
b = b_init
print("initial values of the parameters:", w, b )

dtype = torch.FloatTensor
w_init_t = torch.from_numpy(w_init).type(dtype)
b_init_t = torch.from_numpy(b_init).type(dtype)
x_t = torch.from_numpy(x).type(dtype)
y_t = torch.from_numpy(y).type(dtype)

learning_rate = 1e-2

# Pytorch version

In [None]:
model = torch.nn.Sequential(torch.nn.Linear(2, 1),)

for m in model.children():
    m.weight.data = w_init_t.clone().unsqueeze(0)
    m.bias.data = b_init_t.clone()

loss_fn = torch.nn.MSELoss(reduction='sum')

model.train()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for epoch in range(10):
    y_pred = model(x_t)
    loss = loss_fn(y_pred, y_t)
    print("progress:", "epoch:", epoch, "loss",loss.item())
    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
# After training
print("estimation of the parameters:")
for param in model.parameters():
    print(param)

# Jax linear layer with Haiku

Playing with a linear layer... you first define your pyhton function and 'jaxify' it into a pure function thanks to the `hk.transfrom`.

In [None]:
def _linear(x, config):
    return hk.Linear(config.size_out)(x)

class toy_config:
    size_out = 5

linear = hk.without_apply_rng(hk.transform(lambda x: _linear(x, config=toy_config)))

In [None]:
rng_key = jax.random.PRNGKey(42)
x_dummy = jax.random.normal(key=rng_key, shape=(1,7))

In [None]:
params = linear.init(rng=rng_key, x=x_dummy)

In [None]:
params

In [None]:
x_in = jax.random.normal(key=rng_key, shape=(10,3,7))
out = linear.apply(x=x_in, params= params)

In [None]:
out.shape

# Haiku: Linear layer with initialization

I want to check my results by comparing them to the Pytorch implementation so I need to start with the same initial values for the parameters.

In [None]:
# inspired from https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/initializers.py#L51#L63
class Init_jnparray(hk.initializers.Initializer):
    def __init__(self, w: jnp.ndarray):
        self.w = w

    def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:
        if self.w.shape != tuple(shape):
            raise ValueError('Error in shape! w:', self.w.shape,' and shape:', shape)
        return self.w.astype(dtype)

In [None]:
class config:
    size_out = 1
    w_source = jnp.array([w_init]).swapaxes(1,0)
    b_source = jnp.array(b_init)


def _linear(x, config):
    return hk.Linear(config.size_out,w_init=Init_jnparray(config.w_source), b_init=Init_jnparray(config.b_source))(x)

Note taht noting is random here (no random initialization).

In [None]:
x_dummy = jax.random.normal(key=rng_key, shape=(1,2))
linear = hk.without_apply_rng(hk.transform(lambda x: _linear(x, config=config)))
params = linear.init(x=x_dummy,rng=None)

In [None]:
params

In [None]:
out = linear.apply(x=x,params=params)

In [None]:
out.shape

# Computing loss and gradients

In [None]:
def mse_loss(y_pred, y_t):
    return jax.lax.integer_pow(y_pred - y_t,2).sum()

mse_loss(out,y)

In [None]:
def loss_fn(x_in, y_t, config):
    return mse_loss(_linear(x=x_in, config=config),y_t)

In [None]:
hk_loss_fn = hk.without_apply_rng(hk.transform(partial(loss_fn, config=config)))

In [None]:
params = hk_loss_fn.init(rng=rng_key, x_in=x,y_t=y)

In [None]:
params

Redefining the loss as the `apply` method of the transformed haiku loss and then the pytorch `backward` operation is done in jax with [`value_and_grad`](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#evaluate-a-function-and-its-gradient-using-value-and-grad).

In [None]:
loss_fn = hk_loss_fn.apply
loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)

In [None]:
loss

In [None]:
grads

# Optimizer

Now things are rather easy to understand

In [None]:
optimizer = optax.sgd(learning_rate=1e-2)

In [None]:
opt_state = optimizer.init(params)
for epoch in range(10):
    loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)
    print("progress:", "epoch:", epoch, "loss",loss)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
# After training
print("estimation of the parameters:")
print(params)

# Jax/Haiku/Optax full code

In [None]:
class config:
    size_out = 1
    w_source = jnp.array([w_init]).swapaxes(1,0)
    b_source = jnp.array(b_init)

class Init_jnparray(hk.initializers.Initializer):
    def __init__(self, w: jnp.ndarray):
        self.w = w

    def __call__(self, shape: Sequence[int], dtype: Any) -> jnp.ndarray:
        if self.w.shape != tuple(shape):
            raise ValueError('Error in shape! w:', self.w.shape,' and shape:', shape)
        return self.w.astype(dtype)
    
def _linear(x, config):
    return hk.Linear(config.size_out,w_init=Init_jnparray(config.w_source), b_init=Init_jnparray(config.b_source))(x)

def mse_loss(y_pred, y_t):
    return jax.lax.integer_pow(y_pred - y_t,2).sum()

def loss_fn(x_in, y_t, config):
    return mse_loss(_linear(x=x_in, config=config),y_t)

hk_loss_fn = hk.without_apply_rng(hk.transform(partial(loss_fn, config=config)))
params = hk_loss_fn.init(x_in=x,y_t=y,rng=None)
loss_fn = hk_loss_fn.apply

optimizer = optax.sgd(learning_rate=1e-2)

opt_state = optimizer.init(params)
for epoch in range(10):
    loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)
    print("progress:", "epoch:", epoch, "loss",loss)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
# After training
print("estimation of the parameters:")
print(params)