> ###### *UNIVERSITY OF PISA* - *M.Sc. Computer Science (Artificial Intelligence)*
>
> **Continual Learning 2022/23**
>
> Irene Pisani \
> Matricola 560104 \
> i.pisani1@studenti.unipi.it \



# **An implementation of  Model Agnostic Meta Learning (MAML) for few-shot supervised image classification**


---



## **Project objectives**

---





1.   Provide an implementation  of MAML for few shot supervised learning  using Pytorch.

2.   Reproduce the  original experiments on the common few-shot images recognition benchmark: the Omniglot dataset.  Follow the original experimental protocol for image classification: fast learning of N-way classification with 1 or 5 shots and N equal to 5 or 20.

3.   Performance comparison between the original and the obtained ones.

4.   Analyze the impact of the number of inner SGD step during training and evaluation phases.

Steps 1. and 2. are addressed in this notebook `MAML_Algorithm.ipynb`, while steps 3. and 4. are addressed in `MAML_ExperimentalResults.ipynb` .



In [10]:
# @title Import useful tools and libraries

import os
import os.path

import torch

from torch.func import vmap
from torch.func import grad
from torch.func import functional_call

from torch import nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
from torch.nn.modules import Module

import torchvision
import torchvision.transforms as transforms

from PIL import Image
from types import SimpleNamespace
from queue import SimpleQueue

import pandas as pd
import numpy as np

from google.colab import drive
drive.mount('/content/drive', force_remount = True)


Mounted at /content/drive



## **Omniglot Dataset**

---





The Omniglot dataset is a common benchmark for few-shot image recognition and it consists of 20 instances of 1623 characters (each istance is drawn by a different person) from 50 different alphabets.

**Dataset split**

The dataset was directly imported from `torchvision` where it is stored divided into background dataset and evaluation dataset. Since this partition does not reflect the one used by the authors the two datasets were aggregated and resplit following the splitting protocol suggested by the authors:
- Training set (TR) is composed of  1200  characters randomly sampled irrespectively to  the alphabet;
- Test set (TS) is composed of the remaining 423 characters.

In addition, since a Validation set (VL) is still required for performing model checkpointing, 100 characters from the TR set were used as VL set (i.e., the TR set on which the models were trained consists of 1100 characters).

**Dataset preproccessing**

As described in the original paper, downsampling to 28 × 28 was performed over all the images (image channels = 1) during data preprocessing.

Unlike the authors, the Omniglot dataset is not augmented with rotations of multiples of 90 degrees; the impact of this different choice will be discussed along with the experimental results.

In [11]:
#@title Load Dataset

import torchvision

def LoadDatasets(root, img_size):

  # --------------------- Load OMNIGLOT BACKGROUND -----------------------------

  background_dataset = torchvision.datasets.Omniglot(
      root = root,        # Directory for store dataset
      download = True,    # Download zip file
      background = True,  # Train set

      # Apply trasformation on dataset (es. resizing)
      transform  = torchvision.transforms.Compose([
          lambda x: x.convert('L'),
          lambda x: x.resize((img_size, img_size)),
          lambda x: np.reshape(x, (img_size, img_size, 1)),
          lambda x: np.transpose(x, [2, 0, 1]),
          lambda x: x/255.]))

  # Transform Backgorund dataset to dictionary
  dict_dataset = dict()

  # iterate over each (image, target) pairs contained in the dataset
  for img, target in background_dataset:

    # add target as key of dataset dictionary
    if target not in dict_dataset:
      dict_dataset[target] = []

    # add image as values of the key in dataset dictionary
    dict_dataset[target].append(img)

  background_dataset = []

  # iterate over each (key:target, value :img) pairs of the dictionary
  for y, x in dict_dataset.items():

    # add image as numpy array in the list dataset
    background_dataset.append(np.array(x).astype('float32'))

  # --------------------- Load OMNIGLOT EVALUATION -----------------------------

  evaluation_dataset = torchvision.datasets.Omniglot(
      root = root,        # Directory for store dataset
      download = True,    # Download zip file
      background = False,  # Test set

      # Apply trasformation on dataset - es. resizing -
      transform  = torchvision.transforms.Compose([
          lambda x: x.convert('L'),
          lambda x: x.resize((img_size, img_size)),
          lambda x: np.reshape(x, (img_size, img_size, 1)),
          lambda x: np.transpose(x, [2, 0, 1]),
          lambda x: x/255.]))

  # Transform Evaluation dataset to dictionary
  dict_dataset = dict()

  # iterate over each (image, target) pairs contained in the dataset
  for img, target in evaluation_dataset:

    # add target as key of dataset dictionary
    if target not in dict_dataset:
      dict_dataset[target] = []

    # add image as values of the key in dataset dictionary
    dict_dataset[target].append(img)

  evaluation_dataset = []

  # iterate over each (key:target, value :img) pairs of the dictionary
  for y, x in dict_dataset.items():

    # add image as numpy array in the list dataset
    evaluation_dataset.append(np.array(x).astype('float32'))

  ## --------------------- FULL DATASET and VALIDATION SPLIT -------------------

  # Transform to numpy array
  background_dataset = np.array(background_dataset).astype('float32')
  evaluation_dataset = np.array(evaluation_dataset).astype('float32')

  # Concatenation (background + evaluation)
  dataset = np.concatenate((background_dataset, evaluation_dataset), axis=0)

  # free memory: delete useless versions the dataset
  del dict_dataset
  del evaluation_dataset
  del background_dataset

  # Split full dataset in TR, VL and TS set
  dataset = {
      "trainval": dataset[:1200],
      "train": dataset[:1100],
      "val": dataset[1100:1200],
      "test": dataset[1200:]
      }

  return dataset

