# JAX Basics

In [None]:
import jax
from jax import numpy as jnp
import numpy as np


x = np.array([1.0, 2.0])[None].T
y = jnp.array([1.0, 2.0])[None].T

# numpy arrays are transformed to JAX tensors automatically at operation execution (no to_tensor, from_numpy mumbo-jumbo)
print("Dot product:", x.T@y)
x.T@y

In [None]:
# y[0] = 3.0
y.at[0].set(3.0)

In [None]:
# we should not do this
class Func:
    def __init__(self):
        self.x = 0
    def __call__(self, x):
        res = jnp.square(x+self.x)
        self.x += x
        return res


func = Func()


In [None]:
jax.grad(func)(2.0)

# The JIT

Stands for "Just in Time Compilation". All JAX operations are implemented in terms of operations in XLA – the Accelerated Linear Algebra compiler.

In [None]:

M = np.random.uniform(1.0, 2.0, (2000, 1000))
M_jax = jnp.array(M)

def numpy_function(M):
    return np.log((M @ M.T)**2) - M @ M.T

def function(M):
    return jnp.log((M @ M.T)**2) - M @ M.T

@jax.jit
def jit_function(M):
    return jnp.log((M @ M.T)**2) - M @ M.T


In [None]:
print("Numpy function:")
%timeit  numpy_function(M)
print("JAX function:")
%timeit  function(M_jax)
print("JAX jit function:")
%timeit  jit_function(M_jax)

# Vectorization / Parallelization

`jax.vmap` and `jax.pmap`

In [None]:
# vmap usage
# say we have a "complicated function" that we want to apply row-wise, ie. over axis=0

M = np.random.uniform(1, 10, (200, 2))
func = lambda x: x[0]**2 > np.exp(x[1])

#@jax.jit
def naive(M):
    return jnp.stack([func(x) for x in M])

#@jax.jit
def with_vmap(M):
    return jax.vmap(func)(M)


@jax.jit
def with_pmap(M): # this is just for the sake of example, it goes across devices
    return jax.pmap(func)(M)


print("Naive:")
%timeit -n 50 naive(M)
print("With vmap:")
%timeit -n 50 with_vmap(M)

# Calculating Gradients

Using the `jax.grad` function

In [None]:
import functools
from matplotlib import pyplot as plt
# n-th order polynomial

n = 6
roots = np.random.uniform(-3, 3, n).astype(np.float64)
def nth_order_polynomial(x, roots):
    y = 1
    for r in roots:
        y = y*(x-r)
    return y


poly = jax.jit(functools.partial(nth_order_polynomial, roots=roots))

x = jnp.linspace(-2, 2, 100)

y = poly(x)


plt.plot(x, y)

In [None]:
# we can trace the computation of the polynomial
jax.make_jaxpr(poly)(2.0)

In [None]:
# this is how we compute gradients / derivatives
jax.grad(poly)(0.)
jax.value_and_grad(poly)(0.)


In [None]:
# n-th order derivatives of this function?
grad_func = poly
nth_order_grads = []
for i in range(n):

    grad_func = jax.grad(grad_func)
    y = jax.vmap(grad_func)(x)
    plt.plot(x, y, label=f'n={i+1}')
plt.legend()

# JAX Can Differentiate with Respect to (almost) Anything

As long as we register the type. But standard Python containers are supported out of the box. The type of datastructures that JAX can handle are called `Pytrees`.

`Pytrees` by default include compositions of standard Python containers, `list`, `tuple`, `dict`. But we can also add our custom `Pytree` nodes.






In [None]:
def square(data):
    return data['x']**2 + data['y']**2

g = jax.grad(square)({"x": 1.0, "y": 2.0})

print(f"Gradient 1: {g}")


def square(data):
    return data[0]**2 + data[1] ** 2

g = jax.grad(square)([1.0, 2.0])
print(f"Gradient 2: {g}")


from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float


