<a href="https://colab.research.google.com/github/JoshStrong/MAML/blob/master/CIFAR_FS_maml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install learn2learn

In [None]:
import random
import numpy as np

import torch
from torch import nn, optim

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import random

import copy

import learn2learn as l2l
from learn2learn.data.transforms import (NWays,
                                         KShots,
                                         LoadData,
                                         RemapLabels,
                                         ConsecutiveLabels)
from tqdm.notebook import tqdm

In [None]:
# GPU assigned by google colab for training
torch.cuda.get_device_name(0)

'Tesla T4'

## Acknowledgements

This code makes heavy use of the meta-learning software library learn2learn [1], which makes implementing meta-learning algorithms, such as MAML, much easier. Without learn2learn, implementing MAML requires "hacking" away with PyTorch functionality, which I decided to avoid for the 2nd experimental study to ensure bugless implementation of MAML when applying MAML to the more complex problem of a computer vision related task.



[1] Sebastien M.R. Arnold, Praateek Mahajan, Debajyoti Datta, Ian Bunner. "learn2learn". https://github.com/learnables/learn2learn, 2019

## Construct model to be trained, define helper functions 

In [None]:
?l2l.vision.models.ConvBase

In [None]:
# CNN model to be used in training
# Note how we seperate the head and the body of the network for potential ANIL training

class CifarCNN(torch.nn.Module):
    def __init__(self, output_size=5, hidden_size=32, layers=4):
        super(CifarCNN, self).__init__()
        self.hidden_size = hidden_size

        # Instantiates 4-layer CNN using l2l library
        # 3 rgb channels with 32 hidden size
        # 4 layers
        features = l2l.vision.models.ConvBase(
            output_size=hidden_size,
            hidden=hidden_size,
            channels=3,
            max_pool=False,
            layers=layers,
            max_pool_factor=0.5,
        )
        # Here we call Sequential on previous 4-layer CNN...
        self.features = torch.nn.Sequential(
            features,
            l2l.nn.Lambda(lambda x: x.mean(dim=[2, 3])),
            l2l.nn.Flatten(),
        )
        #... and connect with the 'head' of the model; a fully-connected layer
        self.linear = torch.nn.Linear(self.hidden_size, output_size, bias=True)
        l2l.vision.models.maml_init_(self.linear)

    # Performs forward-pass of the network
    def forward(self, x):
        x = self.features(x)
        x = self.linear(x)
        return x

In [None]:
model = CifarCNN()
model

