In [None]:
# If running in google colab run this cell as well
# ! git clone https://github.com/nicolas-aagnes/sequential-leo.git
# %cd sequential-leo
# ! git pull origin
# ! git checkout maml-vs-leo

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np
import torch
from IPython import display

In [3]:
import torch.utils.tensorboard as tensorboard
import argparse
from config import get_model_and_dataloaders
from models.maml_new import MAML, MAMLConfig

In [4]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Namespace(
    model="maml",
    dataset="sine2D",
    num_support=1,
    num_query=15,
    num_timesteps=10,
    num_timesteps_pred=4,
    num_inner_steps=1,
    inner_lr=0.4,
    learn_inner_lr=False,
    outer_lr=0.001,
    batch_size=64,
    num_train_tasks=10000,
    num_val_tasks=64,
    test=False,
    log_dir=None,
    checkpoint_step=-1,
    noise=0.0,
    delta=0.3,
)

In [5]:
def plot(x_support, y_support, y_pred):
    plt.figure(figsize=(5, 5))
    plt.plot(x_support[0, 0, :, 0], x_support[0, 0, :, 1], "ro")
    plt.scatter(y_support[0, 0, :, 0], y_support[0, 0, :, 1], facecolors='none', edgecolors='r')
    # plt.plot(y_pred[0, 0, :, 0], y_pred[0, 0, :, 1], "go")
    plt.xlim((-6, 6))
    plt.ylim((-5, 5))
    plt.show()

# MAML Training

In [None]:
writer = tensorboard.SummaryWriter(log_dir=args.log_dir)
_, dataloaders = get_model_and_dataloaders(args)

config = MAMLConfig(
    input_size=2,
    hidden_size=512,
    num_timesteps_pred=args.num_timesteps_pred
)
model = MAML(
    args.num_inner_steps,
    args.inner_lr,
    args.learn_inner_lr,
    args.outer_lr,
    args.log_dir,
    config,
)

for i, (task_batch, predictions) in enumerate(model.train(dataloaders["train"], dataloaders["val"], writer, args, True)):
    if i % 10 == 0:
        _, _, x_query, y_query = task_batch
        display.clear_output(wait=True)
        plot(x_query, y_query, predictions.detach().cpu().numpy())

In [None]:
model.eval(dataloaders["val"], 1000)

# LEO Training

In [None]:
config = LEOConfig(
    num_support=args.num_support,
    num_timesteps_pred=args.num_timesteps_pred,
    input_size=2,  # This is for sine 2D.
    encoder_hidden_size=512,
    relation_net_hidden_size=512,
    z_dim=64,
    decoder_hidden_size=512,
    f_theta_hidden_size=512,
)

model = LEO(
    args.num_inner_steps,
    args.inner_lr,
    args.learn_inner_lr,
    args.outer_lr,
    args.log_dir,
    config,
)

for i, (task_batch, predictions) in enumerate(model.train(dataloaders["train"], dataloaders["val"], writer, args, True)):
    if i % 10 == 0:
        _, _, x_query, y_query = task_batch
        display.clear_output(wait=True)
        plot(x_query, y_query, predictions.detach().cpu().numpy())

In [None]:
model.eval(dataloaders["val"], 1000)

In [6]:
_, dataloaders = get_model_and_dataloaders(args)
for task in dataloaders["train"]:
    x, y, _, _ = task
    plot(x, y)

[tensor([[[[ 0.3637,  0.2138],
          [ 0.6868,  0.3961],
          [ 1.0099,  0.5647],
          ...,
          [ 2.6252,  1.0258],
          [ 2.9483,  1.0191],
          [ 3.2714,  0.9772]]],


        [[[-2.0113, -0.2720],
          [-1.6882, -0.6679],
          [-1.3651, -0.9248],
          ...,
          [ 0.2503,  0.3465],
          [ 0.5733,  0.7237],
          [ 0.8964,  0.9505]]],


        [[[ 0.5010,  1.5489],
          [ 0.8241,  2.2996],
          [ 1.1472,  2.7248],
          ...,
          [ 2.7626, -0.2655],
          [ 3.0857, -1.2754],
          [ 3.4087, -2.1048]]],


        ...,


        [[[ 0.2660,  0.2000],
          [ 0.5891,  0.4400],
          [ 0.9122,  0.6737],
          ...,
          [ 2.5275,  1.6326],
          [ 2.8506,  1.7630],
          [ 3.1737,  1.8677]]],


        [[[-0.9194, -0.8935],
          [-0.5964, -0.5998],
          [-0.2733, -0.2803],
          ...,
          [ 1.3421,  1.2180],
          [ 1.6652,  1.4067],
          [ 1.9883,  1.