In [None]:
# If running in google colab run this cell as well
# ! git clone https://github.com/nicolas-aagnes/sequential-leo.git
# %cd sequential-leo
# ! git pull origin
# ! git checkout leo2d

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
plt.style.use('seaborn-whitegrid')
import numpy as np
import torch
from IPython import display

In [None]:
import torch.utils.tensorboard as tensorboard
import argparse
from config import get_model_and_dataloaders
from models.maml_new import MAML, MAMLConfig

In [None]:
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args = Namespace(
    model="maml",
    dataset="sine2D",
    num_support=1,
    num_query=15,
    num_timesteps=10,
    num_timesteps_pred=4,
    num_inner_steps=1,
    inner_lr=0.4,
    learn_inner_lr=False,
    outer_lr=0.001,
    batch_size=32,
    num_train_tasks=100000,
    num_val_tasks=32,
    test=False,
    log_dir=None,
    checkpoint_step=-1,
    noise=0.0,
    delta=0.3,
)

In [None]:
def plot(x_support, y_support, y_pred):
    plt.figure(figsize=(5, 5))
    plt.plot(x_support[0, 0, :, 0], x_support[0, 0, :, 1], "ro")
    plt.scatter(y_support[0, 0, :, 0], y_support[0, 0, :, 1], facecolors='none', edgecolors='r')
    plt.plot(y_pred[0, 0, :, 0], y_pred[0, 0, :, 1], "go")
    plt.xlim((-6, 6))
    plt.ylim((-5, 5))
    plt.show()

In [None]:
writer = tensorboard.SummaryWriter(log_dir=args.log_dir)
_, dataloaders = get_model_and_dataloaders(args)

config = MAMLConfig(
    input_size=2,
    hidden_size=16,
    num_timesteps_pred=args.num_timesteps_pred
)
model = MAML(
    args.num_inner_steps,
    args.inner_lr,
    args.learn_inner_lr,
    args.outer_lr,
    args.log_dir,
    config,
)

for i, (task_batch, predictions) in enumerate(model.train(dataloaders["train"], dataloaders["val"], writer, args, True)):
    if i % 10 == 0:
        _, _, x_query, y_query = task_batch
        display.clear_output(wait=True)
        plot(x_query, y_query, predictions.detach().cpu().numpy())

In [None]:
# Testing dataset.
def plot(x_support, y_support):
    plt.figure(figsize=(5, 5))
    i = np.random.choice(x_support.shape[0])
    plt.plot(x_support[i, 0, :, 0], x_support[i, 0, :, 1], "ro")
    plt.scatter(y_support[i, 0, :, 0], y_support[i, 0, :, 1], facecolors='none', edgecolors='r')
    plt.xlim((-6, 6))
    plt.ylim((-5, 5))
    plt.show()

for task_batch in dataloaders["val"]:
    x_support, y_support, x_query, y_query = task_batch
    plot(x_support, y_support)
