In [1]:
## Standard libraries
import os
import numpy as np
import random
import json

from collections import defaultdict
from statistics import mean, stdev
from copy import deepcopy

## tqdm for loading bars
from tqdm.auto import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Setting the seed
pl.seed_everything(42)
CHECKPOINT_PATH = "../saved_models"

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

Global seed set to 42


Device: cpu


## Few-shot classification

We start our implementation by discussing the dataset setup. In this notebook, we will use 16 different NLP tasks. Each task may have different input (e.g. one or two sentences) and different number of output classes (e.g. 'entailment', 'neutral', 'contradiction' for MNLI, and 'positive'/'negative' for SST2). Instead of splitting the training, validation, and test set over examples, we will split them over datasets: we will use 5 classes for training, and 5 for validation, and 6 for testing. Our overall goal is to obtain a model that can classify the samples in the test datasets with seeing very few examples. 

### Data preprocessing

First, let's define functions to load the datasets. Next, we need to prepare the dataset in the training, validation and test split as mentioned before. Huggingface gives us the training, validation and test set as separate dataset objects. The next code cells will merge the separate training and validation datasets, and then create the new train-val-test split.

In [2]:
from dataloader import load_mnli, load_qqp, load_sst2, load_boolq, load_cb

In [3]:
# Define the datasets that we are using for metalearning


# TODO: Add other dataloaders
# In the meantime just use one dataset for testing purposes

DATALOADERS = {
#    "mnli": load_mnli(),
#    "qqp": load_qqp(),
#    "sst": load_sst2(),
#    "wgrande": 
#    "boolq": load_boolq(),
#    "imdb": 
#    "hswag":
#    "mrpc":
#    "argument": 
#    "scitail": 
#    "sociqa": 
#    "cosqa": 
#    "csqa": 
#    "sick": 
#    "rte": 
    "cb": load_cb()
}