Due to initial  lack of understanding rather than due to complexity of the implementation, designing a meta-batch sampler take considerable time. Other implementations of MAML available on Github were used as reference point and they helped with proper development. In particular, the code concering the metabatch sampler is strictly inspired by  https://github.com/dragen1860/MAML-Pytorch - a popolar MAML implementation whose code developed for sampling a metabatch is highly reused among a wide number of repository.

The aim of `MetaBatchSampler` class is to prepare the dataset for being feed to to the base learner under the MAML algortihm settings for n-way k-shot supervised classification.

Given the hyper-parameter `n`, `k_support` (k -shot for support set), `k_query` (k-shot for query set), and `meta_batch_size`, each dataset split (TR, VL or TS) is treated has follow:

  - `load_metabatch_queque()` preload a queque of 10 metabatch.
  - `next()` gets the next the metabatch in the queque for passing it to the learner, refilling the queque if it's empty.











In [12]:
#@title Metabach Sampler
class MetaBatchSampler:

  def __init__(self, options, device = None):

    # get configuration values
    self.root = options.dataset_root
    self.meta_batch_size = options.meta_batch_size
    self.n_way = options.n_way
    self.k_support = options.k_support
    self.k_query = options.k_query
    self.img_size = options.img_size

    # device
    self.device = device

    # Load Omniglot dataset
    self.datasets = LoadDatasets(root = self.root, img_size = self.img_size)

    # dictionary of datasets splits: each one is made of a preloaded queque of metabatches
    self.metabatch_queque = {
        "trainval":  self.load_metabatch_queque(self.datasets["trainval"]),
        "train": self.load_metabatch_queque(self.datasets["train"]),
        "val": self.load_metabatch_queque(self.datasets["val"]),
        "test": self.load_metabatch_queque(self.datasets["test"])
        }

  # ------------- LOAD A QUEQUE OF METABATCHES ---------------------------------

  def load_metabatch_queque(self, dataset):

    # size of query set and support set
    support_set_size = self.k_support * self.n_way
    query_set_size   = self.k_query * self.n_way

    # initialize a queque of metabatches with a given num. of preloaded  metabatches
    metabatch_queque = SimpleQueue()
    preloaded_metabatch = 10

    # iterate over num of preloaded metabatch
    for _ in range(preloaded_metabatch):

      # initialze empty metabatch
      metabatch_x_spt, metabatch_y_spt, metabatch_x_qry, metabatch_y_qry = [], [], [], []

      # iterate over size of metabatch
      for i in range(self.meta_batch_size):

        # initialze empty support set and query set
        x_support, y_support, x_query, y_query = [], [], [], []

        # randomly sample n classes among all the available in the dataset
        sampled_class = np.random.choice(dataset.shape[0], self.n_way, False)

        # iterate over sampled classes
        for j, target in enumerate(sampled_class):

          # for each class randomly sample k support + k query images without replacement
          sampled_imgs = np.random.choice(dataset.shape[1], self.k_support + self.k_query, False)

          # add the first k sampled images to support set and the remaining to query set
          x_support.append(dataset[target][sampled_imgs[:self.k_support]])
          x_query.append(dataset[target][sampled_imgs[self.k_support:]])

          # store corresponding labels
          y_support.append([j]*self.k_support)
          y_query.append([j]*self.k_query)

        # shuffle the support set: [support_size, 1, 28, 28]
        perm = np.random.permutation(support_set_size)
        x_support = np.array(x_support).reshape(support_set_size, 1, self.img_size, self.img_size)[perm]
        y_support = np.array(y_support).reshape(support_set_size)[perm]

        # shuffle the query set: [query_size, 1, 28, 28]
        perm = np.random.permutation(query_set_size)
        x_query = np.array(x_query).reshape(query_set_size, 1, self.img_size, self.img_size)[perm]
        y_query = np.array(y_query).reshape(query_set_size)[perm]

        # append the created support and query sets to the metabatch
        metabatch_x_spt.append(x_support) # metabatch of x support
        metabatch_y_spt.append(y_support) # metabatch of y support
        metabatch_x_qry.append(x_query) # metabatch of x query
        metabatch_y_qry.append(y_query) # metabatch of y query

      # meta batch of support set: [meta batch size, support_set_size, 1, 28, 28]
      metabatch_x_spt = np.array(metabatch_x_spt).astype('float32').reshape(self.meta_batch_size, support_set_size, 1, self.img_size, self.img_size)
      metabatch_y_spt = np.array(metabatch_y_spt).astype(int).reshape(self.meta_batch_size, support_set_size)

      # meta batch of support set: [meta batch size, query_set_size, 1, 28, 28]
      metabatch_x_qry = np.array(metabatch_x_qry).astype('float32').reshape(self.meta_batch_size, query_set_size, 1, self.img_size, self.img_size)
      metabatch_y_qry = np.array(metabatch_y_qry).astype(int).reshape(self.meta_batch_size, query_set_size)

      # transform from numpy to pytorch tensor
      metabatch_x_spt, metabatch_y_spt, metabatch_x_qry, metabatch_y_qry = [
          torch.from_numpy(b).to(self.device) for b in [metabatch_x_spt, metabatch_y_spt, metabatch_x_qry, metabatch_y_qry]
      ]

      # put the created meta batch to the metabatch queque
      metabatch_queque.put([metabatch_x_spt, metabatch_y_spt, metabatch_x_qry, metabatch_y_qry])

    return metabatch_queque

  # ------------- GET NEXT METABATCH IN THE QUEQUE -----------------------------

  def next(self, mode='trainval'):

    # update queque if it is empty
    if self.metabatch_queque[mode].empty():
      self.metabatch_queque[mode] = self.load_metabatch_queque(self.datasets[mode])

    # get next metabatch, given the dataset split specified in 'mode' parameter
    next_metabatch = self.metabatch_queque[mode].get()

    return next_metabatch



