# Lab 7 - few-shot learning and hypernetworks

Plan for today:
* learn about the concept of few-shot learning
* familiarize ourselves with hypernetworks
* connect those two concepts by implementing a technique from [this paper](https://arxiv.org/pdf/1706.03466.pdf).

In [3]:
!pip install learn2learn

Collecting learn2learn
  Using cached learn2learn-0.1.7-cp37-cp37m-linux_x86_64.whl
Collecting gsutil
  Using cached gsutil-5.10-py3-none-any.whl
Collecting qpth>=0.0.15
  Using cached qpth-0.0.15-py3-none-any.whl
Collecting google-auth[aiohttp]>=2.5.0
  Using cached google_auth-2.6.6-py2.py3-none-any.whl (156 kB)
Collecting pyOpenSSL>=0.13
  Using cached pyOpenSSL-22.0.0-py2.py3-none-any.whl (55 kB)
Collecting argcomplete>=1.9.4
  Using cached argcomplete-2.0.0-py2.py3-none-any.whl (37 kB)
Collecting google-apitools>=0.5.32
  Using cached google_apitools-0.5.32-py3-none-any.whl (135 kB)
Collecting retry-decorator>=1.0.0
  Using cached retry_decorator-1.1.1-py2.py3-none-any.whl
Collecting gcs-oauth2-boto-plugin>=3.0
  Using cached gcs_oauth2_boto_plugin-3.0-py3-none-any.whl
Collecting fasteners>=0.14.1
  Using cached fasteners-0.17.3-py3-none-any.whl (18 kB)
Collecting httplib2>=0.20.4
  Using cached httplib2-0.20.4-py3-none-any.whl (96 kB)
Collecting monotonic>=1.4
  Using cached mono

In [50]:
import torch
from torchvision.datasets import Omniglot, EMNIST
from torchvision import transforms as T
import learn2learn as l2l
import matplotlib.pyplot as plt
from learn2learn.data import MetaDataset, TaskDataset
from learn2learn.vision.models import OmniglotCNN
from torch import nn
from typing import Tuple
import matplotlib.pyplot as plt


## Few-shot learning

In general, neural networks require huge amounts of data to train well. Few-shot learning techniques aim to construct models, which are capable of quickly adapting to numerous **tasks** based on limited amounts of data.

One of the most popular usecases for FSL is image classification. We define $K$-shot, $N$-way classification as the task of classifying between $N$ classes based on $K$ examples for each of the classes, called the **support set**. The model is then tasked with classifying the **query set** of previously unseen images, which belong to the same set of $N$ classes.

During training, we construct **tasks** consisting of support and query examples from a set of training classes and taks the model with adapting to those tasks.

We evaluate the model on tasks sampled from a set of classes **separate from the training set** - after all, we want to measure how well the model adapts to previously unseen tasks!



One of the most popular datasets for FSL is Omniglot. 

### Task for you - import the omniglot dataset from the [learn2learn](http://learn2learn.net/) package and visualize an example task:
* sample a single task from the tasksets
* draw a grid with images and their classes

In [51]:
shots = 5
queries = 15
ways = 5


tasksets = l2l.vision.benchmarks.get_tasksets('omniglot',
                                                  train_ways=ways,
                                                  train_samples=shots + queries,
                                                  test_ways=ways,
                                                  test_samples=shots + queries,
                                                  num_tasks=20000,
                                                  root='~/data',
                                              
    )

for X, y in tasksets.train:
  break

# your code here - visualize the examples from X

  "Argument interpolation should be of type InterpolationMode instead of int. "


Files already downloaded and verified
Files already downloaded and verified


## Hypernetworks

Hypernetworks are models which, based on some condition, predict weights of other neural networks, which perform the downstream tasks. The concept has been utilized in many fields, such as generative models, point clouds, condtional flows, as well as few-shot learning.


## Bringing it all together - Parameter Prediction from Activations

Today, we will utilize the hypernets in the task of FSL. We will base our solution on the [Few-Shot Image Recognition by Predicting Parameters from Activations](https://arxiv.org/pdf/1706.03466.pdf).

The PPA model consists of:
* a convolutional backbone
* a parameter prediction hypernet

First, we process the support and query samples through the backbone and obtain embeddings $E$ Next, we want to predict the weights of the classifier which will transform $E$ into classes $C$. The classifier is therefore a linear layer with dimentionality $(E, C)$.

We can predict the weights of the classifier in several ways:
* concatenate all of the support embeddings and predict all of the classifier parameters
* predict the *portion* of parameters of shape $(E, 1)$ dedicated to predicting class $C$ based only on the support embeddings from that class. Then, concatenate all portions into weights of shape $(E, C)$.



### Task for you - implement the few-shot hypernetwork
* implement two variants of classifier generation:
  * generating **all** weights based on **all** support class embeddings
  * generating **weight fragments** responsible for predicting class $C$ based solely on support embbedings of class $C$


In [32]:
class Hypernet(nn.Module):
  def __init__(self, n_shot: int, n_way: int, hidden_size: int = 64, weights_per_class: bool = True):
    super().__init()
    self.cnn = OmniglotCNN(hidden_size=hidden_size).features 
    # a convolutional net which transforms an image of shape (1, 28, 28) to vectors of shape `hidden_size`

    self.weight_predictor = ... 

  def forward(
      self, 
      support_examples: torch.Tensor, 
      support_labels: torch.Tensor,
      query_examples: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    support_examples: [n_shot * n_way, 1, 28, 28]
    support_labels: [n_shot * n_way]
    query_samples: [n_query]

    Returns a tuple of logits:
      (y_pred_support, y_pred_query)
      of shapes:
      (
        [n_shot * n_way, n_way],
        [n_query, n_way]        
      ) 
    """

    # 1: process the supports and queries through the cnn
    # 2: generate the weights of the classifier based on the support embeddings
    # 3: classify the support and query embeddings with the generated weights




In [None]:
shots = 5
queries = 15
ways = 5


tasksets = l2l.vision.benchmarks.get_tasksets('omniglot',
                                                  train_ways=ways,
                                                  train_samples=shots + queries,
                                                  test_ways=ways,
                                                  test_samples=shots + queries,
                                                  num_tasks=20000,
                                                  root='~/data',
                                              
    )

### Task for you - finish implementing the training loop:
* add the necessary training loss and optimizer parts
* track the meta-training and meta-validation losses and accuracies throughout the training epochs and plot them after the training
* train the two variants of hypernetwork on Omniglot 
* train the hypernets in two settings:
  * 1-shot, 5-way
  * 5-shot, 5-way

In [None]:
def train_hypernet(
    hypernet: Hypernet,
    tasksets,
    optimizer,
    num_epochs: int = 20,
    n_shot: int = shots,
    n_query: int = queries,
    n_ways: int = ways,
    img_shape = (1, 28, 28)
):

  for e in range(num_epochs):
    # meta-training:

    for X, y in tasksets.train:
      # reshape X and y to have each class in a separate row
      X = X.reshape(n_ways, n_shot+n_query, *img_shape)
      y = y.reshape(n_ways, n_shot+n_query,)

      # separate support from query
      X_support, X_query = X[:, :n_shot], X[:, n_shot:]
      y_support, y_query = y[:, :n_shot], y[: ,n_shot:]

      # re-flatten the tensors
      X_support = X_support.reshape(n_ways * n_shot, *img_shape)
      X_query = X_query.reshape(n_ways * n_query, *img_shape)
      y_support = y_support.reshape(n_ways * n_shot)
      y_query = y_query.reshape(n_ways * n_query)


      # predictions
      y_support_pred, y_query_pred = hypernet(X_support, y_support, X_query)

      # YOUR CODE HERE
      # 
      #####
      


    # meta-validation
    for X, y in tasksets.train:
      # reshape X and y to have each class in a separate row
      X = X.reshape(n_ways, n_shot+n_query, *img_shape)
      y = y.reshape(n_ways, n_shot+n_query,)

      # separate support from query
      X_support, X_query = X[:, :n_shot], X[:, n_shot:]
      y_support, y_query = y[:, :n_shot], y[: ,n_shot:]

      # re-flatten the tensors
      X_support = X_support.reshape(n_ways * n_shot, *img_shape)
      X_query = X_query.reshape(n_ways * n_query, *img_shape)
      y_support = y_support.reshape(n_ways * n_shot)
      y_query = y_query.reshape(n_ways * n_query)

      # YOUR CODE HERE
      # 
      ####

  # plot the training / validation losses and accuracies
  

In [None]:
# initialize and train the hypernetwork

**Question for you** - which variant of the few-shot hypernetwork worked better? Why?

## Final validation

Let's validate our models on one more dataset - EMNIST - which contains digits and latin alphabet characters

### Task for you
* based on [documentation](http://learn2learn.net/tutorials/task_transform_tutorial/transform_tutorial/), prepare the EMNIST meta-dataset. Then, calculate the accuracy of the hypernetworks you've trained on the tasks from that dataset.

# From 27.05 - project presentations!
* Guidelines are [here](https://docs.google.com/document/d/1Xr49OjhKMTZu1Cxmz3b1exezXf9OWDgnlo3IzuYEhtw/edit)