# Sampling Few-Shot Learning Episodes

Last chapter, we learned how to create a class-conditional dataset for few-shot learning, and we implemented our own class-conditional version of the TinySOL dataset. In this chapter, we will build an Episodic Sampler using PyTorch, which will allow us to sample few-shot learning episodes from our dataset. 

This chapter, we will learn how to create an Episodic Sampler for a [Class Conditional Dataset](/fsl-example/datasets.html), and use it sample few-shot learning episodes. 

To recap on our [foundations chapter](/foundations-fsl/foundations.md), episodic training is a technique used in few-shot learning to effectively leverage a large training dataset. It involves splitting each training iteration into a self-contained learning task, known as an episode, which simulates a few-shot learning scenario with a small number of labeled examples for a set of classes. During episodic training, the model is presented with a completely new $N$-shot, $K$-way classification task at each step, and must learn to classify the examples in the query set using only the labeled examples in the support set. This allows the model to learn how to effectively learn from a small amount of data and adapt to new tasks quickly.

## Anatomy of an Episode

```{figure} ../assets/foundations/support-query.png
---
name: support-query
---
A few-shot learning episode splits data into two separate sets: the support set (the few labeled examples of novel data) and the query set (the data we want to label).
```

In few-shot learning, an episode consists of two sets of data: the support set and the query set.

- The support set contains a small number of labeled examples for each of the classes in the episode. We use the examples in the support set to guide the few-shot learning model in the classification task. 

- The query set contains a larger number of (unlabeled) examples for each of the classes. During training, we make predictions for examples in the query set, and compute a loss over these predictions to update the model parameters. During evaluation, we use the predictions for the query set to compute any evaluation metrics for the episode.


In [1]:
import random

import torch

from music_fsl.data import ClassConditionalDataset
import music_fsl.util as util

## Building an `EpisodicSampler` class

