Testing a toy example for demonstrating MAML. Regression to a sine wave. Tasks are waves with different amplitudes and phases. Model is first pre-trained on a sample of training tasks using higher-order gradients as described in MAML. During testing, it is fine-tuned on `k` examples from an evaluation task.

This is benchmarked against pretraining a model on a sample of training tasks which is then finetuned on `k` examples from the evaluation task.

In [46]:
%reload_ext autoreload
%autoreload 2

import torch
from torch import nn
from torch import optim
from higher import innerloop_ctx
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

import notebook_setup
import ppo, utils, meta

In [84]:
# tasks
# sin waves with amplitude [0.1, 5] and phase [0, pi]
def get_task(seed=None):
    rnd = np.random.RandomState(seed)
    phase = rnd.rand(1) * np.pi
    amplitude = 0.1 + rnd.rand(1) * 4.9
    return amplitude, phase

def get_samples(k, amplitude, phase, seed=None):
    rnd = np.random.RandomState(seed)
    x = -5 + 10 * rnd.rand(k)
    y = amplitude * np.sin(x + phase)
    return torch.from_numpy(x).float().reshape(-1, 1), torch.from_numpy(y).float().reshape(-1, 1)

In [92]:
n_training_tasks = 5
training_tasks = list(range(n_training_tasks))  # random seeds for generating task parameters
evaluation_tasks = [11]
k = 10        # number of examples per task
n_adapt = 10
alpha = 0.01 # global learning rate
loss = nn.MSELoss()

In [93]:
def make_model():
    model = nn.Sequential(
            nn.Linear(1, 40),
            nn.ReLU(),
            nn.Linear(40, 40),
            nn.ReLU(),
            nn.Linear(40, 1)
    )
    return model

In [94]:
def train(model, opt, n=100, losses=[]):
    """Pre-training model using higher-order gradients on per-task samples from `training_tasks`"""
    for _ in trange(n, leave=False):
        opti = optim.Adam(model.parameters(), lr=alpha)
        for i, task in enumerate(map(get_task, training_tasks)):
            with innerloop_ctx(model, opti, copy_initial_weights=False) as (fmodel, diffopt):

                xi, yi = get_samples(k, *task)
                y_ = fmodel(xi)
                l = loss(y_, yi)
                diffopt.step(l)

                xi, yi = get_samples(k, *task)
                y_ = fmodel(xi)
                l = loss(y_, yi)
                l.backward()
                if losses is not None:
                    losses[i].append(l.item())
        opt.step()
        opt.zero_grad()

def benchmark_train(model, opt, n=100, losses=[]):
    """Pre-training model on aggregated samples from `training_tasks`"""
    for _ in trange(n, leave=False):
        for i, task in enumerate(map(get_task, training_tasks)):

                xi, yi = get_samples(k, *task)
                y_ = model(xi)
                l = loss(y_, yi)
                l.backward()
                if losses is not None:
                    losses[i].append(l.item())

        opt.step()
        opt.zero_grad()

In [95]:
def test(model, seed=0, losses=None, predict=None):
    """Fine-tuning model on tasks contained in `evaluation_tasks`"""
    opti = optim.Adam(model.parameters(), lr=alpha)
    for i, task in enumerate(map(get_task, evaluation_tasks)):
        with innerloop_ctx(model, opti, track_higher_grads=False) as (fmodel, diffopt):

            xi, yi = get_samples(k, *task, seed=seed)
            for _ in range(n_adapt):
                y_ = fmodel(xi)
                l = loss(y_, yi)
                diffopt.step(l)

            if losses is not None:
                xi, yi = get_samples(k, *task, seed=seed + 1)
                y_ = fmodel(xi)
                l = loss(y_, yi)
                losses[i].append(l.item())
            
            if predict is not None:
                return fmodel(predict)

In [96]:
model = make_model()
bench = make_model()
bench.load_state_dict(model.state_dict())

opt = optim.Adam(model.parameters(), lr=alpha)
optb = optim.Adam(bench.parameters(), lr=alpha)

losses_train = [[] for _ in training_tasks]
losses_bench_train = [[] for _ in training_tasks]
losses_test = [[] for _ in evaluation_tasks]
losses_bench_test = [[] for _ in evaluation_tasks]

for epoch in trange(10, leave=False):
    train(model, opt, 100, losses_train)
    benchmark_train(bench, optb, 100, losses_bench_train)
    test(model, losses=losses_test)
    test(bench, losses=losses_bench_test)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
# Loss perfomance
for t, tb in zip(losses_test, losses_bench_test):
    plt.plot(t, label='Test')
    plt.plot(tb, label='Bench')
plt.grid(True)
plt.legend()

In [None]:
# Plotting actual outputs
x = torch.from_numpy(np.linspace(-5, 5, 50, False).reshape(-1, 1)).float()
y_pre = model(x).data
y_post = test(model, predict=x).data
y_b_pre = bench(x).data
y_b_post = test(bench, predict=x).data
amp, ph = get_task(evaluation_tasks[-1])
y = amp * np.sin(x.data.numpy() + ph)

plt.figure(figsize=(8,6))
plt.plot(x, y_pre, label='Test(Pre)', c='b', ls=':')
plt.plot(x, y_post, label='Test(Post)', c='b')
plt.plot(x, y_b_pre, label='Bench(Pre)', c='r', ls=':')
plt.plot(x, y_b_post, label='Bench(Post)', c='r')
plt.plot(x, y, label='True', c='g')
plt.legend()
plt.grid(True)