# now we register the Point class as a Pytree node, so that JAX knows how to deal with it
from jax.tree_util import register_pytree_node
register_pytree_node(
    Point, 
    lambda x: ([x.x, x.y], None), #unpacking
    lambda d, x: Point(*x) # packing
)

def square(data):
    return data.x**2 + data.y ** 2

g = jax.grad(square)(Point(1.0, 2.0))
print(f"Gradient 3: {g}")



## Useful Pytree Operations

In [None]:
# further useful operations on Pytrees

def mul(x, y):
    return x*y


tree_a = {"x": 1.0, "y": 2.0}
tree_b = {"x": 2.0, "y": 4.0}

from jax.tree_util import tree_multimap, tree_map


# equivalent of python `map` but applied to pytrees

res = tree_map(lambda x: x**2, tree_a)
print("Result of map:", res)

# constructing operations between multiple trees
res = tree_multimap(lambda x,y,z: x+y+z, tree_a, tree_b, tree_b)
print("Result of multimap:", res)

# Linear Regression Using SGD

In [None]:
# dataset for linear regression
def generate_data(d, N):
    x = np.random.uniform(-1,1, (N, d))
    theta = np.random.uniform(-1,1, (d,1))
    y = x @ theta
    theta_ = np.random.uniform(-1,1, (d,1))
    return x, y, theta, theta_

x, y, theta, theta_ = generate_data(20, 500)

def vis_lines(theta, theta_):
    x = np.linspace(-2, 2, 100)
    y = theta[0, 0]*x + theta[1,0]
    y_ = theta_[0,0]*x+ theta_[1,0]
    plt.plot(x, y, label='gt')
    plt.plot(x, y_, label='pred')
    plt.legend()



In [None]:
@jax.jit
def predict(x, theta):
    return x@theta

def mse(x, theta, y):
    y_ = predict(x, theta)
    return ((y-y_)**2).mean()

grad_func = jax.grad(mse, argnums=[1]) # returns a function which takes the same arguments as the wrapped one


# returns tuple of gradients with respect to arguments of function
grad, = grad_func(x, theta, y)

# for practicality, this is also available
loss, grad = jax.value_and_grad(mse, argnums=1)(x,theta_, y)
print(f"MSE loss: {loss}")



vis_lines(theta, theta_)

In [None]:
# stupid loop to optimize our model
for _ in range(50):
    loss, grad = jax.value_and_grad(mse, argnums=1)(x,theta_, y)
    theta_ -= 0.01*grad 

print(f"MSE: {loss}")
vis_lines(theta, theta_)

# Jacobians, Hessians

In [None]:
# what if we want to get a Jacobian? x is d-dimensional
J = jax.jacobian(predict, argnums=0)(x[0], theta)
J.shape

In [None]:
# what about the Hessian?
H, = jax.jacfwd(jax.jacrev(predict), argnums=0)(x[0], theta)
print("Hessian shape:", H.shape)
H, = jax.hessian(predict, argnums=0)(x[0], theta)

# 2nd Order Optimization!

In [None]:
from matplotlib import pyplot as plt
d = 10

@jax.jit
def f(x):
    return (x**4).sum(-1)


x = np.linspace(-10, 10, 100)[:, None]
plt.plot(x, f(x))

# what can we say about the hessian of this function?

In [None]:
x0 = jnp.array([10.0]*d)
# how does the hessian look like?
jax.hessian(f)(x0).shape

In [None]:
# lets try to minimize it via gradient descent
x_ = x0
trajx = [x_[0]]
trajy = [f(x_)]

err = 1e-5
for i in range(300):
    g = jax.grad(f)(x_)
    x_ -= 1/2* g
    trajx.append(x_[0])
    trajy.append(f(x_))
    if trajy[-1] < err:
        break
print(f"Converged in {i} steps, {err}, {x_}")
plt.plot(x, f(x))
plt.plot(trajx, trajy)
plt.scatter(trajx, trajy, color='red')

In [None]:
# lets try to minimize it via gradient descent
x_ = x0
trajx = [x0[0]]
trajy = [f(x0)]