Reusing dataset super_glue (/Users/FrankVerhoef/.cache/huggingface/datasets/super_glue/cb/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/FrankVerhoef/.cache/huggingface/datasets/super_glue/cb/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7/cache-f6a9e244a16fe613.arrow
Loading cached processed dataset at /Users/FrankVerhoef/.cache/huggingface/datasets/super_glue/cb/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7/cache-9038dd7cd8f9c87e.arrow
Loading cached processed dataset at /Users/FrankVerhoef/.cache/huggingface/datasets/super_glue/cb/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7/cache-41619ee74b27d2c7.arrow


In [4]:
def combine_train_valid(name):
    
    (ds, id2label), key = DATALOADERS[name]

    # combine input data from train and validation set
    all_inputs = torch.cat((ds["train"]["input_ids"], ds[key]["input_ids"]), dim=0)
    all_token_types = torch.cat((ds["train"]["token_type_ids"], ds[key]["token_type_ids"]), dim=0)
    all_masks = torch.cat((ds["train"]["attention_mask"], ds[key]["attention_mask"]), dim=0)
    all_labels = torch.cat((ds["train"]["labels"], ds[key]["labels"]), dim=0)
    
    return {
        "input_ids": all_inputs, 
        "token_type_ids": all_token_types, 
        "attention_mask": all_masks,
        "labels": all_labels
    }

DATASETS = {ds: combine_train_valid(ds) for ds in DATALOADERS.keys()}
TASK_IDS = {name: id for id, name in enumerate(DATASETS.keys())}

In [5]:
combined_dataset = {
    "tasks": torch.hstack([
        torch.tensor([TASK_IDS[ds_name]] * len(ds["labels"]))
        for ds_name, ds in DATASETS.items()
    ]),
    "input_ids": torch.vstack([ds["input_ids"] for ds in DATASETS.values()]),
    "token_type_ids": torch.vstack([ds["token_type_ids"] for ds in DATASETS.values()]),
    "attention_mask": torch.vstack([ds["attention_mask"] for ds in DATASETS.values()]),
    "labels": torch.hstack([ds["labels"] for ds in DATASETS.values()])
}

In [6]:
class NLPDataset(data.Dataset):

    def __init__(self, tasks, input_ids, token_type_ids, attention_mask, labels):

        super().__init__()
        self.tasks = tasks
        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __getitem__(self, idx):
        task = self.tasks[idx]
        x = (self.input_ids[idx], self.token_type_ids[idx], self.attention_mask[idx])
        label = self.labels[idx]
        return task, x, label

    def __len__(self):
        return self.inputs.shape[0]


In [7]:
def dataset_from_tasks(dataset, tasks, **kwargs):
    
    task_mask = (dataset["tasks"][:,None] == tasks[None,:]).any(dim=-1)
    dataset = NLPDataset(
        tasks = dataset["tasks"][task_mask], 
        input_ids = dataset["input_ids"][task_mask], 
        token_type_ids = dataset["token_type_ids"][task_mask], 
        attention_mask = dataset["attention_mask"][task_mask], 
        labels = dataset["labels"][task_mask], 
        **kwargs
    )
    
    return dataset

In [8]:
# For testing purposes just use one small dataset for everything
# TODO: Check memory usage! Loading everything caused crash on Colab :(

#train_datasets = ["mnli", "qqp"]
#val_datasets = ["sst", "boolq"]
train_datasets = ["cb"]
val_datasets = ["cb"]
test_datasets = ["cb"]

#train_set = dataset_from_tasks(combined_dataset, torch.tensor([TASK_IDS[ds] for ds in train_datasets]))
#val_set = dataset_from_tasks(combined_dataset, torch.tensor([TASK_IDS[ds] for ds in val_datasets]))
train_set = dataset_from_tasks(combined_dataset, torch.tensor([TASK_IDS[ds] for ds in test_datasets]))
val_set = dataset_from_tasks(combined_dataset, torch.tensor([TASK_IDS[ds] for ds in test_datasets]))
test_set = dataset_from_tasks(combined_dataset, torch.tensor([TASK_IDS[ds] for ds in test_datasets]))



In [9]:
test_set[0]

(tensor(0),
 (tensor([  101,  2009,  2001,  1037,  3375,  2653,  1012,  2025,  2517,  2091,
           2021,  4375,  2091,  1012,  2028,  2453,  2360,  2009,  2001, 20956,
           2091,  1012,   102,  1996,  2653,  2001, 20956,  2091,   102,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     

### Data sampling

The strategy of how to use the available training data for learning few-shot adaptation is crucial in meta-learning. At each training step, we randomly select a small number of classes and sample a small number of examples for each class. For the NLP tasks that we have selected, the number of classes is either 2 or 3.

This represents our few-shot training batch, which we also refer to as **support set**. Additionally, we sample a second set of examples from the same classes and refer to this batch as **query set**. Our training objective is to classify the query set correctly from seeing the support set and its corresponding labels. The main difference between our three methods (ProtoNet, MAML, and Proto-MAML) is in how they use the support set to adapt to the training classes.

This subsection summarizes the code that is needed to create such training batches. In PyTorch, we can specify the data sampling procedure by so-called `Sampler` ([documentation](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler)). Samplers are iterable objects that return indices in the order in which the data elements should be sampled. In our previous notebooks, we usually used the option `shuffle=True` in the `data.DataLoader` objects which creates a sampler returning the data indices in random order. Here, we focus on samplers that return batches of indices that correspond to support and query set batches. Below, we implement such a sampler.

In [10]:
class FewShotBatchSampler(object):

    def __init__(self, dataset_tasks, dataset_targets, N_way, K_shot, include_query=False, shuffle=True, shuffle_once=False):
        """
        Inputs:
            dataset_tasks - PyTorch tensor of the id's of the tasks in the dataset.
            dataset_classes - PyTorch tensor of the classes in the dataset
            N_way - Number of classes to sample per batch.
            K_shot - Number of examples to sample per class in the batch.
            include_query - If True, returns batch of size N_way*K_shot*2, which 
                            can be split into support and query set. Simplifies
                            the implementation of sampling the same classes but 
                            distinct examples for support and query set.
            shuffle - If True, examples and classes are newly shuffled in each
                      iteration (for training)
            shuffle_once - If True, examples and classes are shuffled once in 
                           the beginning, but kept constant across iterations 
                           (for validation)
        """
        super().__init__()
        self.dataset_tasks = dataset_tasks
        self.dataset_targets = dataset_targets
        self.dataset_task_targets = torch.cat((dataset_tasks.unsqueeze(dim=1), dataset_targets.unsqueeze(dim=1)), dim=1)
        self.N_way = N_way
        self.K_shot = K_shot
        self.shuffle = shuffle
        self.include_query = include_query
        if self.include_query:
            self.K_shot *= 2
        self.batch_size = self.N_way * self.K_shot  # Number of overall samples per batch

        # Organize examples by task and class
        self.tasks = torch.unique(self.dataset_tasks).tolist()
        self.classes = {}
        self.num_classes = {}
        self.indices_per_task = {}
        self.batches_per_task = {}
        self.indices_per_class = {}
        self.batches_per_class = {}
        for t in self.tasks:
            self.indices_per_task[t] = torch.where(self.dataset_tasks == t)[0].tolist()
#            print("Indices for task {}: {}".format(t, self.indices_per_task[t]))
#            print("Classes for task {}: {}".format(t, torch.unique(self.dataset_targets[self.indices_per_task[t]]).tolist()))
            self.classes[t] = torch.unique(self.dataset_targets[torch.where(self.dataset_tasks == t)[0]]).tolist()
            self.num_classes[t] = len(self.classes[t])
            self.indices_per_class[t] = {}
            self.batches_per_class[t] = {}  # Number of K-shot batches that each class can provide
            for c in self.classes[t]:
                self.indices_per_class[t][c] = torch.where(
                    (self.dataset_tasks == t) *
                    (self.dataset_targets == c)
                )[0]
                print("Indices per class ({}, {}): {}".format(t, c, self.indices_per_class[t][c]))
                self.batches_per_class[t][c] = self.indices_per_class[t][c].shape[0] // self.K_shot
#                print("Batches per class ({}, {}): {}".format(t, c, self.batches_per_class[t][c]))
            self.batches_per_task[t] = sum(self.batches_per_class[t].values())
#            print("Batches for task {}: {}".format(t, self.batches_per_task[t]))
        self.unique_task_classes = [(t,c) for t in self.tasks for c in self.classes[t]]

        # Create a list of task-class tuples from which we select the N classes per batch
        self.iterations_per_task = [sum(self.batches_per_class[t].values()) // self.N_way for t in self.tasks]
        self.task_list = [t for t in self.tasks for _ in range(self.iterations_per_task[t])]
        print("Task_list  (init): ", self.task_list)
        self.iterations = sum(self.iterations_per_task)
#        print("Iterations: ", self.iterations_per_task)
        self.class_list = {
            t: [c for c in self.classes[t] for _ in range(self.batches_per_class[t][c])]
            for t in self.task_list
        }
        print("Class_list (init): ", self.class_list)
        if shuffle_once or self.shuffle:
            self.shuffle_data()
        else:
            # For testing, we iterate over tasks and classes instead of shuffling them
            for t in self.tasks:
                sort_idxs = [
                    i + p * self.num_classes[t]
                    for i, c in enumerate(self.classes[t]) 
                    for p in range(self.batches_per_class[t][c])
                ]
                self.class_list[t] = np.array(self.class_list[t])[np.argsort(sort_idxs)].tolist()
        print("Class_list (final): ", self.class_list)
        print("Task_list  (final): ", self.task_list)
            
    def shuffle_data(self):
        # Shuffle the examples per task and class.       
        for t,c in self.unique_task_classes:
            perm = torch.randperm(self.indices_per_class[t][c].shape[0])
            self.indices_per_class[t][c] = self.indices_per_class[t][c][perm]

        # Shuffle the order of the tasks
        random.shuffle(self.task_list)
        
        # Lastly, shuffle the class list per task.
        # Note that this way of shuffling does not prevent to choose the same class twice in a batch. 
        # Especially with NLP-tasks with small number of classes, this happens quite often  
        for t in self.tasks:
            random.shuffle(self.class_list[t])

    def __iter__(self):
        # Shuffle data
        if self.shuffle:
            self.shuffle_data()

        # Sample few-shot batches
        start_index = defaultdict(int)
        task_iter = [0] * len(self.tasks)
        for it in range(self.iterations):
            
            # Select N classes for task t for the batch
            t = self.task_list[it]
            idx = task_iter[t] * self.N_way
            task_iter[t] += 1
            class_batch = self.class_list[t][idx:idx+self.N_way]

            # For each task-class tuple, select the next K examples and add them to the batch
            index_batch = []
            for c in class_batch:  
                index_batch.extend(self.indices_per_class[t][c][start_index[t,c]:start_index[t,c]+self.K_shot])
                start_index[t,c] += self.K_shot
                
            # If we return support+query set, sort them so that they are easy to split
            if self.include_query:
                index_batch = index_batch[::2] + index_batch[1::2]
            yield [i.item() for i in index_batch]

    def __len__(self):
        return self.iterations

In [11]:
class TaskBatchSampler(object):

    def __init__(self, dataset_tasks, dataset_targets, batch_size, N_way, K_shot, include_query=False, shuffle=True):
        """
        Inputs:
            dataset_tasks - PyTorch tensor of the id's of the tasks in the dataset.
            dataset_classes - PyTorch tensor of the classes in the dataset
            batch_size - Number of tasks to aggregate in a batch
            N_way - Number of classes to sample per batch.
            K_shot - Number of examples to sample per class in the batch.
            include_query - If True, returns batch of size N_way*K_shot*2, which
                            can be split into support and query set. Simplifies
                            the implementation of sampling the same classes but
                            distinct examples for support and query set.
            shuffle - If True, examples and classes are newly shuffled in each
                      iteration (for training)
        """
        super().__init__()
        self.batch_sampler = FewShotBatchSampler(dataset_tasks, dataset_targets, N_way, K_shot, include_query, shuffle)
        self.task_batch_size = batch_size
        self.local_batch_size = self.batch_sampler.batch_size

    def __iter__(self):
        # Aggregate multiple batches before returning the indices
        batch_list = []
        for batch_idx, batch in enumerate(self.batch_sampler):
            batch_list.extend(batch)
            if (batch_idx+1) % self.task_batch_size == 0:
                yield batch_list
                batch_list = []

    def __len__(self):
        return len(self.batch_sampler)//self.task_batch_size

    
    def get_collate_fn(self):
        # Returns a collate function that converts a list of items into format for transformer model
        
        def collate_fn(item_list):
            input_batch = {
                "input_ids": torch.stack([x[0] for task, x, label in item_list], dim=0),
                "token_type_ids": torch.stack([x[1] for task, x, label in item_list], dim=0),
                "attention_mask": torch.stack([x[2] for task, x, label in item_list], dim=0)
            } 
            label_batch = torch.stack([label for task, x, label in item_list], dim=0)
            return input_batch, label_batch
        
        return collate_fn

In [23]:
# Mock dataset for testing purposes
n_c = [2,3,3]
c_len = [51,13,10]

base = torch.vstack([
    torch.tensor([t, c])
    for t in range(3)
    for c in range(n_c[t])
    for e in range(c_len[t])
])

base = torch.hstack((base, torch.tensor(np.arange(len(base))).unsqueeze(dim=1)))
tasks = base[:,0]
labels = base[:,1]
num = len(base)

dataset = NLPDataset(tasks, torch.ones((num, 2)), torch.zeros((num, 2)), torch.zeros((num, 2)), labels)

In [24]:
tasks, labels

(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2,
         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
         2, 2, 2]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        

In [25]:
fsb_s = FewShotBatchSampler(tasks, labels, 3, 2, include_query=False, shuffle=False, shuffle_once=True)

Indices per class (0, 0): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
Indices per class (0, 1): tensor([ 51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
         79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
         93,  94,  95,  96,  97,  98,  99, 100, 101])
Indices per class (1, 0): tensor([102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114])
Indices per class (1, 1): tensor([115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127])
Indices per class (1, 2): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140])
Indices per class (2, 0): tensor([141, 142, 143, 144, 145, 146, 147, 148, 149, 150])
Indices per class (2, 1): tensor([151, 152, 15

In [26]:
# Check out what the FewShotBatchSampler looks like

print("Size of dataset:   ", len(labels))

print("Classes:           ", fsb_s.classes)
print("Number of classes: ", fsb_s.num_classes)
print("Batches per class: ", fsb_s.batches_per_class)
print("Indices per class: ", fsb_s.indices_per_class)
print("Iterations:        ", fsb_s.iterations)
print("Class list:        ", fsb_s.class_list)

print("Print some batches")
for i, b in enumerate(fsb_s):
    print(i, b)

Size of dataset:    171
Classes:            {0: [0, 1], 1: [0, 1, 2], 2: [0, 1, 2]}
Number of classes:  {0: 2, 1: 3, 2: 3}
Batches per class:  {0: {0: 25, 1: 25}, 1: {0: 6, 1: 6, 2: 6}, 2: {0: 5, 1: 5, 2: 5}}
Indices per class:  {0: {0: tensor([41,  3, 42, 19, 35, 49, 40, 29, 15, 16, 11, 26, 22, 21,  6,  0, 45, 44,
        25,  7, 17, 18,  5, 48, 38, 32, 28, 12, 20, 23, 31, 43, 30,  1,  2,  4,
        36, 47, 13, 37, 33, 39, 46, 50,  9, 10, 14, 34, 24, 27,  8]), 1: tensor([ 97,  99,  92,  65,  78, 101,  76,  81,  66,  94,  56,  73,  55,  70,
         64,  69,  58,  75,  62,  61,  63,  79,  77,  60,  53,  67,  71,  52,
         82,  84,  95,  72,  74,  96,  51,  57,  89,  85,  86,  83,  90,  80,
         87,  54,  98,  68,  88,  59, 100,  93,  91])}, 1: {0: tensor([102, 112, 106, 111, 108, 113, 109, 104, 105, 110, 114, 107, 103]), 1: tensor([123, 125, 121, 119, 122, 118, 120, 126, 116, 115, 127, 117, 124]), 2: tensor([132, 128, 130, 134, 131, 139, 133, 135, 137, 140, 129, 138, 136])}, 2

In [27]:
tb_s = TaskBatchSampler(tasks, labels, 5, 3, 2, include_query=False, shuffle=False)

Indices per class (0, 0): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
Indices per class (0, 1): tensor([ 51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
         79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,
         93,  94,  95,  96,  97,  98,  99, 100, 101])
Indices per class (1, 0): tensor([102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114])
Indices per class (1, 1): tensor([115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127])
Indices per class (1, 2): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140])
Indices per class (2, 0): tensor([141, 142, 143, 144, 145, 146, 147, 148, 149, 150])
Indices per class (2, 1): tensor([151, 152, 15

In [28]:
dl = data.DataLoader(
    dataset, 
    batch_sampler=tb_s,
    collate_fn=tb_s.get_collate_fn()
)
for i, (x, labels) in enumerate(dl):
    print("Batch {}: len={}, label={}".format(i, len(x), labels))

Batch 0: len=3, label=tensor([0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
        0, 0, 1, 1, 0, 0])
Batch 1: len=3, label=tensor([1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
        1, 1, 0, 0, 1, 1])
Batch 2: len=3, label=tensor([0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
        0, 0, 1, 1, 0, 0])
Batch 3: len=3, label=tensor([1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2,
        0, 0, 1, 1, 2, 2])
Batch 4: len=3, label=tensor([0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2, 0, 0, 1, 1, 2, 2,
        0, 0, 1, 1, 2, 2])


In [29]:
# See if this works
# Testset is a list of tuple with (task, (input_ids, token_type_ids, attention_mask), label)

dataset_tasks = torch.tensor([item[0] for item in test_set])
dataset_targets = torch.tensor([item[2] for item in test_set])
s = TaskBatchSampler(dataset_tasks, dataset_targets, 5, 3, 2, include_query=False, shuffle=True)

Indices per class (0, 0): tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  24,  25,  30,  31,  32,
         34,  35,  36,  37,  38,  39,  40,  44,  45,  46,  47,  48,  49,  50,
         51,  52,  54,  55,  56,  68,  69,  70,  71,  72,  73,  74,  75,  77,
         78,  80,  81,  82,  83,  85,  87,  88,  89,  90,  91,  92,  93,  94,
         97,  98, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 122, 136, 144, 157, 167, 172, 178, 184, 190,
        195, 212, 213, 214, 215, 217, 220, 223, 225, 226, 231, 232, 241, 246,
        247, 248, 249, 252, 256, 257, 263, 268, 269, 270, 272, 279, 280, 281,
        284, 285, 288, 290, 291, 293, 295, 297, 298, 299, 301, 305])
Indices per class (0, 1): tensor([ 23,  26,  27,  28,  29,  41,  42,  43,  53,  57,  58,  59,  60,  61,
         62,  63,  64,  65,  66,  79,  99, 117, 118, 119, 120, 121, 123, 124,
        125, 127, 128

In [31]:
"""
TODO: 
THIS IS NOT WORKING YET
PROBABLY BECAUSE OF PROBLEM WITH UNEQUAL CLASS SIZES WITHIN A TASK
"""

dl = data.DataLoader(
    dataset, 
    batch_sampler=s,
    collate_fn=s.get_collate_fn()
)
for i, (x, labels) in enumerate(dl):
    print("Batch {}: len={}, label={}".format(i, len(x), labels))

IndexError: index 184 is out of bounds for dimension 0 with size 171

Note that this sampler eventually allows the sampling of batches where we use a class twice due to a simpler shuffling function (`shuffle_data`). In other words, during training or validation, if we sample batches for a 5-class 4-shot training setting, it can happen that we get a batch where 2 of the 5 classes are identical. 

Since the NLP classification tasks only have a few classes, this happens very often. It does not constitute any issue if the code for the meta-learning methods support variable number of classes and shots per class. Nonetheless, for our NLP metalearner, we choose to set shuffle to False, and N-WAY to the max number of classes, so that each class is covered in each episode. 

For our experiments, we will use a training setting with 2/3-classes per task, and try out various options for K_SHOT, e.g. 4, 8, 16. This means that each support set contains 2/3 classes with 4/8/16 examples each, i.e., between 8-48 samples overall. 

For simplicity, we implemented the sampling of a support and query set as sampling a support set with twice the number of examples. After sampling a batch from the data loader, we need to split it into a support and query set. We can summarize this step in the following function:

In [32]:
def split_batch(inputs, targets):
    # Split inputs and targets in two batches.
    # Format needs to match with requirements of adaptorfusion model
    
    support_input_ids, query_input_ids = inputs["input_ids"].chunk(2, dim=0)
    support_token_type_ids, query_token_type_ids = inputs["token_type_ids"].chunk(2, dim=0)
    support_attention_mask, query_attention_mask = inputs["attention_mask"].chunk(2, dim=0)
    support_inputs = {
        "input_ids": support_input_ids,
        "token_type_ids": support_token_type_ids,
        "attention_mask": support_attention_mask
    } 
    query_inputs = {
        "input_ids": query_input_ids,
        "token_type_ids": query_token_type_ids,
        "attention_mask": query_attention_mask
    } 
    support_targets, query_targets = targets.chunk(2, dim=0)
    return support_inputs, query_inputs, support_targets, query_targets


In [33]:
class ProtoNet:

    @staticmethod
    def calculate_prototypes(features, targets):
        # Given a stack of features vectors and labels, return class prototypes
        # features - shape [N, proto_dim], targets - shape [N]
        classes, _ = torch.unique(targets).sort()  # Determine which classes we have
        prototypes = []
        for c in classes:
            p = features[torch.where(targets == c)[0]].mean(dim=0)  # Average class feature vectors
            prototypes.append(p)
        prototypes = torch.stack(prototypes, dim=0)
        # Return the 'classes' tensor to know which prototype belongs to which class
        return prototypes, classes



## MAML and ProtoMAML

The second meta-learning algorithm we will look at is MAML, short for Model-Agnostic Meta-Learning. MAML is an optimization-based meta-learning algorithm, which means that it tries to adjust the standard optimization procedure to a few-shot setting. The idea of MAML is relatively simple: given a model, support, and query set during training, we optimize the model for $m$ steps on the support set and evaluate the gradients of the query loss with respect to the original model's parameters. For the same model, we do it for a few different support-query sets and accumulate the gradients. This results in learning a model that provides a good initialization for being quickly adapted to the training tasks. If we denote the model parameters with $\theta$, we can visualize the procedure as follows (Figure credit - [Finn et al.](http://proceedings.mlr.press/v70/finn17a.html)).


The full algorithm of MAML is therefore as follows. At each training step, we sample a batch of tasks, i.e., a batch of support-query set pairs. For each task $\mathcal{T}_i$, we optimize a model $f_{\theta}$ on the support set via SGD, and denote this model as $f_{\theta_i'}$. We refer to this optimization as _inner loop_. Using this new model, we calculate the gradients of the original parameters, $\theta$, with respect to the query loss on $f_{\theta_i'}$. These gradients are accumulated over all tasks and used to update $\theta$. This is called _outer loop_ since we iterate over tasks. The full MAML algorithm is summarized below (Figure credit - [Finn et al.](http://proceedings.mlr.press/v70/finn17a.html)).

To obtain gradients for the initial parameters $\theta$ from the optimized model $f_{\theta_i'}$, we actually need second-order gradients, i.e. gradients of gradients, as the support set gradients depend on $\theta$ as well. This makes MAML computationally expensive, especially when using multiple inner loop steps. A simpler, yet almost equally well-performing alternative is First-Order MAML (FOMAML) which only uses first-order gradients. This means that the second-order gradients are ignored, and we can calculate the outer loop gradients (line 10 in algorithm 2) simply by calculating the gradients with respect to $\theta_i'$ and use those as an update to $\theta$. Hence, the new update rule becomes:

$$
\theta\leftarrow\theta-\beta\sum_{\mathcal{T}_i\sim p(\mathcal{T})}\nabla_{\theta_i'}\mathcal{L}_{\mathcal{T}_i}(f_{\theta_i'})
$$

Note the change of $\theta$ to $\theta_i'$ for $\nabla$.

### ProtoMAML

A problem of MAML is how to design the output classification layer. In case all tasks have a different number of classes, we need to initialize the output layer with zeros or randomly in every iteration. Even if we always have the same number of classes, we just start from random predictions. This requires several inner loop steps to reach a reasonable classification result. To overcome this problem, Triantafillou et al. (2020) propose to combine the merits of Prototypical Networks and MAML. Specifically, we can use prototypes to initialize our output layer to have a strong initialization. Thereby, it can be shown that the softmax over euclidean distances can be reformulated as a linear layer with softmax. To see this, let's first write out the negative Euclidean distance between a feature vector $f_{\theta}(\mathbf{x}^{*})$ of a new data point $\mathbf{x}^{*}$ to a prototype $\mathbf{v}_c$ of class $c$:

$$
-||f_{\theta}(\mathbf{x}^{*})-\mathbf{v}_c||^2=-f_{\theta}(\mathbf{x}^{*})^Tf_{\theta}(\mathbf{x}^{*})+2\mathbf{v}_c^{T}f_{\theta}(\mathbf{x}^{*})-\mathbf{v}_c^T\mathbf{v}_c
$$

We perform the classification across all classes $c\in\mathcal{C}$ and take a softmax on the distance. Hence, any term that is the same for all classes can be removed without changing the output probabilities. In the equation above, this is true for $-f_{\theta}(\mathbf{x}^{*})^Tf_{\theta}(\mathbf{x}^{*})$ since it is independent of any class prototype. Thus, we can write:

$$
-||f_{\theta}(\mathbf{x}^{*})-\mathbf{v}_c||^2=2\mathbf{v}_c^{T}f_{\theta}(\mathbf{x}^{*})-||\mathbf{v}_c||^2+\text{constant}
$$

Taking a second look at the equation above, it looks a lot like a linear layer. For this, we use $\mathbf{W}_{c,\cdot}=2\mathbf{v}_c$ and $b_c=-||\mathbf{v}_c||^2$ which gives us the linear layer $\mathbf{W}f_{\theta}(\mathbf{x}^{*})+\mathbf{b}$. Hence, if we initialize the output weight with twice the prototypes, and the biases by the negative squared L2 norm of the prototypes, we start with a Prototypical Network. MAML allows us to adapt this layer and the rest of the network further. 

In the following, we will implement First-Order ProtoMAML for few-shot classification. The implementation of MAML would be the same except for the output layer initialization. 

### ProtoMAML implementation

For implementing ProtoMAML, we can follow Algorithm 2 with minor modifications. At each training step, we first sample a batch of tasks, and a support and query set for each task. In our case of few-shot classification, this means that we simply sample multiple support-query set pairs from our sampler. For each task, we finetune our current model on the support set. However, since we need to remember the original parameters for the other tasks, the outer loop gradient update, and future training steps, we need to create a copy of our model and finetune only the copy. We can copy a model by using standard Python functions like `deepcopy`. The inner loop is implemented in the function `adapt_few_shot` in the PyTorch Lightning module below. 

After finetuning the model, we apply it to the query set and calculate the first-order gradients with respect to the original parameters $\theta$. In contrast to simple MAML, we also have to consider the gradients with respect to the output layer initialization, i.e. the prototypes, since they directly rely on $\theta$. To realize this efficiently, we take two steps. First, we calculate the prototypes by applying the original model, i.e. not the copied model, on the support elements. When initializing the output layer, we detach the prototypes to stop the gradients. This is because, in the inner loop itself, we do not want to consider gradients through the prototypes back to the original model. However, after the inner loop is finished, we re-attach the computation graph of the prototypes by writing `output_weight = (output_weight - init_weight).detach() + init_weight`. While this line does not change the value of the variable `output_weight`, it adds its dependency on the prototype initialization `init_weight`. Thus, if we call `.backward` on `output_weight`, we will automatically calculate the first-order gradients with respect to the prototype initialization in the original model.

After calculating all gradients and summing them together in the original model, we can take a standard optimizer step. PyTorch Lightning's method is however designed to return a loss-tensor on which we call `.backward` first. Since this is not possible here, we need to perform the optimization step ourselves. All details can be found in the code below.

For implementing (Proto-)MAML with second-order gradients, it is recommended to use libraries such as [$\nabla$higher](https://github.com/facebookresearch/higher) from Facebook AI Research. For simplicity, we stick with first-order methods here.

In [34]:
# Mock model for testing purposes --> Replace with adaptor fusion model

from transformers import BertForSequenceClassification

def get_transformer_model(output_size):
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=output_dim)
    for param in model.bert.bert.parameters():
        param.requires_grad = False
    return model


In [35]:
class ProtoMAML(pl.LightningModule):
    
    def __init__(self, proto_dim, lr, lr_inner, lr_output, num_inner_steps):
        """
        Inputs
            proto_dim - Dimensionality of prototype feature space
            lr - Learning rate of the outer loop Adam optimizer
            lr_inner - Learning rate of the inner loop SGD optimizer
            lr_output - Learning rate for the output layer in the inner loop
            num_inner_steps - Number of inner loop updates to perform
        """
        super().__init__()
        self.save_hyperparameters()
        self.model = get_transformer_model(output_size=self.hparams.proto_dim)
        
        
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[140,180], gamma=0.1)
        return [optimizer], [scheduler]
        
        
    def run_model(self, local_model, output_weight, output_bias, inputs, labels):
        
        # Execute a model with given output layer weights and inputs
        feats = local_model(inputs)
        preds = F.linear(feats, output_weight, output_bias)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=1) == labels).float()
        
        return loss, preds, acc
    
    
    def adapt_few_shot(self, support_inputs, support_targets):
        
        # Determine prototype initialization
        support_feats = self.model(support_inputs)
        prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
        support_labels = (classes[None,:] == support_targets[:,None]).long().argmax(dim=-1)
        
        # Create inner-loop model and optimizer
        local_model = deepcopy(self.model)
        local_model.train()
        local_optim = optim.SGD(local_model.parameters(), lr=self.hparams.lr_inner)
        local_optim.zero_grad()
        
        # Create output layer weights with prototype-based initialization
        init_weight = 2 * prototypes
        init_bias = -torch.norm(prototypes, dim=1)**2
        output_weight = init_weight.detach().requires_grad_()
        output_bias = init_bias.detach().requires_grad_()
        
        # Optimize inner loop model on support set
        for _ in range(self.hparams.num_inner_steps):
            
            # Determine loss on the support set
            loss, _, _ = self.run_model(local_model, output_weight, output_bias, support_inputs, support_labels)
            
            # Calculate gradients and perform inner loop update
            loss.backward()
            local_optim.step()
            
            # Update output layer via SGD
            output_weight.data -= self.hparams.lr_output * output_weight.grad
            output_bias.data -= self.hparams.lr_output * output_bias.grad
            
            # Reset gradients
            local_optim.zero_grad()
            output_weight.grad.fill_(0)
            output_bias.grad.fill_(0)
            
        # Re-attach computation graph of prototypes
        output_weight = (output_weight - init_weight).detach() + init_weight
        output_bias = (output_bias - init_bias).detach() + init_bias
        
        return local_model, output_weight, output_bias, classes

    
    def outer_loop(self, batch, mode="train"):

        accuracies = []
        losses = []
        self.model.zero_grad()
        
        # Determine gradients for batch of tasks
        for task_batch in batch:
            inputs, targets = task_batch
            support_inputs, query_inputs, support_targets, query_targets = split_batch(inputs, targets)
            
            # Perform inner loop adaptation
            local_model, output_weight, output_bias, classes = self.adapt_few_shot(support_inputs, support_targets)
            
            # Determine loss of query set
            query_labels = (classes[None,:] == query_targets[:,None]).long().argmax(dim=-1)
            loss, preds, acc = self.run_model(local_model, output_weight, output_bias, query_inputs, query_labels)
            
            # Calculate gradients for query set loss
            if mode == "train":
                loss.backward()

                for p_global, p_local in zip(self.model.parameters(), local_model.parameters()):
                    
                    # First-order approx. -> add gradients of finetuned and base model
                    p_global.grad += p_local.grad  
            
            accuracies.append(acc.mean().detach())
            losses.append(loss.detach())
        
        # Perform update of base model
        if mode == "train":
            opt = self.optimizers()
            opt.step()
            opt.zero_grad()
        
        self.log(f"{mode}_loss", sum(losses) / len(losses))
        self.log(f"{mode}_acc", sum(accuracies) / len(accuracies))

        
    def training_step(self, batch, batch_idx):
        self.outer_loop(batch, mode="train")
        return None  # Returning None means we skip the default training optimizer steps by PyTorch Lightning

    
    def validation_step(self, batch, batch_idx):
        # Validation requires to finetune a model, hence we need to enable gradients
        torch.set_grad_enabled(True)
        self.outer_loop(batch, mode="val")
        torch.set_grad_enabled(False)

### Training

To train ProtoMAML, we need to change our sampling slightly. Instead of a single support-query set batch, we need to sample multiple. To implement this, we yet use another Sampler that combines multiple batches from a `FewShotBatchSampler` and returns it afterward. Additionally, we define a `collate_fn` for our data loader which takes the stack of support-query set images and returns the tasks as a list. This makes it easier to process in our PyTorch Lightning module before. The implementation of the sampler can be found below. 

The creation of the data loaders is with this sampler straight-forward. Note that since many samples need to loaded for a training batch, it is recommended to use less workers than usual.

In [48]:
def train_model(model_class, train_loader, val_loader, **kwargs):

    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, model_class.__name__),
        gpus=1 if torch.cuda.is_available() else 0,
        max_epochs=200,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch")
        ],
        progress_bar_refresh_rate=0)
    trainer.logger._default_hp_metric = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, model_class.__name__ + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        model = model_class.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # To be reproducable
        model = model_class(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        # Load best checkpoint after training
        model = model_class.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)  

    return model

In [50]:
# Training constant
N_WAY = 3     # All tasks have 2 or 3 classes, so set to 3 to ensure all classes are covered in an episode
K_SHOT = 4

# Training set
train_protomaml_sampler = TaskBatchSampler(
    train_set.tasks,
    train_set.labels, 
    include_query=True,
    N_way=N_WAY,
    K_shot=K_SHOT,
    batch_size=16,
    shuffle=False   # Set to False, otherwise you risk getting same class twice in dataset
)
train_protomaml_loader = data.DataLoader(
    train_set, 
    batch_sampler=train_protomaml_sampler,
    collate_fn=train_protomaml_sampler.get_collate_fn(),
    num_workers=2
)

# Validation set
val_protomaml_sampler = TaskBatchSampler(
    val_set.tasks,
    val_set.labels,
    include_query=True,
    N_way=N_WAY,
    K_shot=K_SHOT,
    batch_size=1,  # We do not update the parameters, hence the batch size is irrelevant here
    shuffle=False
)
val_protomaml_loader = data.DataLoader(
    val_set, 
    batch_sampler=val_protomaml_sampler,
    collate_fn=val_protomaml_sampler.get_collate_fn(),
    num_workers=2
)

Indices per class (0, 0): tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  24,  25,  30,  31,  32,
         34,  35,  36,  37,  38,  39,  40,  44,  45,  46,  47,  48,  49,  50,
         51,  52,  54,  55,  56,  68,  69,  70,  71,  72,  73,  74,  75,  77,
         78,  80,  81,  82,  83,  85,  87,  88,  89,  90,  91,  92,  93,  94,
         97,  98, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 122, 136, 144, 157, 167, 172, 178, 184, 190,
        195, 212, 213, 214, 215, 217, 220, 223, 225, 226, 231, 232, 241, 246,
        247, 248, 249, 252, 256, 257, 263, 268, 269, 270, 272, 279, 280, 281,
        284, 285, 288, 290, 291, 293, 295, 297, 298, 299, 301, 305])
Indices per class (0, 1): tensor([ 23,  26,  27,  28,  29,  41,  42,  43,  53,  57,  58,  59,  60,  61,
         62,  63,  64,  65,  66,  79,  99, 117, 118, 119, 120, 121, 123, 124,
        125, 127, 128

Now, we are ready to train our ProtoMAML. We use the same feature space size as for ProtoNet but can use a higher learning rate since the outer loop gradients are accumulated over 16 batches. The inner loop learning rate is set to 0.1, which is much higher than the outer loop learning rate because we use SGD in the inner loop instead of Adam. Commonly, the learning rate for the output layer is higher than the base model if the base model is very deep or pre-trained. However, for our setup, we observed no noticeable impact of using a different learning rate than the base model. The number of inner loop updates is another crucial hyperparameter and depends on the similarity of our training tasks. Since all tasks are on images from the same dataset, we notice that a single inner loop update achieves similar performance as 3 or 5 while training considerably faster. However, especially in RL and NLP, a larger number of inner loop steps are often needed.

In [None]:
protomaml_model = train_model(
    ProtoMAML, 
    proto_dim=64, 
    lr=1e-3, 
    lr_inner=0.1,
    lr_output=0.1,
    num_inner_steps=1,  # Often values between 1 and 10
    train_loader=train_protomaml_loader, 
    val_loader=val_protomaml_loader
)

### Testing

We test ProtoMAML in the same manner as ProtoNet, namely by picking random examples in the test set as support sets and use the rest of the dataset as the query set. Instead of just calculating the prototypes for all examples, we need to finetune a separate model for each support set. This is why this process is more expensive than ProtoNet, and in our case, testing $k=\{2,4,8,16,32\}$ can take almost an hour. Hence, we provide evaluation files besides the pretrained models.

In [None]:
def test_protomaml(model, dataset, k_shot=4):
    
    pl.seed_everything(42)
    model = model.to(device)
    num_classes = dataset.targets.unique().shape[0]
    exmps_per_class = dataset.targets.shape[0]//num_classes
    
    # Data loader for full test set as query set
    full_dataloader = data.DataLoader(
        dataset, 
        batch_size=128, 
        num_workers=4, 
        shuffle=False, 
        drop_last=False
    )
    # Data loader for sampling support sets
    sampler = FewShotBatchSampler(
        dataset.targets, 
        include_query=False,
        N_way=num_classes,
        K_shot=k_shot,
        shuffle=False,
        shuffle_once=False
    )
    sample_dataloader = data.DataLoader(dataset, batch_sampler=sampler, num_workers=2)
    
    # We iterate through the full dataset in two manners. First, to select the k-shot batch. 
    # Second, the evaluate the model on all other examples
    accuracies = []
    for (support_imgs, support_targets), support_indices in tqdm(zip(sample_dataloader, sampler), "Performing few-shot finetuning"):
        support_imgs = support_imgs.to(device)
        support_targets = support_targets.to(device)
        
        # Finetune new model on support set
        local_model, output_weight, output_bias, classes = model.adapt_few_shot(support_imgs, support_targets)
        with torch.no_grad():  # No gradients for query set needed
            local_model.eval()
            batch_acc = torch.zeros((0,), dtype=torch.float32, device=device)
            
            # Evaluate all examples in test dataset
            for query_imgs, query_targets in full_dataloader:
                query_imgs = query_imgs.to(device)
                query_targets = query_targets.to(device)
                query_labels = (classes[None,:] == query_targets[:,None]).long().argmax(dim=-1)
                _, _, acc = model.run_model(local_model, output_weight, output_bias, query_imgs, query_labels)
                batch_acc = torch.cat([batch_acc, acc.detach()], dim=0)
                
            # Exclude support set elements
            for s_idx in support_indices:
                batch_acc[s_idx] = 0
            batch_acc = batch_acc.sum().item() / (batch_acc.shape[0] - len(support_indices))
            accuracies.append(batch_acc)
            
    return mean(accuracies), stdev(accuracies)

In contrast to training, it is recommended to use many more inner loop updates during testing. During training, we are not interested in getting the best model from the inner loop, but the model which can provide the best gradients. Hence, one update might be already sufficient in training, but for testing, it was often observed that a larger number of updates can give a considerable performance boost. Thus, we change the inner loop updates to 200 before testing.

In [None]:
protomaml_model.hparams.num_inner_steps = 200

Now, we can test our model. For the pre-trained models, we provide a json file with the results to reduce evaluation time.

In [None]:
protomaml_result_file = os.path.join(CHECKPOINT_PATH, "protomaml_fewshot.json")

if os.path.isfile(protomaml_result_file):
    # Load pre-computed results
    with open(protomaml_result_file, 'r') as f:
        protomaml_accuracies = json.load(f)
    protomaml_accuracies = {int(k): v for k, v in protomaml_accuracies.items()}
else:
    # Perform experiments
    protomaml_accuracies = dict()
    for k in [2, 4, 8, 16, 32]:
        protomaml_accuracies[k] = test_protomaml(protomaml_model, test_set, k_shot=k)
    # Export results
    with open(protomaml_result_file, 'w') as f:
        json.dump(protomaml_accuracies, f, indent=4)

for k in protomaml_accuracies:
    print(f"Accuracy for k={k}: {100.0*protomaml_accuracies[k][0]:4.2f}% (+-{100.0*protomaml_accuracies[k][1]:4.2f