CifarCNN(
  (features): Sequential(
    (0): ConvBase(
      (0): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (2): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (3): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
  

In [None]:
# Accuracy function to be used in evaluating models performance

def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)

## Define fast_adapt function for inner loop training

In [None]:
# inner loop for vision tasks

def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
    """
    fast_adapt: computes inner loop of MAML for computer vision tasks

    PARAMETERS:
    1. batch = batch of data to train inner loop on 
    2. learner = model to perform inner loop on
    3. loss = loss function to quantify performance
    4. adaptation_steps = number of inner loop gradient steps to be performed
    5. shots = number of data per class
    6. ways = number of classes
    7. device = gpu/cpu (default gpu)
    """
    # seperate data & labels from inputted batch of data
    data, labels = batch
    data, labels = data.to(device), labels.to(device)

    # Separate data into adaptation/evalutation sets with a 50:50 split
    # adaptation data = D^{tr}, evaluation data = D^{test}
    adaptation_indices = np.zeros(data.size(0), dtype=bool) # 50 Falses
    adaptation_indices[np.arange(shots*ways) * 2] = True # Set every even to True
    evaluation_indices = torch.from_numpy(~adaptation_indices) # Opposite of adaptation_indices, make array
    adaptation_indices = torch.from_numpy(adaptation_indices) # Tensor -> array
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    # Perform gradient update of task-specific model copy
    for step in range(adaptation_steps):
        prediction = learner(adaptation_data)
        # Compute loss over all data
        adaptation_error = loss(prediction, adaptation_labels)
        # Average the loss
        adaptation_error /= len(adaptation_data)
        # learner.adapt adjusts model parameters
        learner.adapt(adaptation_error)

    # Comput loss of adapted model to be used for outer-loop update
    # similar to above, except we do not learner.adapt, instead
    # accumulate gradients and perform a meta-update outside of this 
    # function
    predictions = learner(evaluation_data)
    evaluation_error = loss(predictions, evaluation_labels)
    evaluation_error /= len(evaluation_data)
    evaluation_accuracy = accuracy(predictions, evaluation_labels)
    return evaluation_error, evaluation_accuracy

## Use learn2learn library for easy access to CIFAR-FS dataset

In [None]:
# Use l2l framework to get CIFAR-FS data
# in a form useful for few-shot meta-learning
# use 2*shots in generating training samples
# since we need D^{tr} and D^{test}

shots=5
ways=5
tasksets = l2l.vision.benchmarks.get_tasksets(
        name='cifarfs',
        train_samples=2*shots,
        train_ways=ways,
        test_samples=2*shots,
        test_ways=ways,
        root='~/data',
    )

## Define CIFARFS_MAML function

CIFARS_MAML implements the main portion of the MAML algorithm for computer-vision related tasks.

In [None]:
def CIFARFS_MAML(ways=5, shots=5, meta_lr=0.003, inner_lr=0.5, meta_batch_size=32, adaptation_steps=1, num_iterations=10000):
    # Instantiate cifarCNN model, send to gpu
    model = CifarCNN()
    model.to('cuda')
    maml = l2l.algorithms.MAML(model, lr=inner_lr, first_order=False)

    # Use Adam optimiser and cross entropy loss
    opt = optim.Adam(maml.parameters(), meta_lr)
    loss = nn.CrossEntropyLoss(reduction='mean')

    # Meta-learning
    # Outer-loop
    # num_iterations = number of outer-loop updates
    for iteration in tqdm(range(num_iterations)):
        opt.zero_grad()
        meta_valid_error = 0.0
        meta_valid_accuracy = 0.0

        # For each sampled task T_i 
        for task in range(meta_batch_size):
            # Clone the meta-parameters to perform task-specific adaptation
            learner = maml.clone()
            batch = tasksets.train.sample()
            # Perform the inner loop with this batch of data from sampled task
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               'cuda')
            # Do a backwards pass to compute & accumulate gradients
            evaluation_error.backward()

            # Compute meta-validation loss
            learner = maml.clone()
            batch = tasksets.validation.sample()
            evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                               learner,
                                                               loss,
                                                               adaptation_steps,
                                                               shots,
                                                               ways,
                                                               'cuda')
            meta_valid_error += evaluation_error.item()
            meta_valid_accuracy += evaluation_accuracy.item()

        avg_meta_valid_err = meta_valid_error/meta_batch_size
        avg_meta_valid_acc = meta_valid_accuracy / meta_batch_size
        if iteration % 50 == 0:
          print('\n')
          print('Iteration', iteration)
          print('Average Meta Valid Error', avg_meta_valid_err)
          print('Average Meta Valid Accuracy', avg_meta_valid_acc)

        # Average the accumulated gradients from all tasks and optimize
        for p in maml.parameters():
            p.grad.data.mul_(1.0 / meta_batch_size)
        opt.step()

    return model

In [None]:
model = CIFARFS_MAML()

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

  if param.grad is not None:




Iteration 0
Average Meta Valid Error 0.06647501117549837
Average Meta Valid Accuracy 0.19999998807907104


Iteration 50
Average Meta Valid Error 0.06198663148097694
Average Meta Valid Accuracy 0.2687499930616468


Iteration 100
Average Meta Valid Error 0.06170825252775103
Average Meta Valid Accuracy 0.28749999264255166


Iteration 150
Average Meta Valid Error 0.06118645064998418
Average Meta Valid Accuracy 0.29874999169260263