err = 1e-5
for i in range(100):
    g = jax.grad(f)(x_)
    H = jax.hessian(f)(x_)
    x_ -= (jnp.linalg.inv(H) @ g).flatten()
    trajx.append(x_[0])
    trajy.append(f(x_))
    if trajy[-1] < err:
        break
print(f"Converged in {i} steps, {err}, {x_}")
plt.plot(x, f(x))
plt.plot(trajx, trajy)
plt.scatter(trajx, trajy, color='red')

# Note on Gradient Calculation

<p style="font-size:20px">

Some options:
* finite differences $\frac{df}{dx} = \lim_{h \mapsto 0} \frac{f(x+h)-f(x)}{h}$
* symbolic
* automatic differentiation - most of deep learning
<p>


In [None]:
@jax.jit
def square(x):
    return x**2

def finite_differences(f, h):
    def func(x):
        return  (f(x+h)-f(x))/h
    return func


finite_differences(square, 1e-5)(2.0)

# Autodiff Forward vs. Reverse Mode

<p style="font-size:16px">

**Forward Mode**: augments the outputs of the forward pass with their derivatives in a (primal, tangent) tuple $(x, \dot x)$. This is preferred in the case where the number of inputs is much smaller than the number of outputs, in practice we compute a Jacobian-vector product.


**Reverse Mode**: comes in two stages. First we make a forward pass through our computation graph which is followed by computation of partial derivatives with respect to intermediate variables (adjoints). Backpropagation is a special case of reverse mode autodiff. Vector-Jacobian product.

**Reverse on Forward**: hybrid, for example computing Hessian.

</p>

In [None]:
# back to our Hessian example, let's time it with different ways of computing the gradient

def func(x):
    return jnp.sum(jnp.sin(jnp.log(jnp.exp(jnp.sin(x) +jnp.cos(x)) )+ jnp.cos(jnp.exp(jnp.sin(x)))))

print("Only reverse mode autodiff:")
%timeit jax.jacrev(jax.jacrev(func), argnums=0)(x)
print("Hybrid autodiff forward then reverse:")
%timeit jax.jacrev(jax.jacfwd(func), argnums=0)(x)
print("Hybrid autodiff reverse mode then forward mode:")
%timeit jax.jacfwd(jax.jacrev(func), argnums=0)(x)
print("JAX hessian func")
%timeit jax.hessian(func, argnums=0)(x)


# JAX Random Numbers

Random numbers in JAX are annoying. There is no stateful random number generator such as we have in `numpy`, but we need to pass around a `key` that we split with `jax.random.split`. This is also where JAX syntax for distributions and numpy syntax differs considerably.

In [None]:
random_seed = 123
key = jax.random.PRNGKey(random_seed) 

rngs = jax.random.split(key, 10)


print("These are 10 random numbers")
print(jnp.array([jax.random.normal(k) for k in rngs]))
print("These are the same numbers")
print(jnp.array([jax.random.normal(k) for k in rngs]))


key, _ = jax.random.split(rngs[-1])

normal_sample = jax.random.normal(key)

keys = jax.random.split(key, 10)

print("These are samples from normal distribution")
print()



# Here come the neural networks...

In [None]:
# say we use an ensemble neural network (this cause a bit of pain for me and Sebastian to implement in PyTorch)
from jax.tree_util import tree_flatten, tree_unflatten

def tree_stack(trees):
    """Takes a list of trees and stacks every corresponding leaf.
    For example, given two trees ((a, b), c) and ((a', b'), c'), returns
    ((stack(a, a'), stack(b, b')), stack(c, c')).
    Useful for turning a list of objects into something you can feed to a
    vmapped function.
    """
    leaves_list = []
    treedef_list = []
    for tree in trees:
        leaves, treedef = tree_flatten(tree)
        leaves_list.append(leaves)
        treedef_list.append(treedef)

    grouped_leaves = zip(*leaves_list)
    result_leaves = [jnp.stack(l) for l in grouped_leaves]
    return treedef_list[0].unflatten(result_leaves)