## **Model**

---



The model follows the same architecture proposed by the author:
4  modules with a 3 × 3 strided convolutions and 64 filters, followed by batch normalization, and ReLU non linearity.The dimensionality of the last hidden layer is 64 and it is fed into a soft-max.

Stride = 2 and padding = 1 were used but note that this values may not coincide with the ones used by the author since they do not specify these values inside the paper. In fact, by looking at the original code https://pytorch.org/docs/stable/func.html I was not able to properly understand the stride value and to asses whatever padding were used or not.

Batch normalization is used with `track_running_stats = False`. The motivation come from some statements provided by Antreas et al (2019) in *How to train your MAML* and by Bronskill et al (2020) in *TASKNORM: Rethinking Batch Normalization for Meta-Learning*. They both claimed  that in the original MAML implementation, instead of accumulating running statistics, the statistics of the current batch were used for batch normalization also during evaluation phase.

`inplace = True` for ReLu computation was initially used to manage some memory issues, but in the end it results being useless. Some preliminary experiments show that this attribute does not affect the performance.


In [13]:
#@title Model architecture

class FinnCNN(nn.Module):

  def __init__(self, n_way, device):
    super().__init__()

    self.cnn = nn.Sequential(

        # 1° convolutional block
        nn.Conv2d(1, 64, 3, stride = 2, padding = 1),
        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
        nn.ReLU(inplace=True),

        # 2° convolutional block
        nn.Conv2d(64, 64, 3, stride = 2, padding = 1),
        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
        nn.ReLU(inplace=True),

        # 3° convolutional block
        nn.Conv2d(64, 64, 3, stride = 2, padding = 1),
        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
        nn.ReLU(inplace=True),

        # 4° convolutional block
        nn.Conv2d(64, 64, 3, stride = 2, padding = 0),
        nn.BatchNorm2d(64, affine=True, track_running_stats=False),
        nn.ReLU(inplace=True),

        # Classifier layer
        nn.Flatten(),
        nn.Linear(64, n_way)).to(device)

  def forward(self, x):
    x = self.cnn(x)
    return x

