<a href="https://colab.research.google.com/github/JoshStrong/MAML/blob/master/MAML_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Model Agnostic Meta Learning (MAML)

Make use of google colabs GPU for faster training via PyTorch. \\
Original MAML paper: https://arxiv.org/pdf/1703.03400.pdf

In [9]:
# import dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

import copy
import matplotlib.pyplot as plt

In [10]:
torch.cuda.get_device_name(0)

'Tesla K80'

In [11]:
!pip install torchviz
from graphviz import Digraph
import torch
from torch.autograd import Variable


# make_dot was moved to https://github.com/szagoruyko/pytorchviz
from torchviz import make_dot



# Implementing the original regression tasks

In [12]:
class Sine_Task():
  def __init__(self, amplitude, phase, xmin, xmax):
    self.amplitude = amplitude
    self.phase = phase
    self.xmin = xmin
    self.xmax = xmax

  def oracle(self, x):
    """
    Oracle: returns output of sin function with given amplitude, phase and input x

    PARAMETERS:
    1. x - input
    """
    return self.amplitude * np.sin(self.phase + x)

  def sample_data(self, size=1):
    """
      sample_data: sample input/output of given instance of Sine_Task with set variables

      PARAMETERS:
      1. size - amount of sampled data
    """
    x = torch.rand(size)*(self.xmax-self.xmin) - self.xmax
    y = self.oracle(x)
    x = x.unsqueeze(1).cuda()
    y = y.unsqueeze(1).cuda()

    return x, y

In [13]:
class Sine_Task_Distribution():
  def __init__(self, amplitude_min, amplitude_max, phase_min, phase_max, xmin, xmax):
    self.amplitude_min = amplitude_min
    self.amplitude_max = amplitude_max
    self.phase_min = phase_min
    self.phase_max = phase_max
    self.xmin = xmin
    self.xmax = xmax

  def sample_task(self):
    # Sample a random amplitude from the range []
    amplitude = np.random.uniform(self.amplitude_min, self.amplitude_max)
    phase = np.random.uniform(self.phase_min, self.phase_max)

    return Sine_Task(amplitude, phase, self.xmin, self.xmax)

In [14]:
class Regressor(nn.Module):
  def __init__(self):
    super(Regressor, self).__init__()
    self.fc1 = nn.Linear(1, 40)
    self.fc2 = nn.Linear(40, 40)
    self.fc3 = nn.Linear(40, 1)

  def forward(self, x):
    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)
    x = F.relu(x)
    x = self.fc3(x)
    return x

The (shortened) MAML algorithm:




*   Sample task (or a batch of tasks) $\mathcal{T}_i$.
*   Sample $D_i^{tr}, D_i^{test}$ from sampled task $\mathcal{T}_i$.
*   Inner Loop: Optimise meta-parameters $\theta$ on task $D_i$ to produce task-specific optimal parameters $\phi_i$: $\phi_i \leftarrow \theta - \alpha \nabla_\theta \mathcal{L}(\theta, D_i^{tr})$
*   Outer Loop: Update $\theta$ using stochastic gradient descent:
$\theta \leftarrow \theta - \beta \nabla_\theta \mathcal{L}(\phi, D_i^{test})$

where $\mathcal{L}(\cdot, \cdot)$ is the chosen loss function of the network.



Let $U(\theta, D^{tr}) := \phi = \theta - \alpha \nabla_\theta \mathcal{L}(\theta, D^{tr})$ denote the update rule used for optimising $\phi$.



The meta-optimisation objective is given as 
\begin{align*}
    \underset{\theta}{\min}\,\,\mathcal{L}(\phi, D^{test}) = \underset{\theta}{\min}\,\,\mathcal{L}(U(\theta, D^{tr}), D^{test}).
\end{align*}
We require $\frac{d}{d\theta}\mathcal{L}(\phi, D^{test})$
\begin{align*}
    \frac{d}{d\theta}\mathcal{L}(\phi, D^{test}) &= \frac{d}{d\theta}\mathcal{L}(U(\theta, D^{tr}), D^{test})\\
    &= \underbrace{\nabla_{\Theta}\mathcal{L}(\Theta, D^{test})|_{\Theta=U(\theta, D^{tr})}}_{(1)} \underbrace{\dfrac{d}{d \theta} U(\theta, D^{tr})}_{(2)}. && (\text{via chain rule})
\end{align*}

(1) is a row vector which can be computed through a single backwards pass of the network, when setting parameters to $\Theta$ then differentiating loss $\mathcal{L}$ with respects to $\Theta$.\\ The hessian matrix (2) is obtained through differentiating the update rule $U(\theta, D^{tr}) = \theta - \alpha \nabla_\theta \mathcal{L}(\theta, D^{tr})$ with respects to $\theta$:

In [15]:
class MAML():
  def __init__(self, model, tasks, inner_lr, meta_lr, K=10, inner_steps=1, tasks_per_meta_batch=1000, criterion=nn.MSELoss(reduction='mean')):
    self.tasks = tasks # Instance of Sine_Task_Distribution
    self.model = model # Meta-model
    self.criterion = criterion # Usually MSE
    self.meta_optimizer = torch.optim.SGD(model.parameters(), meta_lr) # Meta-optimiser for outer loop updates

    self.inner_lr = inner_lr
    self.meta_lr = meta_lr
    self.K = K 
    self.inner_steps = inner_steps
    self.tasks_per_meta_batch = tasks_per_meta_batch

  def inner_loop(self, task):
    """
        Computes inner-loop optimisation:
        PARAMETERS:
        1. task - takes an instantiation of tasks distribution

        ALGORITHM:
        1. Evaluate gradient of the loss of the task with respects to meta-parameters over K-examples of task
        2. Compute adapted parameters with gradient descent
      """
    # Sample K training data from this task for inner-loop training
    X_train, y_train = task.sample_data(self.K)

    # Store copy of model
    adapted_model = copy.deepcopy(self.model)
    for param in self.model.parameters():
      param.requires_grad = False
    for param in adapted_model.parameters():
      param.requires_grad = True

    # Initialise inner optimiser
    inner_opt = torch.optim.SGD(adapted_model.parameters(), self.inner_lr)

    # Perform inner_steps gradient steps (usually 1 to avoid nasty mathematics)
    for _ in range(self.inner_steps):
      
      # Zero gradients in case of multiple inner-loop gradient steps
      inner_opt.zero_grad()

      # Forward pass using meta-parameters
      y_pred = adapted_model(X_train)

      # Calculate loss
      loss = self.criterion(y_pred, y_train)

      # Compute adapted parameters in adapted_model
      # Pass create_graph=True to instruct model to keep original computation
      # graph around to allow for second-order derivatives to be calculated
      # for meta-training in the outer-loop.
      loss.backward(retain_graph=True, create_graph=True)
      inner_opt.step()
      
    # Now we have adapted parameters:
    # Compute loss using new data, sampled from the same task using
    # adapted parameters
    X_train, y_train = task.sample_data(self.K)
    y_pred = adapted_model(X_train)
    loss = self.criterion(y_pred, y_train)

    return loss
    
  def outer_loop(self, num_iterations):
    # use self.meta_optimizer so that every loop updates it for inner_loop function
    """
      Computes outer-loop optimisation:
      PARAMETERS:
      1. num_iterations - number of outer-loop gradient descent steps

      ALGORITHM:
      1. Sample batch of tasks for inner-loop optimisation
      2. Compute inner-loop
      3. Update meta-parameters using gradient descent

    """

    epoch_loss = 0

    for iteration in range(0, num_iterations):
      

      meta_loss = 0
      for i in range(self.tasks_per_meta_batch):
        # Generate task from task distribution
        task = self.tasks.sample_task()
        # Compute task specific loss with adapted gradients
        meta_loss += self.inner_loop(task) # do not need to add meta_loss ? should go in inner loop
      make_dot(meta_loss).view()

      # Compute meta gradient of "meta-loss" w.r.t. meta-parameters
      for param in self.model.parameters():
        param.requires_grad = True


      self.meta_optimizer.zero_grad()


      meta_loss.backward(retain_graph=True) # not computing gradients wrt meta parameters because they aren't in the graph
      #self.meta_optimizer.step()

In [17]:
torch.manual_seed(2)
model1 = Regressor().cuda()
a = Sine_Task_Distribution(0.1,5,0,np.pi,-5,5)
task = a.sample_task()
print(task.xmin)
print(task.xmax)
maml = MAML(model1, a, 0.01, 0.01, 10, 1 , 3)
for p in model1.parameters(): print(p,p.grad)

-5
5
Parameter containing:
tensor([[ 0.2294],
        [-0.2380],
        [ 0.2742],
        [-0.0511],
        [ 0.4272],
        [ 0.2381],
        [-0.1149],
        [-0.8085],
        [ 0.2283],
        [-0.8853],
        [ 0.1314],
        [ 0.0665],
        [-0.2199],
        [ 0.8177],
        [ 0.0667],
        [ 0.4147],
        [ 0.4232],
        [-0.5899],
        [-0.3844],
        [ 0.9617],
        [-0.9795],
        [-0.0679],
        [-0.0792],
        [ 0.7093],
        [-0.0951],
        [ 0.2633],
        [-0.0480],
        [-0.5599],
        [-0.5668],
        [-0.4858],
        [-0.9084],
        [-0.6490],
        [ 0.2353],
        [ 0.6581],
        [ 0.0493],
        [-0.4584],
        [ 0.4395],
        [-0.3839],
        [-0.2215],
        [-0.5482]], device='cuda:0', requires_grad=True) None
Parameter containing:
tensor([-0.3140, -0.9266,  0.4267,  0.3888,  0.1986,  0.4910,  0.4238,  0.0442,
         0.1059,  0.0764,  0.5336,  0.6717,  0.7181,  0.5796, -0.243