def tree_unstack(tree):
    """Takes a tree and turns it into a list of trees. Inverse of tree_stack.
    For example, given a tree ((a, b), c), where a, b, and c all have first
    dimension k, will make k trees
    [((a[0], b[0]), c[0]), ..., ((a[k], b[k]), c[k])]
    Useful for turning the output of a vmapped function into normal objects.
    """
    leaves, treedef = tree_flatten(tree)
    n_trees = leaves[0].shape[0]
    new_leaves = [[] for _ in range(n_trees)]
    for leaf in leaves:
        for i in range(n_trees):
            new_leaves[i].append(leaf[i])
    new_trees = [treedef.unflatten(l) for l in new_leaves]
    return new_trees



def get_nn_params():
        return [
            (np.random.uniform(-1,1, (10, 512)), np.random.uniform(-1,1, (512, 1))),
            (np.random.uniform(-1,1, (512, 256)), np.random.uniform(-1,1, (256, 1))),
            (np.random.uniform(-1,1, (256, 2)),  np.random.uniform(-1,1, (2,1)))
        ]


@jax.jit
def forward(x, theta):
    for w, b in theta:
        x = jax.nn.relu(x@w) + b.T
    return x



params = get_nn_params()


out = forward(x, params)
out.shape



In [None]:
lots_of_params = [get_nn_params() for _ in range(20)]
# how do we parallelize this?

stacked_tree = tree_stack(lots_of_params)

def seq_ensemble_forward(x, trees):
    return jnp.stack([forward(x,tree) for tree in trees])


seq_ensemble_forward = jax.jit(seq_ensemble_forward)

vmap_ensemble_forward = jax.vmap(forward, in_axes=[None,[(0, 0)]*len(lots_of_params[0])])
vmap_ensemble_forward = jax.jit(vmap_ensemble_forward)
# seems to still not be better than linear speedup when stacking the matrices into tensors, but there is some improvement (possibly limited by hardware)

In [None]:
# what about simple stacking?

@jax.jit
def stacked_forward(x, theta):
    for w, b in theta:
        x = jax.nn.relu(x@w) + b.transpose(0,2,1)
    return x


In [None]:

x = np.random.uniform(-1,1, (1024,10))

vmap_ensemble_forward(x, stacked_tree)
stacked_forward(x, stacked_tree)
seq_ensemble_forward(x, lots_of_params)

%timeit stacked_forward(x, stacked_tree).block_until_ready() 
%timeit vmap_ensemble_forward(x, stacked_tree).block_until_ready() 
%timeit seq_ensemble_forward(x, lots_of_params).block_until_ready() 




# Flax: Making Things More Simple

Alternatives: `Haiku`, `Stax`, `Objax`

In [None]:
import flax
import jax
from jax import numpy as jnp
import numpy as np
from typing import Optional
# for details about this, read about python >=3.7 dataclasses

class MLP(flax.linen.Module):
    # here, ordering is preserved
    num_hidden: int
    hidden_size: int
    outputs: int
    act_function: Optional[str] = 'relu'

    def setup(self):
        self.layers = [flax.linen.Dense(features=64) for _ in range(self.num_hidden)]
        self.last_layer = flax.linen.Dense(self.outputs)

    def __call__(self, x):
        act_function = getattr(flax.linen, self.act_function)
        for layer in self.layers[:-1]:
            x = act_function(layer(x))
        # don't apply act in last layer
        x = self.layers[-1](x)
        return x


key = jax.random.PRNGKey(0)

X = np.random.randn(128, 10)

model = MLP(2, 64, 2)

# we need to call the init function, takes a batch and key
key, _ = jax.random.split(key)
params = model.init(key, X)


# we need to call the apply function for the forward pass, which also takes the model parameters
y_ = model.apply(params, X)
y_.shape

In [None]:
import flax
from typing import Optional
# for details about this, read about python 3.7 datamodule class