## **MAML: Model Agnostic Meta Learning**

---



The `MAML()` class performs the core of the algorithm.

All the required parameter are given as input together with a preloaded queque of metabatch.
It initialize a base learner (by calling `FinnCNN()`) and a gradient-based meta optimizer (Adam).

By calling the `MAML.fit_and_evaluate()` method, the model is trained on TR set and evaluated on VL set.

**Model checkpointing** is used to find the best model monitoring the validation loss during training - best model has the lowest validation loss.
Note that model selection phase with a grid search over the hyperparameter space was not performed due to computational and time constraint. All the hyperparamter involved are the same ones mentioned in the original paper.
Anyway, model checkpointing allow to save the best model among the iterations and this technique is required since for 20-way classification significant training oscillations have been observed (they will be further disussed in experimental results section).

By calling the `MAML.test()` method, best model's performance are assessed on TS set.  




**Implementation details with torch.func**

Most of the core steps of MAML were implemented exploiting some torch.func functionalities (note that functorch is currently migrated on torch.func) and following [PyTorch](https://pytorch.org/docs/stable/index.html) and [torch.func](https://pytorch.org/docs/stable/func.html) available documentations.

*  `train()` function perform MAML training over a single epoch (during 1 epochs the algorithm is trained with ` total number of class in TR set // meta-batch size` distinct metabatches).

  * Run in parallel the inner loop on each task in the metabatch, in order to efficiently get the query losses required for the meta/outer update, using:
  ```
  # Compute outer losses mappping inner loop function to each task (with corresponding support and query set)
  all_inner_loss, all_inner_acc, all_loss, all_acc = vmap(self.inner_loop)(x_support, y_support, x_query, y_query)
  ```

  * In the inner loop fit the model to a task - get adapted paramter θ' starting from θ - without forgetting the current θ parameter values:  
  ```
  # get model's output (on support set) by replacing parameters and buffers
  logits = functional_call(self.model, (params, buffers), x_support)
  ```
  allow to compute the output of the models with the specified parameters and  buffers  instead of  the current ones. In this way, the θ parameter that need to be updated in the outer loop were still remembered.
  
  * In addition, when adapting the model to a task using SGD step - get adapted paramter θ' - it's necessary to explicitly computing the gradient of the support loss via
  ```
  # compute gradients of inner loss
  grad(compute_support_loss)(params, buffers, support_x, support_y)
  ```
  and use it in SGD step. `backward()` method is not possible here since it would overwrite θ parameter instead of θ' of each task. Also SGD update rules was explicitly computed.

* `evaluation()` function is used to asses model performance on TS or VL dataset. Here, the model is evaluted after being adapted to support sets with some SGD steps (note that the number of SGD inner step for training and evaluation could be different) and the performance of the query sets are returned.


In [14]:
#@title Model Agnostic Meta Learning

class MAML():

  # ------------ INITIALIZATION ------------------------------------------------

  def __init__(self, dataset, options, device, path):

    self.path = path # model checkpoint (for saving or load)

    self.meta_lr = options.meta_lr   # beta learning rate in outer loop
    self.inner_lr = options.inner_lr # alfa learning rate in inner loop
    self.tr_inner_steps = options.inner_step # optimization step in inner loop during training
    self.all_ts_inner_steps = options.ts_inner_step # all considered optimization step in inner loop during evaluation
    self.ts_inner_steps = 3 # set 3 for 5-way and set 5 for 20-way : deafault optimization step in inner loop during evaluation
    self.iterations = options.iterations  # epochs

    # Omniglot datasets (TR, VL, TS) already divided in metabatch
    self.dataset = dataset

    # Base Learner: Convolutional NN described by Finn et. al
    self.model = FinnCNN(options.n_way, device)

    # Meta optimizer: Adam
    self.meta_optimizer = optim.Adam(self.model.parameters(), lr = self.meta_lr)

    # Num. of metabatch for each dataset split
    self.num_tr_metabatch = self.dataset.datasets['train'].shape[0] // self.dataset.meta_batch_size
    self.num_trvl_metabatch = self.dataset.datasets['trainval'].shape[0] // self.dataset.meta_batch_size
    self.num_vl_metabatch = self.dataset.datasets['val'].shape[0] // self.dataset.meta_batch_size
    self.num_ts_metabatch = self.dataset.datasets['test'].shape[0] // self.dataset.meta_batch_size

    # initialize optimal loss for model checkpointing
    self.optimal_loss = 100000

    # Store performance history
    self.tr_history = {
        "tr_loss" : [],
        "tr_accuracy": [],
        "tr_inner_loss": [],
        "tr_inner_accuracy" : []
        }
    self.vl_history = {
        "vl_loss" : [],
        "vl_accuracy": [],
        "vl_inner_loss": [],
        "vl_inner_accuracy" : [],

        }
    self.ts_history = {
        "ts_inner_steps":[],
        "ts_loss" : [],
        "ts_accuracy" : [],
        "ts_inner_loss" : [],
        "ts_inner_accuracy" : [],
    }

  def get_inner_loss(self, params, buffers, x, y):

    # Get model's output on support set by replacing parameters and buffers
    logits = functional_call(self.model, (params, buffers), x)

    # compute cross entropy loss function on support set
    loss = F.cross_entropy(logits, y)

    return loss


  # ------------------------ TRAIN - INNER LOOP on SINGLE TASK -----------------

  def inner_loop(self, x_support, y_support, x_query, y_query):

    # store model parameters and buffers
    params = dict(self.model.named_parameters())
    buffers = dict(self.model.named_buffers())

    # Size of query set
    query_size = x_query.size(dim = 0)
    support_size = x_support.size(dim = 0)

    # iterate over inner steps
    for step in range(self.tr_inner_steps):

      if step + 1 == self.tr_inner_steps:
        # get model's output (on support set) by replacing parameters and buffers
        logits = functional_call(self.model, (params, buffers), x_support)

        # compute cross entropy loss function on support set
        support_loss = F.cross_entropy(logits, y_support)

        # compute accuracy on support set
        support_acc = (logits.argmax(dim=1) == y_support).sum() / support_size

      # compute gradients of inner loss
      grads = grad(self.get_inner_loss)(params, buffers, x_support, y_support)

      # compute adapted parameters with gradient descent:
      # (!) params = params - alfa * gridient of the loss
      params = {k: params[k] - g * self.inner_lr for k, g, in grads.items()}

    # get model's output (on query set) by replacing parameters and buffers
    logits = functional_call(self.model, (params, buffers), x_query)

    # compute cross entropy loss function on query set
    query_loss = F.cross_entropy(logits, y_query)

    # compute accuracy on query set
    query_acc = (logits.argmax(dim=1) == y_query).sum() / query_size


    return support_loss, support_acc, query_loss, query_acc


  # ------------------------ TRAIN - OUTER LOOP  -------------------------------

  def train(self, retrain = True):

    # set model in training state
    self.model.train()

    # initialize metrics and loss for the whole epoch
    meta_loss, meta_acc = 0, 0
    inner_loss, inner_acc = 0, 0

    # set train dataset (TR or TR+VL) and num metabatches
    if retrain:
      mode_train = 'trainval'
      num_metabatch = self.num_trvl_metabatch
    else:
      mode_train = 'train'
      num_metabatch = self.num_tr_metabatch

    # iterate over total number of metabach per iterations
    for i in range(num_metabatch):

      # Sample a single metabatch (inner + query) from train set
      x_support, y_support, x_query, y_query = self.dataset.next(mode = mode_train)

      # total number of tasks in the sampled metabatch
      num_tasks = x_support.size(dim = 0)

      # set gradients to zero before updating
      self.meta_optimizer.zero_grad()

      # (!) for each task in the current metabach train one model (inner loop on single task)

      # Compute outer losses mappping inner loop function to each task (with corresponding support and query set)
      all_inner_loss, all_inner_acc, all_loss, all_acc = vmap(self.inner_loop)(x_support, y_support, x_query, y_query)

      # add chunk_size = 1 size to perform vmap in a for-fashioned loop (to avoid cuda out of memory issues)
      #all_inner_loss, all_inner_acc, all_loss, all_acc = vmap(self.inner_loop, chunk_size = 1)(x_support, y_support, x_query, y_query)

      # Compute gradients of the meta (query) losses
      all_loss.sum().backward()

      # Meta update: update model parameters
      self.meta_optimizer.step()

      all_inner_loss = all_inner_loss.detach().sum() / num_tasks
      all_inner_acc = 100. * all_inner_acc.sum() / num_tasks

      # Compute mean meta-loss and mean-meta accuracy over tasks
      all_loss = all_loss.detach().sum() / num_tasks
      all_acc = 100. * all_acc.sum() / num_tasks

      # update loss and accuracy after each metabach
      meta_loss += all_loss.item()
      meta_acc += all_acc.item()
      inner_loss += all_inner_loss.item()
      inner_acc += all_inner_acc.item()

    # Get meta-loss and meta-accuracy of the overall iteration
    inner_loss = inner_loss / num_metabatch
    inner_acc = inner_acc / num_metabatch
    meta_loss = meta_loss / num_metabatch
    meta_acc = meta_acc / num_metabatch

    return inner_loss, inner_acc, meta_loss, meta_acc


  # ------------------------ EVALUATION on VL/TS -------------------------------

  def evaluation(self, mode_eval):

    # set model in evaluation status
    self.model.eval()

    # initialize array for storing perforamance
    qry_losses = []
    qry_accs = []
    spt_losses = []
    spt_accs = []

    # choose dataset for evaluation
    if mode_eval == "val":
        num_metabatch = self.num_vl_metabatch
    else:
        num_metabatch = self.num_ts_metabatch

    # iterate over the number of metabatch
    for i in range(num_metabatch):

      # sample a single metabatch (inner + query) from train set
      x_support, y_support, x_query, y_query = self.dataset.next(mode = mode_eval)

      # total number of tasks in the sampled metabatch
      task_num = x_support.size(dim = 0)

      # iterate over the number of tasks
      for j in range(task_num):

        # Inner loop on single task

        # get original model parameter and buffer
        new_params = dict(self.model.named_parameters())
        buffers = dict(self.model.named_buffers())

        # Adapt model to the task by iterating over inner steps
        for _ in range(self.ts_inner_steps):

          # get model's output on support set replacing paramters and buffers
          spt_logits = functional_call(self.model, (new_params, buffers), x_support[j])

          # compute inner loss function on support set
          spt_loss = F.cross_entropy(spt_logits, y_support[j])

          # compute gradients of inner loss function
          grads = torch.autograd.grad(spt_loss, new_params.values())

          # update model parameter using stocastic gradient descent
          new_params = {k: new_params[k] - g * self.inner_lr for k, g, in zip(new_params, grads)}

        # compute logits on query set using adapted parameter
        qry_logits = functional_call(self.model, (new_params, buffers), x_query[j]).detach()

        # compute loss on query set
        qry_loss = F.cross_entropy(qry_logits, y_query[j], reduction='none')

        # add support loss and accuracy to history
        spt_losses.append(spt_loss.detach())
        spt_accs.append((spt_logits.argmax(dim=1) == y_support[j]).detach())

        # add query loss and accuracy to history
        qry_losses.append(qry_loss.detach())
        qry_accs.append((qry_logits.argmax(dim=1) == y_query[j]).detach())

    # get mean inner performance over metabatches
    spt_losses = torch.mean(torch.stack(spt_losses)).item()
    spt_accs = 100. * torch.mean(torch.stack(spt_accs).float()).item()

    # get mean meta performance over metabatches
    qry_losses = torch.cat(qry_losses).mean().item()
    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()

    return  spt_losses, spt_accs, qry_losses, qry_accs


  # ------------------------ FIT MAML to DATASET -------------------------------

  def fit_and_evaluate(self):

    # ---> RUN TRAINING on TRAIN SET and VALIDATE on VALIDATION SET
    #self.model.train()
    for iteration in range(self.iterations):

      # train on Train Set (no final retraining on TR+VL) and evaluate on Validation set
      tr_inner_loss, tr_inner_acc, tr_loss, tr_acc = self.train(retrain = False)
      vl_inner_loss, vl_inner_acc, vl_loss, vl_acc = self.evaluation(mode_eval = "val")

      # add performance to history
      self.tr_history["tr_accuracy"].append(tr_acc)
      self.vl_history["vl_accuracy"].append(vl_acc)
      self.tr_history["tr_loss"].append(tr_loss)
      self.vl_history["vl_loss"].append(vl_loss)
      self.tr_history["tr_inner_loss"].append(tr_inner_loss)
      self.tr_history["tr_inner_accuracy"].append(tr_inner_acc)
      self.vl_history["vl_inner_loss"].append(vl_inner_loss)
      self.vl_history["vl_inner_accuracy"].append(vl_inner_acc)

      '''
      (!) Typically Hold-Out Validation technique requires a final retraining on full TR+VL set
      Here final retraining has been not executed due to time constraints
      (Training phase is too time-consuming)
      VL set is still useful model checkponinting
      '''

      # show performance to monitor training
      if iteration % 10 == 0 or iteration == 99:
        print(
            f'[Epoch {iteration}] | ',
            f'[TR] Meta Loss: {tr_loss:.2f} - Meta Acc: {tr_acc:.2f} -',
            f'Inner Loss: {tr_inner_loss:.2f} -  Inner Acc: {tr_inner_acc:.2f} |',
            f'[VL] Meta Loss: {vl_loss:.2f} -  Meta Acc: {vl_acc:.2f} -',
            f'Inner Loss: {vl_inner_loss:.2f} - Inner Acc: {vl_inner_acc:.2f} |'
            )

      # model checkpointing: save model if the val loss is the lowest so far
      if vl_loss <= self.optimal_loss:

        # update optimal loss
        self.optimal_loss = vl_loss

        # save model
        torch.save({
                'epoch': iteration,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.meta_optimizer.state_dict(),
                'loss': vl_loss,
                }, self.path)


    return self.tr_history, self.vl_history

  # ------------------------ MAKE INFERENCE on TEST SET ------------------------

  def test(self, more_ts_inner_step = True):

    # load best model from checkpoint
    checkpoint = torch.load(self.path)
    self.model.load_state_dict(checkpoint['model_state_dict'])
    self.meta_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if more_ts_inner_step:
      # Make inference with different number of inner step

      for innerstep in self.all_ts_inner_steps:
        self.ts_inner_steps = innerstep

        for _ in range(10): # Make inference 10 times

          # make inference
          ts_inner_loss, ts_inner_acc, ts_loss, ts_acc = self.evaluation(mode_eval ="test")

          # save test performance
          self.ts_history["ts_inner_steps"].append(self.ts_inner_steps)
          self.ts_history["ts_loss"].append(ts_loss)
          self.ts_history["ts_accuracy"].append(ts_acc)
          self.ts_history["ts_inner_loss"].append(ts_inner_loss)
          self.ts_history["ts_inner_accuracy"].append(ts_inner_acc)

      return self.ts_history

    else: # make inference just for deafult ts_inner_step

      # make inference
      ts_inner_loss, ts_inner_acc, ts_loss, ts_acc = self.evaluation(mode_eval ="test")

      # show test performance
      print(
          f'inner-step: {self.ts_inner_steps}|',
          f'Meta-Test Loss: {ts_loss:.2f} | Meta-Test accuracy: {ts_acc:.2f} | ',
          f'Inner-Test Loss: {ts_inner_loss:.2f} | Inner-Test Acc: {ts_inner_acc:.2f} | '
            )

      return


## **Experiments**

---



| Hyper-Parameters or config. | 5 way | 20 way |
|-----|--------|---|
| k shot for support| 5 or 1 | 5 or 1 |
| k shot for query |  20 - k shot for support | 20 - k shot for support|
| Inner and Outer Loss | Cross Entropy | Cross Entropy |
| Outer optimizer | Adam | Adam |
| Outer learning rate | 1e-3 | 1e-3 |
| Inner optimizer | SGD | SGD |
| Inner learning rate | 0.4 | 0.1 |
| Meta batch size | 32 | 16 |
| Deafult inner step (trainining) | 1 | 5 |
| Deafult inner step (evaluation) | 3 | 5 |
| Max epochs | 100 | 100 |
| image input size | 28 | 28 |
| augmentation with rotations | no | no |
| image input channel | 1 | 1 |
| Convolutional Model | Finn et. al (2017) | Finn et. al (2017) |










All the model were trained over 100 epochs according to the experiments carried out in *How to train your MAML*, since training the models over 60000 iterations - like in the original paper by Finn (2017) - would have been too expensive.

In [15]:
#@title Configuration and hyperparameters set-up
options = SimpleNamespace(

    # path to data
    dataset_root='.' + os.sep + 'dataset',

    # number of classes for n way classification
    n_way = 5, # 5 or 20

    # k shot for support set
    k_support = 5, # 5 or 1

    # k shot for query set
    k_query = 15,

    # Number of tasks per meta batch
    meta_batch_size = 32, # 16 for 20-way and 32 for 5-way

    # Meta learning rate (alfa)
    meta_lr = 1e-3,

    # Inner learning rate (beta)
    inner_lr = 0.4, #0.4 for 5-way and 0.1 for 20-way

    # Image size for resizing
    img_size = 28,

    # Iterations
    iterations = 100,

    # Inner steps  [tested values for 5-way: 1 3 5]
    inner_step = 1, # deafult: 1 for 5-way and 5 for 20-way

    # Eval inner steps [tested values: 1 3 5 8]
    ts_inner_step = [1, 3, 5, 10], # deafult 3 for 5-way, 5 for 20-way

    # seed
    seed = 1,

)

In [16]:
# Specify seed and device

torch.manual_seed(options.seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(options.seed)
np.random.seed(options.seed)

if torch.cuda.is_available():
  device = 'cuda'
else:
  device= 'cpu'


In [17]:
# Initialize the metabatch sampler for loading and preparing Omniglot dataset
data_sampler = MetaBatchSampler(options, device)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
# path to load or store model
checkpoint_path  = str(f'/content/drive/MyDrive/MAML/{options.n_way}way/{options.k_support}shot/Model/NEW{options.inner_step}innerstep.pt')

# path for storing performance
history_train = str(f'/content/drive/MyDrive/MAML/{options.n_way}way/{options.k_support}shot/TR_{options.inner_step}innerstep.csv')
history_val = str(f'/content/drive/MyDrive/MAML/{options.n_way}way/{options.k_support}shot/VL_{options.inner_step}innerstep.csv')
history_test  = str(f'/content/drive/MyDrive/MAML/{options.n_way}way/{options.k_support}shot/TS_{options.inner_step}-{options.ts_inner_step}innerstep.csv')

# initialize MAML
maml = MAML(data_sampler, options, device, checkpoint_path)

# train on TR and evaluate on VL
tr_history, vl_history = maml.fit_and_evaluate()

# test on TS with deafult number of SGD inner step for evaluation
maml.test(more_ts_inner_step = False)

# test on TS with different number of SGD inner step for evaluation
ts_history = maml.test()


# Save TR performance in csv files for further analysis
df = pd.DataFrame(tr_history)
df.to_csv(history_train)

# Save VL performance in csv files for further analysis
df = pd.DataFrame(vl_history)
df.to_csv(history_val)

# Save TS performance in csv files for further analysis
df = pd.DataFrame(ts_history)
df.to_csv(history_test)

[Epoch 0] |  [TR] Meta Loss: 0.69 - Meta Acc: 81.27 - Inner Loss: 1.65 -  Inner Acc: 19.78 | [VL] Meta Loss: 0.42 -  Meta Acc: 86.92 - Inner Loss: 0.13 - Inner Acc: 99.75 |
[Epoch 10] |  [TR] Meta Loss: 0.07 - Meta Acc: 98.01 - Inner Loss: 1.67 -  Inner Acc: 20.42 | [VL] Meta Loss: 0.14 -  Meta Acc: 95.22 - Inner Loss: 0.01 - Inner Acc: 100.00 |
[Epoch 20] |  [TR] Meta Loss: 0.04 - Meta Acc: 98.66 - Inner Loss: 1.68 -  Inner Acc: 20.24 | [VL] Meta Loss: 0.12 -  Meta Acc: 95.58 - Inner Loss: 0.01 - Inner Acc: 100.00 |
[Epoch 30] |  [TR] Meta Loss: 0.03 - Meta Acc: 98.95 - Inner Loss: 1.70 -  Inner Acc: 20.37 | [VL] Meta Loss: 0.14 -  Meta Acc: 95.42 - Inner Loss: 0.00 - Inner Acc: 99.96 |
[Epoch 40] |  [TR] Meta Loss: 0.03 - Meta Acc: 99.22 - Inner Loss: 1.70 -  Inner Acc: 20.20 | [VL] Meta Loss: 0.13 -  Meta Acc: 95.57 - Inner Loss: 0.00 - Inner Acc: 100.00 |
[Epoch 50] |  [TR] Meta Loss: 0.03 - Meta Acc: 99.15 - Inner Loss: 1.73 -  Inner Acc: 19.63 | [VL] Meta Loss: 0.10 -  Meta Acc: 

Experimental results are further explored and discussed in `MAML_ExperimalResults.ipynb` notebook.



## **References**


---





*  [Chelsea Finn, Pieter Abbeel, and Sergey Levine. *Model-agnostic meta-learning for fast adaptation of deep networks.* (2017).](https://arxiv.org/pdf/1703.03400.pdf)
*   [Antreas Antoniou, Amos Storkey, Harrison Edwards. *How to train your MAML.* (2019).](https://arxiv.org/pdf/1810.09502.pdf)
*   [Bronskill, Gordon, Requeima,  Nowozin, Turner. *TaskNorm: Rethinking Batch Normalization for Meta-Learning*. (2020).](https://arxiv.org/pdf/2003.03284.pdf)
*[Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra,
Timothy Lillicrap. *Meta-Learning with Memory-Augmented Neural Networks.* (2016).](https://proceedings.mlr.press/v48/santoro16.pdf)
*   https://github.com/cbfinn/maml
*   https://github.com/dragen1860/MAML-Pytorch
*   https://pytorch.org/docs/stable/func.html



