In [14]:
from functools import partial
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import numpy as onp
import matplotlib.pyplot as plt
import jax.random as jr

In [15]:
input_size = 32
output_size = 10
K = 100 # K-shot, number of examples per task
batch_size = 32 # meta batch size
alpha = 0.1 # inner learning rate
lr = 0.001 # outer learning rate
seed = 42

In [16]:
key = jr.key(seed)

## Jax tutorial: Meta Learning with HyperNets

In JAX, it's straightforward to sketch a HyperNet-based meta learning algorithm.

Let us first define a meta-batch for meta-learning, having a static regression problem in mind:

In [17]:
# The meta batch: batch_size static regression datasets
batch_x = onp.random.randn(batch_size, K, input_size)
batch_y = onp.random.randn(batch_size, K, output_size)

# support set, aka context, training set
batch_x1 = batch_x[:, :K//2]
batch_y1 = batch_y[:, :K//2]
# query set, test set
batch_x2 = batch_x[:, K//2:]
batch_y2 = batch_y[:, K//2:]

The (meta) batch consists in ``batch_size`` input-output pairs of ``K`` elements each. 

The idea is that ``(batch_x[i], batch_y[i])`` is a dataset with ``K`` samples from the **same** data-generating system. 
Conversely, ``(batch_x[i], batch_y[i])`` ``(batch_x[i], batch_y[i])``, for ``i  ~= j`` are two datasets from different, yet **related** data-generating systems.

Let us define a simple MLP as base architecture.

In [18]:
# A simple MLP (stock code from copilot)
class MLP(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_size)(x)
        return x

In [19]:
# Initialize the MLP
mlp = MLP(hidden_size=128, output_size=output_size)
key, subkey = jr.split(key)
x = jnp.ones((input_size))  # Example input with 32 features
params = mlp.init(subkey, x)  # Initialize parameters
params_mlp_hn = mlp.apply(params, x)  # Forward pass
params_mlp_hn.shape

(10,)

In [20]:
params_flat, unflatten_params_fn = jax.flatten_util.ravel_pytree(params)
n_params = params_flat.size
n_params

5514

The hypernet takes in a dataset and generates corresponding mlp model parameters. It is kind of a learned algorithm. Considering a case of static regression, we use a deep set hypernetsince it is permutation-invariant, like the algorithm we aim to learn.

In [21]:
class DeepSet(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        # Apply a shared MLP to each element in the set
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)

        # Aggregate the set using a permutation-invariant operation (e.g., sum)
        x = jnp.sum(x, axis=-2)

        # Apply another MLP to the aggregated representation
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.output_size)(x)
        return x 

In [22]:
# Example usage
hypernet = DeepSet(hidden_size=128, output_size=n_params)
key, subkey = jr.split(key)
xy_input = jnp.ones((K, input_size + output_size))  # The hypernet processes K input-output pairs...
params_hypernet = hypernet.init(subkey, xy_input)  # Initialize parameters
params_mlp_hn = hypernet.apply(params_hypernet, xy_input)  # And splits out the mlp params
params_mlp_hn.shape

(5514,)

The standard regression loss, boring stuff

In [23]:
def loss_fn(params, x, y):
    pred = mlp.apply(params, x)
    return jnp.mean((pred - y) ** 2)

loss_fn(params, batch_x[0], batch_y[0])  # Loss for the first task

Array(1.495289, dtype=float32)

The hypernet loss, slightly more interesting!

In [24]:
def hypernet_loss(ph, x1, y1, x2, y2):


    # Generate the weights using the hypernetwork
    x1y1 = jnp.concatenate((x1, y1), axis=-1)  # Concatenate x1 and y2
    weights = hypernet.apply(ph, x1y1)

    # Unflatten the weights to match the model's parameters
    pm = unflatten_params_fn(weights)

    return loss_fn(pm, x2, y2)  # Loss for the second task

hypernet_loss(params_hypernet, batch_x1[0], batch_y1[0], batch_x2[0], batch_y2[0])

Array(20384282., dtype=float32)

Let us vectorize the hypernet loss to make it amenable for "meta mini-batch training"

In [25]:
def batched_hypernet_loss(ph, x1_b, y1_b, x2_b, y2_b):
    hn_loss_cfg = partial(hypernet_loss, ph) # fix the first argument
    hn_loss_vmapped = jax.vmap(hn_loss_cfg) # vmap over the rest
    task_losses = hn_loss_vmapped(x1_b, y1_b, x2_b, y2_b)
    return jnp.mean(task_losses)

batched_hypernet_loss(params_hypernet, batch_x1, batch_y1, batch_x2, batch_y2)  # Inner update for the first task

Array(20158264., dtype=float32)

Note: the initial loss is poorly scaled. We should have scaled the hypernet to split out reasonable mlp parameters at initialization...

See the full [hypernet meta learning example](gallery/hypernet_sines.ipynb) in the gallery!