To sample few-shot learning episodes, we can take advantage of the [`Sampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) class in PyTorch. The `Sampler` class in PyTorch provides a convenient way to sample a subset of data from a dataset. It is an abstract class that can be extended to create custom sampling strategies for a dataset. In this section, we will implement an EpisodicSampler class that extends the Sampler class to generate few-shot learning episodes.

To create our own `EpisodicSampler`, we will need to implement the following methods:

- `__iter__`: The `__iter__` method is responsible for generating the episodes that we will use for training and evaluation. It should iterate over the episodes and yield an episode with a support and query set at each iteration.

- `__len__`: The `__len__` method is responsible for returning the total number of episodes that the sampler will generate.


Let's start by writing an `__init__` method, which will be responsible for initializing the sampler. 

We'll add in the ability to specify the number of classes to sample per episode (`n_way`), the number of support examples to sample per class (`n_support`), and the number of query examples to sample per class (`n_query`). We'll also add in the ability to specify the number of episodes to sample (`n_episodes`).

```python
class EpisodeSampler(torch.utils.data.Sampler):
   """
    A sampler for few-shot learning tasks.

    Args:
        dataset (ClassConditionalDataset): The dataset to sample episodes from.
        n_way (int): The number of classes to sample per episode.
            Default: 5.
        n_support (int): The number of samples per class to use as support.
            Default: 5.
        n_query (int): The number of samples per class to use as query.
            Default: 20.
        n_episodes (int): The number of episodes to generate.
            Default: 100.
    """
    def __init__(self,
        dataset: ClassConditionalDataset, 
        n_way: int = 5, 
        n_support: int = 5,
        n_query: int = 20,
        n_episodes: int = 100,
    ):
        self.dataset = dataset

        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.n_episodes = n_episodes
```

## Sampling episodes: the `__iter__` method
Next, we will implement the `__iter__` method. This method will be responsible for generating the actual episodes for training and evaluation. 

```python
def __iter__(self):
    """Sample an episode from the class conditional dataset"""
```

First, we need to find out which subset of the classlist will be in the episode. We can do this by sampling `n_way` classes from the classlist. 

```python
        # sample the list of classes for this episode
        episode_classlist = random.sample(self.dataset.classlist, self.n_way)
```

Next, we need to sample the support and query sets for each class. 
We can start creating empty lists for each set, and iterating through each of the classes:


```python
        # sample the support and query sets for this episode
        support, query = [], []
        for c in episode_classlist:
```

We need to sample `n_support` and `n_query` examples for each class. Because our dataset is an instance of a [Class Conditional Dataset](/fsl-example/datasets.html), we can use the `class_to_indices` attribute to get the indices of the examples for each class. 

```python
            # grab the dataset indices for this class
            all_indices = self.dataset.class_to_indices[c]
```

Once we have a hold of all the indices for that given class (`c`), we can grab `n_support + n_query` items from the dataset. 

```python
            # sample the support and query sets for this class
            indices = random.sample(all_indices, self.n_support + self.n_query)
            items = [self.dataset[i] for i in indices]
```

We can add the class target to each item we sampled. 

**NOTE**: note that the the index of the target is with respect to the `episode_classlist`. This is important, since we will use this index later to calculate the cross-entropy loss during training. 

```python
            # add the class label to each item
            for item in items:
                item["target"] = torch.tensor(episode_classlist.index(c))
```

Finally, we can split all the items we sampled into support and query items.
```python
            # split the support and query sets
            support.extend(items[:self.n_support])
            query.extend(items[self.n_support:]) 
```

To wrap it up, we will collate the items in each set into a dictionary, to make batch processing possible. Since the details of writing a collating function aren't covered here, we invite the reader to check out the [`PyTorch Dataset docs`](https://pytorch.org/docs/stable/data.html#loading-batched-and-non-batched-data) for more information.  
```python
        # collate the support and query sets
        support = util.collate_list_of_dicts(support)
        query = util.collate_list_of_dicts(query)

        support["classlist"] = episode_classlist
        query["classlist"] = episode_classlist
        
        yield support, query
```


Unhide the cell below to see the full implementation of the `EpisodeSampler` class.

In [2]:

class EpisodeSampler(torch.utils.data.Sampler):
    """
        A sampler for few-shot learning tasks.

    Args:
        dataset (ClassConditionalDataset): The dataset to sample episodes from.
        n_way (int): The number of classes to sample per episode.
            Default: 5.
        n_support (int): The number of samples per class to use as support.
            Default: 5.
        n_query (int): The number of samples per class to use as query.
            Default: 20.
        n_episodes (int): The number of episodes to generate.
            Default: 100.
    """
    def __init__(self,
        dataset: ClassConditionalDataset, 
        n_way: int = 5, 
        n_support: int = 5,
        n_query: int = 20,
        n_episodes: int = 100,
    ):
        self.dataset = dataset

        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.n_episodes = n_episodes
    
    def __iter__(self):
        """Iterate through the episodes generated by the sampler. 

        Each episode is a tuple of two dictionaries: a support set and a query set.
        The support set contains a set of samples from each of the classes in the
        episode, and the query set contains another set of samples from each of the
        classes. The class labels are added to each item in the support and query
        sets, and the list of classes is also included in each dictionary.

        Yields:
            Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the support
            set and the query set for an episode.
        """
        # sample the list of classes for this episode
        episode_classlist = random.sample(self.dataset.classlist, self.n_way)

        # sample the support and query sets for this episode
        support, query = [], []
        for c in episode_classlist:
            # grab the dataset indices for this class
            all_indices = self.dataset.class_to_indices[c]

            # sample the support and query sets for this class
            indices = random.sample(all_indices, self.n_support + self.n_query)
            items = [self.dataset[i] for i in indices]

            # add the class label to each item
            for item in items:
                item["target"] = torch.tensor(episode_classlist.index(c))

            # split the support and query sets
            support.extend(items[:self.n_support])
            query.extend(items[self.n_support:])

        # collate the support and query sets
        support = util.collate_list_of_dicts(support)
        query = util.collate_list_of_dicts(query)

        support["classlist"] = episode_classlist
        query["classlist"] = episode_classlist
        
        yield support, query

    def __len__(self):
        return self.n_episodes

    def print_episode(self, support, query):
        """Print a summary of the support and query sets for an episode.

        Args:
            support (Dict[str, Any]): The support set for an episode.
            query (Dict[str, Any]): The query set for an episode.
        """
        print("Support Set:")
        print(f"  Classlist: {support['classlist']}")
        print(f"  Audio Shape: {support['audio'].shape}")
        print(f"  Target Shape: {support['target'].shape}")
        print()
        print("Query Set:")
        print(f"  Classlist: {query['classlist']}")
        print(f"  Audio Shape: {query['audio'].shape}")
        print(f"  Target Shape: {query['target'].shape}")





## Putting it Together: Sampling an Example Episode

Super! Let's grab the class-conditional `TinySol` we created last chapter, and use the `EpisodeSampler` to sample an episode from it.

In [6]:
%%capture
from music_fsl.data import TinySOL

dataset = TinySOL()

# create an episodic dataset
sampler = EpisodeSampler(
    dataset,
    n_way=5, 
    n_support=5,
    n_query=20,
    n_episodes=100,
)

support, query = next(iter(sampler))

INFO: [annotations] downloading TinySOL_metadata.csv
INFO: /home/hugo/mir_datasets/tinysol/annotation/TinySOL_metadata.csv already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.


In [7]:
sampler.print_episode(support, query)

Support Set:
  Classlist: ['Oboe', 'Bassoon', 'Viola', 'Trombone', 'French Horn']
  Audio Shape: torch.Size([25, 1, 16000])
  Target Shape: torch.Size([25])

Query Set:
  Classlist: ['Oboe', 'Bassoon', 'Viola', 'Trombone', 'French Horn']
  Audio Shape: torch.Size([100, 1, 16000])
  Target Shape: torch.Size([100])


In this chapter, we learned how to create an EpisodicSampler class that extends the Sampler class in PyTorch to sample few-shot learning episodes. The EpisodicSampler allows us to specify the number of classes to sample per episode, the number of support and query examples to sample per class, and the number of episodes to sample. It iterates over the episodes and yields a support and query set for each episode, where the support set contains labeled examples for each of the classes in the episode and the query set contains unlabeled examples for each of the classes. This allows us to use the EpisodicSampler to generate few-shot learning tasks for training and evaluation.

Next, we'll write code to create a Prototypical Network model that can be trained on the episodes generated by the EpisodicSampler.