class MLPCompact(flax.linen.Module):
    # here, ordering is preserved
    num_hidden: int
    hidden_size: int
    outputs: int
    act_function: Optional[str] = 'relu'
    
    @flax.linen.compact
    def __call__(self, x):
        act_function = getattr(flax.linen, self.act_function)
        for _ in range(self.num_hidden):
            x = flax.linen.Dense(self.hidden_size)(x)
            x = act_function(x)
        # don't apply act in last layer
        x = flax.linen.Dense(self.outputs)(x)
        return x


key = jax.random.PRNGKey(0)
X = np.random.randn(128, 10)

model = MLPCompact(2, 64, 2)

# we need to call the init function, takes a batch and key to obtain sampled initial parmeters
params = model.init(key, X)


# we need to call the apply function, which also takes the model parameters
y_ = model.apply(params, X)

# VAE Example

This is our loss functions (maximizing ELBO)
$$
\mathcal{L}(\theta, \phi) = -\mathbb{E}[p_\phi(x | z)] +  \mathbb{KL}[q_\theta(z | x) || p(z)]  $$


First term of the loss we call reconstruction loss, second term you can see as some kind of complexity/regularization term.


In [None]:
import optax
from numpyro.distributions import Normal

# loss functions

@jax.vmap
def kl_divergence(mean, logvar):
  return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))


@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
  logits = flax.linen.log_sigmoid(logits)
  return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))


@jax.jit
def loss(logits, mean, logvar):
  reconstruction_loss = binary_cross_entropy_with_logits(logits, image)
  kl_div = kl_divergence(mean, logvar)
  return jnp.mean(reconstruction_loss + kl_div)

from typing import Sequence
class Sequential(flax.linen.Module):
  layers: Sequence[flax.linen.Module]

  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    return x


class MLPCompact(flax.linen.Module):
    # here, ordering is preserved
    num_hidden: int
    hidden_size: int
    outputs: int
    act_function: Optional[str] = 'relu'
    

    def forward(self, x):
      act_function = getattr(flax.linen, self.act_function)
      for _ in range(self.num_hidden):
          x = flax.linen.Dense(self.hidden_size)(x)
          x = act_function(x)
      # don't apply act in last layer
      x = flax.linen.Dense(self.outputs)(x)
      return x

    @flax.linen.compact
    def __call__(self, x):
      return self.forward(x)

class Decoder(MLPCompact):

    def setup(self):
        return

    @flax.linen.compact
    def __call__(self, x):
      x = super().forward(x)
      return x


class VAE(flax.linen.Module):
  latents: int
  outputs: int
  
  def setup(self):
    self.encoder = MLPCompact(3, 128, self.latents*2)
    self.decoder = Decoder(3, 256, outputs=self.outputs)

  def __call__(self, key, x, deterministic=False):

    gauss_params = self.encoder(x)
    mu, logvar = jnp.split(gauss_params, 2, -1)
    sigma = jnp.sqrt(jnp.exp(logvar))

    if not deterministic:
      gauss_dist = Normal(mu, sigma)
      z = gauss_dist.rsample(key)
    else:
      z = mu      
    return self.decoder(z), z, mu, logvar

  def generate(self, key, samples):
    # sample from prior distribution 
    mu = jnp.zeros((samples, self.latents))
    sigma = jnp.ones((samples, self.latents))
    gauss_dist = Normal(mu, sigma)
    z = gauss_dist.sample(key)
    logits =  self.decoder(z)
    return jnp.round(flax.linen.log_sigmoid(logits)).reshape(-1, 28, 28)




In [None]:
# initialize the model and optimizer

# vae template
vae = VAE(20, 28*28)
# create optimizer
optimizer = optax.adam(1e-3)
rk1, rk2 = jax.random.split(key)
params = vae.init(rk1, rk2, jnp.ones((64, 28*28)))
opt_state = optimizer.init(params)


In [None]:
import tensorflow_datasets as tfds
from tqdm import tqdm

# Construct a tf.data.Dataset
ds = tfds.as_numpy(tfds.load('binarized_mnist', split='train', batch_size=64, shuffle_files=True))