Iteration 200
Average Meta Valid Error 0.05942488776054233
Average Meta Valid Accuracy 0.34874999057501554


Iteration 250
Average Meta Valid Error 0.06233933346811682
Average Meta Valid Accuracy 0.2974999900907278


Iteration 300
Average Meta Valid Error 0.06011559732723981
Average Meta Valid Accuracy 0.3387499905657023


Iteration 350
Average Meta Valid Error 0.05991319171153009
Average Meta Valid Accuracy 0.347499989438802


Iteration 400
Average Meta Valid Error 0.06033550656866282
Average Meta Valid Accuracy 0.3399999898392707


Iteration 450
Average Meta

In [None]:
# Save the meta-parameters to a file
# Need to clip 'module.' from parameter keys in order to reload the model first
maml_dict = model.state_dict()
prefix = 'module.'
n_clip = len(prefix)
adapted_dict = {k[n_clip:]: v for k, v in maml_dict.items()
                if k.startswith(prefix)}

PATH = "/content/model_CISFAR"
torch.save(adapted_dict, PATH)

In [None]:
# Load MAML saved parameters
PATH = "/content/model_CISFAR"
saved_model = CifarCNN()
saved_model.load_state_dict(torch.load(PATH), strict=False)
saved_model.to('cuda')

CifarCNN(
  (features): Sequential(
    (0): ConvBase(
      (0): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (2): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (3): ConvBlock(
        (normalize): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
  

## Perform fine-tuning, evaluate performance

In [None]:
meta_test_error = 0.0
meta_test_accuracy = 0.0
meta_batch_size = 32
meta_lr=0.003
inner_lr=0.5
adaptation_steps=1

opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')


maml = l2l.algorithms.MAML(model, lr=inner_lr, first_order=False)
for task in range(meta_batch_size):
  # Compute meta-testing loss
  learner = maml.clone()
  batch = tasksets.test.sample()
  evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                    learner,
                                                    loss,
                                                    adaptation_steps,
                                                    shots,
                                                    ways,
                                                    'cuda')
  meta_test_error += evaluation_error.item()
  meta_test_accuracy += evaluation_accuracy.item()
  print('Avg Meta Test Accuracy', round(meta_test_accuracy / meta_batch_size,2)*100,'%')

  if param.grad is not None:


Avg Meta Test Accuracy 2.0 %
Avg Meta Test Accuracy 4.0 %
Avg Meta Test Accuracy 6.0 %
Avg Meta Test Accuracy 9.0 %
Avg Meta Test Accuracy 11.0 %
Avg Meta Test Accuracy 13.0 %
Avg Meta Test Accuracy 15.0 %
Avg Meta Test Accuracy 17.0 %
Avg Meta Test Accuracy 19.0 %
Avg Meta Test Accuracy 21.0 %
Avg Meta Test Accuracy 23.0 %
Avg Meta Test Accuracy 25.0 %
Avg Meta Test Accuracy 28.000000000000004 %
Avg Meta Test Accuracy 30.0 %
Avg Meta Test Accuracy 32.0 %
Avg Meta Test Accuracy 34.0 %
Avg Meta Test Accuracy 36.0 %
Avg Meta Test Accuracy 38.0 %
Avg Meta Test Accuracy 39.0 %
Avg Meta Test Accuracy 41.0 %
Avg Meta Test Accuracy 43.0 %
Avg Meta Test Accuracy 44.0 %
Avg Meta Test Accuracy 46.0 %
Avg Meta Test Accuracy 47.0 %
Avg Meta Test Accuracy 49.0 %
Avg Meta Test Accuracy 51.0 %
Avg Meta Test Accuracy 54.0 %
Avg Meta Test Accuracy 56.00000000000001 %
Avg Meta Test Accuracy 57.99999999999999 %
Avg Meta Test Accuracy 60.0 %
Avg Meta Test Accuracy 62.0 %
Avg Meta Test Accuracy 64.0 %
