In [1]:
# Note: must set this env variable before jax is imported
import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"  # hack to pretend i have many devices :)
import tqdm
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
import flax
import functools

print(jax.devices())

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [None]:
# make a "dataset"
batch_size, input_dim = 64, 500
num_devices = len(jax.devices())
assert num_devices == 8 and batch_size % num_devices == 0

def batch(gt):  # generate random batch of inputs and g.t. outputs
    xs = jnp.array(np.random.randn(batch_size, input_dim)) / 10
    ys = (xs * gt).sum(axis=-1)
    return (xs, ys)

def forward(_params, _x, _y):
    yhat = (_params * _x).sum(axis=-1)
    loss = ((yhat - _y) ** 2).sum()
    return loss

@functools.partial(jax.pmap, axis_name='devices', in_axes=(0, 0, 0))
def forward_pmapped(_params, _x, _y):
    yhat = (_params * _x).sum(axis=1)
    loss = ((yhat - _y) ** 2).sum()
    return jax.lax.psum(loss, axis_name='devices')

def train_step(_params, _eta, _x, _y):
    def loss_fn(p): return forward(p, _x, _y)
    grads = jax.grad(loss_fn)(_params)
    return _params - _eta * grads

@functools.partial(jax.pmap, axis_name='devices', in_axes=(0, None, 0, 0))
def train_step_pmapped(_params, _eta, _x, _y):
    def loss_fn(p): return forward(p, _x, _y)
    grads = jax.grad(loss_fn)(_params)
    grads = jax.lax.psum(grads, axis_name='devices')
    return _params - _eta * grads


def meta_step(_params, _eta, _x, _y, do_pmap: bool = False):
    HH = 2
    meta_eta = 1e-1
    def _meta_loss(_eta, _params, _x, _y):
        loss = None
        if not do_pmap:  # do regular meta learning
            for _ in range(HH):
              _params = train_step(_params, _eta, _x, _y)  # rollout `HH` steps
            loss = forward(_params, _x, _y)
            return loss, _params

        else:  # do pmapped meta learning
            _x_reshaped = _x.reshape(num_devices, -1, input_dim)
            _y_reshaped = _y.reshape(num_devices, -1)
            for _ in range(HH):
              _params = train_step_pmapped(_params, _eta, _x_reshaped, _y_reshaped)
            loss = forward_pmapped(_params, _x_reshaped, _y_reshaped)
            return loss.at[0].get(), _params


    if do_pmap:
        _x_reshaped = _x.reshape(num_devices, -1, input_dim)
        _y_reshaped = _y.reshape(num_devices, -1)
        (loss, _params), meta_grad = jax.value_and_grad(_meta_loss, has_aux=True)(_eta, _params, _x_reshaped, _y_reshaped)
        return _params, _eta - meta_eta * meta_grad, loss, meta_grad
    else:
        (loss, _params), meta_grad = jax.value_and_grad(_meta_loss, has_aux=True)(_eta, _params, _x, _y)
        return _params, _eta - meta_eta * meta_grad, loss, meta_grad


# "train"
def trial(seed, use_pmap, T):
    np.random.seed(seed)

    # define the model and loss
    gt = jnp.array(np.random.randn(input_dim)) / 10  # ground truth model params
    params = jnp.zeros((input_dim,))  # params of the model
    if use_pmap:
      params = flax.jax_utils.replicate(params)
    eta = 0.1  # adaptive lr

    losses = []
    lrs = []
    meta_grads = []
    for _ in range(T):
        x, y = batch(gt)
        params, eta, loss, meta_grad = meta_step(params, eta, x, y, use_pmap)
        losses.append(loss)
        lrs.append(eta)
        meta_grads.append(meta_grad)
    return np.array(losses), np.array(lrs), np.array(meta_grads)


In [None]:
seeds = np.arange(5)
T = 100

# run repeated experiments
experiment_losses, experiment_lrs, experiment_meta_grads = [], [], []
pmap_experiment_losses, pmap_experiment_lrs, pmap_experiment_meta_grads = [], [], []

for s in seeds:
    losses, lrs, meta_grads = trial(s, use_pmap=False, T=T)
    pmap_losses, pmap_lrs, pmap_meta_grads = trial(s, use_pmap=True, T=T)
    experiment_losses.append(losses); experiment_lrs.append(lrs); experiment_meta_grads.append(meta_grads)
    pmap_experiment_losses.append(pmap_losses); pmap_experiment_lrs.append(pmap_lrs); pmap_experiment_meta_grads.append(pmap_meta_grads)
experiment_losses, experiment_lrs, experiment_meta_grads, pmap_experiment_losses, pmap_experiment_lrs, pmap_experiment_meta_grads = map(np.array, [experiment_losses, experiment_lrs, experiment_meta_grads, pmap_experiment_losses, pmap_experiment_lrs, pmap_experiment_meta_grads])

# plot
def plot(arr, label, ax):
    assert arr.shape == (len(seeds), T)
    ts = range(T); avgs = np.mean(arr, 0); stds = np.std(arr, 0)
    ax.plot(ts, avgs, label=label)
    ax.fill_between(ts, avgs - 1.96 * stds, avgs + 1.96 * stds, alpha=0.2)
    ax.legend()
    pass

fig, ax = plt.subplots(3, 1)
plot(experiment_losses, 'regular', ax[0])
plot(pmap_experiment_losses, 'pmapped', ax[0])
plot(experiment_lrs, 'regular', ax[1])
plot(pmap_experiment_lrs, 'pmapped', ax[1])
plot(experiment_meta_grads, 'regular', ax[2])
plot(pmap_experiment_meta_grads, 'pmapped', ax[2])