# this returns a function!
def construct_training_step(model, optimizer):
  # model forward
  def model_loss(params, key, image):
    logits, z, mean, logvar = model.apply(params, key, image)
    loss = jnp.mean(binary_cross_entropy_with_logits(logits, image) + kl_divergence(mean, logvar))
    return loss
  
  grad_func = jax.value_and_grad(model_loss, argnums=0)
  # this is the function that we call in the end
  def update_func(params, opt_state, key,  image):
    loss, grads = grad_func(params, key, image)
    updates, opt_state = optimizer.update(grads, opt_state)
    return loss, updates, opt_state

  return jax.jit(update_func)


training_step = construct_training_step(vae, optimizer)



In [None]:
from tqdm.notebook import tqdm
import optax


for e in range(20):
  tqdm_iter = tqdm(enumerate(ds))
  for i, batch in tqdm_iter:
    image  = batch["image"]
    
    image = image.reshape(-1, 28*28)

    key, _ = jax.random.split(key)
    logits, z, mean, logvar = vae.apply(params, key, image)

    # this is equivalent to  .backward and optimizer.step in PyTorch
    loss, updates, opt_state = training_step(params, opt_state, key, image)
    params = optax.apply_updates(params, updates)
    
    tqdm_iter.set_description(f"Epoch {e:5d}, batch {i:5d}, loss={loss:10.5f}")





In [None]:
# now we sample from p(z)
key, _ = jax.random.split(key)
imgs = vae.apply(params, key, 2, method=vae.generate)
plt.imshow(imgs[0])

# Meta Learning: MAML

In MAML, we concern ourselves with the multi-task setting. So the following objective has two parts to it. The inner loss is the loss for instances from task 1 and the outer loss is calculated on the shifter paramters in task 2.

$$
    \mathcal{L}(\theta - \nabla \mathcal{L}(\theta, x_1, y_1), x_2, y_2)
$$

It is clear that here we have an optimization step within the loss calculation. Lucky for us, JAX can help us out here!

In [None]:
import functools
from jax.tree_util import tree_multimap
import numpy as np



def mse(params, inputs, targets):
    # Computes average loss for the batch
    predictions = mlp.apply(params, inputs)
    return jnp.mean((targets - predictions)**2)

def inner_update(p, x1, y1, alpha=.1):
    """
        This is the expression with which we obtain \theta - grad(inner_loss)
    """
    grads = jax.grad(mse)(p, x1, y1)
    inner_sgd_fn = lambda g, state: (state - alpha*g)
    return tree_multimap(inner_sgd_fn, grads, p)

def maml_loss(p, x1, y1, x2, y2):
    """
        This is the outer loss
    """
    p2 = inner_update(p, x1, y1)
    return mse(p2, x2, y2)

# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
    print(x1_b.shape, y1_b.shape, x2_b.shape, y2_b.shape)
    task_losses = jax.vmap(functools.partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
    return jnp.mean(task_losses)


# this returns a function!
def construct_training_step(model, optimizer):
  # model forward
  grad_func = jax.value_and_grad(batch_maml_loss)
  # this is the function that we call in the end
  def update_func(params, opt_state, x1_b, y1_b, x2_b, y2_b):
    l, grads = grad_func(params, x1_b, y1_b, x2_b, y2_b)
    updates, opt_state = optimizer.update(grads, opt_state)
    return l, updates, opt_state
  return jax.jit(update_func)

def sample_tasks(outer_batch_size, inner_batch_size):
    # Sample random sinusoid functions
    As = []
    phases = []
    for _ in range(outer_batch_size):        
        As.append(np.random.uniform(low=0.1, high=.5))
        phases.append(np.random.uniform(low=0., high=np.pi))
    def get_batch():
        xs, ys = [], []
        for A, phase in zip(As, phases):
            x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
            y = A * np.sin(x/phase)
            xs.append(x)
            ys.append(y)
        return np.stack(xs), np.stack(ys)
    x1, y1 = get_batch()
    x2, y2 = get_batch()
    return x1, y1, x2, y2



for a,phase in zip([1.0, 0.2, 0.5], [0.1, np.pi/2, np.pi]):
    x = np.linspace(0, 10, 100)
    y = a * np.sin(x+phase)
    plt.plot(x, y)


In [None]:
import optax
mlp = MLPCompact(2, 64, 1)
params  = mlp.init(key, jnp.ones((64,1)))
optimizer = optax.adam(1e-2)
opt_state = optimizer.init(params)
training_step = construct_training_step(mlp, optimizer)


In [None]:
from tqdm import tqdm
np_batched_maml_loss = []
K=20
tqdm_iter = tqdm(range(10000))
for i in tqdm_iter:
    x1_b, y1_b, x2_b, y2_b = sample_tasks(4, K)
    l, updates, opt_state = training_step(params, opt_state, x1_b, y1_b, x2_b, y2_b)
    np_batched_maml_loss.append(l)
    params = optax.apply_updates(params, updates)
    tqdm_iter.set_postfix_str(f"loss:{l:10.5f}")

In [None]:
# batch the inference across K=100
x = np.linspace(-5,5,100).reshape((100, 1)) # (k, 1)
targets = np.sin(x)
predictions = jax.vmap(functools.partial(mlp.apply, params))(x)
plt.plot(x, predictions, label='pre-update predictions')
plt.plot(x, targets, label='target')

x1 = np.random.uniform(low=-5., high=5., size=(K,1))
y1 = 1. * np.sin(x1 + 0.)

net_params = params
for i in range(1,5):
    net_params = inner_update(net_params, x1, y1)
    predictions = jax.vmap(functools.partial(mlp.apply, net_params))(x)
    plt.plot(x, predictions, label='{}-shot predictions'.format(i))
plt.legend()

# Flows



In [None]:
import sklearn
import math
from sklearn.datasets import make_blobs, make_circles, make_swiss_roll, make_moons, make_spd_matrix
def dataset_transformed_gaussian(size, seed=345):
    rng = np.random.RandomState(seed)
    rot = make_spd_matrix(2, random_state=seed)
    data = rng.randn(size, 2) @ rot
    return data

def dataset_half_moon(size, seed=345):
    """Half moon distribution
    
    from https://blog.evjang.com/2018/01/nf1.html
    """
    rng = np.random.RandomState(seed)
    x2 = rng.normal(size=size) * 4
    x1 = rng.normal(size=size) + 0.25 * x2**2
    data = np.stack((x1, x2), axis=1)
    return data

def dataset_two_moons(size, seed=345):
    data = make_moons(size, noise=0.1, random_state=seed)[0]
    return  data

def dataset_blobs(size, loc=3, seed=345):
    data = make_blobs(size, 2, 
                                       centers=[[loc, loc], [-loc, -loc], 
                                                [loc, -loc], [-loc, loc]],
                                       random_state=seed)[0]
    return data

def dataset_circles(size, seed=345):
    data = make_circles(size, noise=0.1, factor=0.3, random_state=seed)[0]
    return data

def dataset_swiss_roll(size, seed=345):
    data = make_swiss_roll(size, noise=0.6, random_state=seed)[0][:, [0, 2]]
    return data

In [None]:
datasets = [dataset_transformed_gaussian, dataset_blobs, dataset_half_moon, 
            dataset_two_moons, dataset_circles, dataset_swiss_roll]

cols = 3
rows = int(math.ceil(len(datasets) / cols))
fig, axs = plt.subplots(rows, cols, figsize=(4.8 * cols, 3.2 * rows))

n_samples = 1000
for idx, (ax, dataset) in enumerate(zip(axs.flat, datasets)):
    X = dataset(n_samples)
    _ = ax.hist2d(X[:, 0], X[:, 1], bins=100);
    ax.set_title(dataset.__name__)

In [None]:
from numpyro.distributions import Normal



base_dist = Normal(jnp.array([0,0]), jnp.array([1,1]) )
mlp = MLPCompact(2, 2, 2)
params = mlp.init(key, np.ones((2,2)))


In [None]:

class AffineCoupling(flax.linen.Module):
    input_dims: list
    output_dims: list
    net: flax.linen.Module

    def forward(self, x):
        x_inp = x[:, self.input_dims]
        params = self.net(x_inp)
        logscale, shift = jnp.split(params, 2, -1)
        x_out = x[:, self.output_dims]*jnp.exp(logscale) + shift
        x = x.at[:,self.output_dims].set(x_out)
        ldj = jnp.exp(logscale.sum(-1))
        return x, ldj


    def __call__(self, x):
        return self.forward(x)

    def inverse(self, x):
        x_inp = x[:, self.input_dims]
        params = self.net(x_inp)
        logscale, shift = jnp.split(params, 2, -1)
        
        x_out = (x[:, self.output_dims] - shift)/jnp.exp(logscale)
        x = x.at[:,self.output_dims].set(x_out)
        ldj = -jnp.sum(logscale.sum(-1))
        return x, ldj


class Flow(flax.linen.Module):
    transforms: flax.linen.Module
    base_dist: flax.linen.Module
    dim: int

    def forward(self, key, n):
        x = self.base_dist.expand((n, self.dim)).sample(key)
        for t in self.transforms:
            x, _ =  t.forward(x)
        return x
    
    def __call__(self, key, n):
        return self.forward(key, n)

    def inverse(self, x):
        logprob = 0
        for t in reversed(self.transforms):
            x, ldj =  t.inverse(x)
            logprob+=ldj
        logprob += self.base_dist.expand(x.shape).log_prob(x).sum(-1)
        return x, logprob



In [None]:
x = jnp.array(np.ones((10, 2)))

coupling = AffineCoupling([0], [1], mlp)


base_dist = Normal(jnp.array([0,0]), jnp.array([1,1]))
flow = Flow([
    AffineCoupling([0], [1], MLPCompact(2, 32, 2)),
    AffineCoupling([1], [0], MLPCompact(2, 32, 2)),
    AffineCoupling([0], [1], MLPCompact(2, 32, 2)),
    AffineCoupling([1], [0], MLPCompact(2, 32, 2))],
    base_dist=base_dist, dim=2)

params =  flow.init(key, key, 2)
optimizer = optax.chain(optax.adam(1e-4), optax.clip_by_global_norm(5.0))
opt_state = optimizer.init(params)


samples = flow.apply(params, key, 500, method=flow.forward)
_ = plt.hist2d(samples[:, 0], samples[:, 1], bins=100)

In [None]:
z, logprob = flow.apply(params, jnp.array(x), method=flow.inverse)

z.shape

In [None]:
# this returns a function!
def construct_training_step(model, optimizer):
  # model forward
  def model_loss(params, x):
    z, logprob = model.apply(params, x, method=model.inverse)  
    return -jnp.mean(logprob)
  grad_func = jax.value_and_grad(model_loss)
  # this is the function that we call in the end
  def update_func(params, opt_state, x):
    l, grads = grad_func(params, x)
    updates, opt_state = optimizer.update(grads, opt_state)
    return l, updates, opt_state
  return jax.jit(update_func)


training_step = construct_training_step(flow,  optimizer)


In [None]:
X = dataset_transformed_gaussian(5000)
def iterate_over_array(X, bs):
    for i in range(0, X.shape[0], bs):
        yield jnp.array(X[i:i+bs])


In [None]:
for i in range(20):
    tqdm_iter = tqdm(iterate_over_array(X, 64))
    for x in tqdm_iter:
        loss, updates, opt_state = training_step(params,opt_state, jnp.array(x))
        params = optax.apply_updates(params, updates)
        tqdm_iter.set_postfix_str(f"loss={loss:10.3f}")

In [None]:
samples = flow.apply(params, key, 5000, method=flow.forward)
_ = plt.hist2d(samples[:, 0], samples[:, 1], bins=100)