In [1]:
from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as onp
import matplotlib.pyplot as plt
import jax.random as jr

In [2]:
input_size = 32
output_size = 10
K = 100 # K-shot, number of examples per task
batch_size = 32 # meta batch size
alpha = 0.1 # inner learning rate
lr = 0.001 # outer learning rate

## Jax tutorial: Meta Learning with MAML

I't straightforward to sketch a model-based Meta Learning algorithm in JAX such as MAML without any extra dependency. 

Let us first define a meta-batch for meta-learning, having a static regression problem in mind:

In [3]:
# The meta batch
batch_x = onp.random.randn(batch_size, K, input_size)
batch_y = onp.random.randn(batch_size, K, output_size)

# support set, aka context, training set
batch_x1 = batch_x[:, :K//2]
batch_y1 = batch_y[:, :K//2]
# query set, test set
batch_x2 = batch_x[:, K//2:]
batch_y2 = batch_y[:, K//2:]

The (meta) batch consists in ``batch_size`` input-output pairs of ``K`` elements each. 

The idea is that ``(batch_x[i], batch_y[i])`` is a dataset with ``K`` samples from the **same** data-generating system. 
Conversely, ``(batch_x[i], batch_y[i])`` ``(batch_x[i], batch_y[i])``, for ``i  ~= j`` are two datasets from different, yet **related** data-generating systems.

Let us define a simple MLP as base architecture.

In [4]:
# A simple MLP (stock code from copilot)
class MLP(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_size)(x)
        return x

In [5]:
# Initialize the MLP
mlp = MLP(hidden_size=128, output_size=output_size)
key = jr.PRNGKey(0)
x = jnp.ones((input_size))  # Example input with 32 features
params = mlp.init(key, x)  # Initialize parameters
output = mlp.apply(params, x)  # Forward pass
output.shape

(10,)

Let us build the MAML loss incrementally. Taking advantage of JAX's transforms, first we define the loss for a single instance in the meta-dataset, and then vectorize it with a ``vmap`` transform. This, in my opinion, maximizes readability of the code.

The first step is to define the standard regression loss for a single dataset:

In [6]:
def loss_fn(params, x, y):
    pred = mlp.apply(params, x)
    return jnp.mean((pred - y) ** 2)

loss_fn(params, batch_x[0], batch_y[0])  # Loss for the first task

Array(1.5897797, dtype=float32)

Then, we define the MAML inner update step, namely a gradient descent step. Pretty straightforward in plain JAX!

In [7]:
# MAML inner update
def inner_update(p, x1, y1):
    grads = jax.grad(loss_fn)(p, x1, y1)
    inner_sgd_fn = lambda g, state: (state - alpha*g) # GD update
    return jax.tree_util.tree_map(inner_sgd_fn, grads, p)

params_updated = inner_update(params, batch_x1[0], batch_y1[0])  # Inner update for the first task

The MAML loss is the regression loss ``loss_fn`` measured on ``(x2, y2)`` with parameters updated with GD executed on ``(x1, y1)``:

In [8]:
def maml_loss(p, x1, y1, x2, y2):
    p2 = inner_update(p, x1, y1)
    return loss_fn(p2, x2, y2)

maml_loss(params, batch_x1[0], batch_y1[0], batch_x2[0], batch_y2[0])  # Inner update for the first task

Array(1.4345717, dtype=float32)

The ``maml_loss`` defined above only handles a single instance in the meta-dataset. Let us batchify it with ``vmap``!

In [9]:
# batched maml loss
def batched_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
    maml_loss_cfg = partial(maml_loss, p) # fix first argument
    maml_loss_batch = jax.vmap(maml_loss_cfg) # vmap over the rest
    task_losses = maml_loss_batch(x1_b, y1_b, x2_b, y2_b)
    #maml_loss_batch = jax.vmap(maml_loss, in_axes=(None, 0, 0, 0, 0)) # alternative
    #task_losses = maml_loss_batch(p, x1_b, y1_b, x2_b, y2_b)
    return jnp.mean(task_losses)

batched_maml_loss(params, batch_x1, batch_y1, batch_x2, batch_y2)  # Inner update for the first task

Array(1.4123442, dtype=float32)

Voila'! This is the meta-training loss we wanna minimize wrt ``params``. Gradients are there for optimization.

In [10]:
loss, grads = jax.value_and_grad(batched_maml_loss)(params, batch_x1, batch_y1, batch_x2, batch_y2)  # Inner update for the first task

The actual MAML training loop may look like the following:

In [11]:
meta_optimizer = optax.adam(learning_rate=lr)
meta_opt_state = meta_optimizer.init(params)

@jax.jit
def make_step(p, s, x1, y1, x2, y2):
    l, g = jax.value_and_grad(batched_maml_loss)(p, x1, y1, x2, y2)
    u, s = meta_optimizer.update(g, s)
    p = optax.apply_updates(p, u)
    return p, s, l


losses = []
for i in range(100):
    #batch_x1, batch_y1, batch_x2, batch_y2 = sample_tasks()
    params, meta_opt_state, loss = make_step(
        params, meta_opt_state, batch_x1, batch_y1, batch_x2, batch_y2
    )
    losses.append(loss)

See the full [maml meta learning example](gallery/maml_sines.ipynb) in the gallery!