In [None]:
# Auto-reload external modules
%load_ext autoreload
%autoreload 2

import torch.nn as nn
import torch.optim as optim

from insect import meta_train, evaluate_model, INSECTModel
from maml import maml_train, maml_evaluate, MAMLModel

In [7]:
#---------------------------
# Meta-Learning Setup
#---------------------------
# The model takes x and memory as input.
# Let's say x is 1D and memory is a small vector (e.g., 5D).
# We will concatenate them: final input dimension = x_dim + memory_dim
x_dim = 1
memory_dim = 20
hidden_dim = 128
output_dim = 1

model = INSECTModel(x_dim + memory_dim, hidden_dim, output_dim)

# We'll define a meta-training loop that updates model weights so that it can
# later adapt quickly by changing the memory.

#---------------------------
# Meta-Training Configuration
#---------------------------
meta_optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# Hyperparameters
meta_iterations = 6000       # Number of meta-training iterations
inner_steps = 10              # Number of memory adaptation steps per task
adapt_lr = 0.01              # Learning rate for adapting memory at test-time
task_batch_size = 10         # Number of points per task for adaptation
test_batch_size = 1          # Number of points to test after adaptation

In [None]:
meta_train(meta_iterations, inner_steps, adapt_lr, task_batch_size, test_batch_size, memory_dim, model, loss_fn, meta_optimizer)

In [None]:
evaluate_model(model, task_batch_size, test_batch_size, adapt_lr, memory_dim, loss_fn)

In [102]:
input_dim = 1
hidden_dim = 138
output_dim = 1

max_meta_iterations = 6000
inner_steps = 10
inner_lr = 0.01
task_batch_size = 10
test_batch_size = 10

maml_model = MAMLModel(input_dim, hidden_dim, output_dim)
maml_loss_fn = nn.MSELoss()
maml_meta_optimizer = optim.Adam(maml_model.parameters(), lr=0.001)

In [None]:
# Meta-train MAML
maml_train(
    maml_model,
    maml_loss_fn,
    maml_meta_optimizer,
    max_meta_iterations,
    inner_steps,
    inner_lr,
    task_batch_size,
    test_batch_size
)


In [None]:
# Evaluate MAML model
maml_evaluate(
    model=maml_model,
    loss_fn=maml_loss_fn,
    inner_steps=10,
    inner_lr=0.01,
    task_batch_size=10
)