In [None]:
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmeta.toy import Sinusoid
from torchmeta.utils.data import BatchMetaDataLoader
from meta_learning_algorithms import MAML, Reptile, MetaSGD

In [None]:
class SineModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(1, 64)
        self.hidden2 = nn.Linear(64, 32)
        self.hidden3 = nn.Linear(32, 1)

    def forward(self, x):
        x = nn.functional.relu(self.hidden1(x))
        x = nn.functional.relu(self.hidden2(x))
        x = self.hidden3(x)
        return x

In [None]:
task_sets = Sinusoid(num_samples_per_task=20, num_tasks=10000)
dataloader = BatchMetaDataLoader(task_sets, batch_size=16)
maml_sine_model = SineModel()
maml_pp_sine_model = SineModel()
meta_sgd_sine_model = SineModel()
reptile_sine_model = SineModel()

In [None]:
support_query_dataset = []
support_only_dataset = []

for batch in dataloader:
    effective_batch_size = batch[0].shape[0]
    for i in range(effective_batch_size):
        train_inputs, train_targets = batch[0][i].float(), batch[1][i].float()
        x_support, y_support = train_inputs[::2], train_targets[::2]
        x_query, y_query = train_inputs[1::2], train_targets[1::2]
        support_query_dataset.append(
            (
                (x_support, y_support),
                (x_query, y_query)
            )
        )
        support_only_dataset.append((train_inputs, train_targets))

In [None]:
maml = MAML(
    model=maml_sine_model,
    loss=nn.MSELoss(),
    maml_plus_plus=False,
    inner_lr=1e-2,
    meta_lr=1e-1,
    device="cpu"
)
maml_plus_plus = MAML(
    model=maml_pp_sine_model,
    loss=nn.MSELoss(),
    maml_plus_plus=True,
    inner_lr=1e-2,
    meta_lr=1e-1,
    device="cpu"
)
meta_sgd = MetaSGD(
    model=meta_sgd_sine_model,
    loss=nn.MSELoss(),
    inner_lr=1e-3,
    meta_lr=1e-3,
    device="cpu"
)
reptile = Reptile(
    model=reptile_sine_model,
    loss=nn.MSELoss(),
    inner_lr=1e-1,
    meta_lr=1e-1,
    clipping=4.0,
    device="cpu"
)

In [None]:
sq_train_dataset, sq_val_dataset = random_split(support_query_dataset, [0.9, 0.1])
sq_train_loader = DataLoader(sq_train_dataset, batch_size=4, shuffle=True, collate_fn=lambda b: b)
sq_val_loader = DataLoader(sq_val_dataset, batch_size=4, shuffle=True, collate_fn=lambda b: b)

so_train_dataset, so_val_dataset = random_split(support_only_dataset, [0.9, 0.1])
so_train_loader = DataLoader(so_train_dataset, batch_size=4, shuffle=True, collate_fn=lambda b: b)
so_val_loader = DataLoader(so_val_dataset, batch_size=4, shuffle=True, collate_fn=lambda b: b)

In [None]:
maml.train(sq_train_loader, sq_val_loader, epochs=3)
meta_sgd.train(sq_train_loader, sq_val_loader, epochs=3)
reptile.train(so_train_loader, so_val_loader, epochs=3)