# Few-Shot Learning For Musical Instrument ID
## PyTorch Example: Prototypical Network on TinySOL Dataset

**TODO** Figure out how to tag the Prototypical Network section in approaches. 
In this coding example, we will train a [Prototypical Network](foundations/approaches.md) for performing the task of Musical Instrument ID using PyTorch. 

Since it is easily accessible with [mirdata](), we will use the [TinySOL dataset](), which includes a total of XX instrument classes. For a bigger challenge, try using the [MedleyDB and MedleyDB 2.0 datasets]() which in total contain a total of XX instrument classes {cite}`flores2021leveraging`.

Note that throughout the notebook we will also be using the `common` library, internal to this tutorial, which contains several utilities used in different chapters of this tutorial. To use the common library in your own code, you can clone it from our [github repository](https://github.com/music-fsl-zsl/tutorial), 

## Table of Contents
1. [Data](#data)
2. [Model](#model)
3. [Training](#training)
4. [Evaluation](#evaluation)
5. [Conclusion](#conclusion)
6. [References](#references)

## But First! Install Requirements

In [7]:
%%capture
# data
!pip install mirdata

# models
!pip install torch
!pip install torchaudio

# audio
!pip install librosa
!pip install soundfile

## ... And Set a Random Seed for Reproducibility

In [8]:
import random
import numpy as np
import torch

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f433ee39df0>

## 1. Data <a name="data"></a>

As previously mentioned, we will be working with the [TinySOL](https://zenodo.org/record/3632193) dataset, which contains 2478 samples of single notes, performed by 14 different instruments. 

Let's start by downloading the dataset and splitting it into train, validation and test sets, using [mirdata](https://github.com/mir-dataset-loaders/mirdata). 

In [9]:
import mirdata

dataset = mirdata.initialize('tinysol')
dataset.download()

INFO: Downloading ['audio', 'annotations'] to /home/hugo/mir_datasets/tinysol
INFO: [audio] downloading TinySOL.tar.gz
INFO: /home/hugo/mir_datasets/tinysol/audio/TinySOL.tar.gz already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.
INFO: [annotations] downloading TinySOL_metadata.csv
INFO: /home/hugo/mir_datasets/tinysol/annotation/TinySOL_metadata.csv already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.


In [10]:
from common.audio_utils import widget
example_track = dataset.choice_track()

print("=================================")
print(f"Sample: {example_track.instrument_full}")
widget(example_track.audio_path)

Sample: Clarinet in Bb


The next thing we need to do is create train/validation/test splits. 

Remember, in few-shot learning, we are interested in learning to recognize new instruments during inference time, so there must be no overlap in instrument classes between the train and test sets.

We'll use 70% of the instrument classes for training, 15% for validation, and 15% for testing.

In [11]:
from typing import List
from typing import Tuple
import random

instruments = list(set([track.instrument_full for track in dataset.load_tracks().values()]))
instruments.sort()

# figure out the instrument splits
random.shuffle(instruments)

def train_val_test_split(classes: List[str], ratios: Tuple[float]):
    """
    Split a list of classes into train, val, and test sets.
    """
    assert sum(ratios) == 1.0, "Ratios must sum to 1"
    n_train = int(len(classes) * ratios[0])
    n_val = int(len(classes) * ratios[1])
    n_test = len(classes) - n_train - n_val
    return classes[:n_train], classes[n_train:n_train+n_val], classes[n_train+n_val:]

train_insts, val_insts, test_insts = train_val_test_split(instruments, (0.8, 0.1, 0.1))

# print the instruments, as well as splits
print(f"=========== Instruments =============")
print(f"There are {len(instruments)} instruments in the dataset.")
print(instruments)

print(f"\n=========== Splits =============")
print(f"Train ({len(train_insts)}): {train_insts}")
print(f"Val ({len(val_insts)}): {val_insts}")
print(f"Test ({len(test_insts)}): {test_insts}")


There are 14 instruments in the dataset.
['French Horn', 'Violin', 'Flute', 'Contrabass', 'Trombone', 'Viola', 'Clarinet in Bb', 'Bass Tuba', 'Oboe', 'Bassoon', 'Cello', 'Trumpet in C', 'Accordion', 'Alto Saxophone']

Train (11): ['French Horn', 'Violin', 'Flute', 'Contrabass', 'Trombone', 'Viola', 'Clarinet in Bb', 'Bass Tuba', 'Oboe', 'Bassoon', 'Cello']
Val (1): ['Trumpet in C']
Test (2): ['Accordion', 'Alto Saxophone']


Now that we have separated our instruments into train, validation and test sets, we can create a dataset class for creating episodes from these sets.

In [12]:
from collections import defaultdict
from typing import Callable, Dict
import torch
import librosa

class BaseEpisodicDataset(torch.utils.data.Dataset):
    """
    Base class for episodic datasets.

    This class handles the logic of sampling episodes, but leaves the actual
    loading of individual items to subclasses. Subclasses should implement the
    load_items_by_class method. 
    """

    def __init__(self,
        classlist: List[str],
        n_way: int, n_support: int,
        n_query: int, n_episodes: int,
    ):
        self.classlist = classlist
        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.n_episodes = n_episodes

    def load_items_by_class(self, 
            class_name: str, 
            n: int, 
            rng: random.Random
        ) -> List[Dict]:
        """
        Sample n items from a class.

        implement me!
        """
        raise NotImplementedError

    def __getitem__(self, idx: int): 
        """
        Sample an episode of n_way classes, 
        with n_support + n_query items each. 
        """
        # use the index to seed the random number generator
        # so we make sure that the same episode is sampled every time
        rng = random.Random(idx)

        # first, pick which classes to use
        classes = rng.sample(self.classlist, self.n_way)

        #  then, sample n_support + n_query items from each class
        support = []
        query = []
        for c in classes:
            items = self.load_items_by_class(c, self.n_support + self.n_query, rng)
            support.extend(items[:self.n_support])
            query.extend(items[self.n_support:])

        return support, query

    def __len__(self):
        return self.n_episodes

def load_excerpt(audio_path: str, duration: float):
    """
    Load an excerpt of audio from a file.
    """
    audio, sr = librosa.load(audio_path, sr=None, mono=True)
    start = np.random.randint(0, len(audio) - int(duration * sr))
    return audio[start:start+int(duration*sr)], sr

class EpisodicTinySOL(BaseEpisodicDataset):
    def __init__(self, 
            dataset: mirdata.datasets.tinysol, 
            instruments: List[str],
            n_way: int = 4,
            n_support: int = 5,
            n_query: int = 15,
            n_episodes: int = 1000,
            duration: float = 1.0,
            sample_rate: int = 16000,
            transform: Callable=None
        ):
        super().__init__(instruments, n_way, n_support, n_query, n_episodes)
        self.dataset = dataset
        self.instruments = instruments
        self.transform = transform
        self.duration = duration
        self.sample_rate = sample_rate

        # organize the tracks by instrument, 
        # so we can create episodes from 
        self.tracks = defaultdict(list)
        for track in dataset.load_tracks().values():
            if track.instrument_full in instruments:
                self.tracks[track.instrument_full].append(track)

    def load_item(self,
        track: mirdata.datasets.tinysol.Track
    ) -> Dict:
        """ 
        Given a track object, load the audio and return a
        dictionary containing the audio metdata, such a
        the instrument label, audio_path, sr. 
        """
        audio, sr = load_excerpt(track.audio_path, self.duration)
        if self.transform:
            audio = self.transform(audio)
        
        item = {
            'audio': audio,
            'audio_path': track.audio_path,
            'sr': sr,
            'instrument': track.instrument_full,
            'label': self.instruments.index(track.instrument_full) 
        }
        return item

    def load_items_by_class(self, 
            classname: str, n: int, 
            rng: random.Random
        ) -> List[Dict]:
        """
        Sample n items from a given class. 
        """
        tracks = self.tracks[classname]
        sample = rng.sample(tracks, n)
        return [self.load_item(t) for t in sample]


In [14]:
from torch.utils.data import DataLoader

BATCH_SIZE = 1
NUM_WORKERS = 0

train_dataset = EpisodicTinySOL(dataset, train_insts)
val_dataset = EpisodicTinySOL(dataset, val_insts)
test_dataset = EpisodicTinySOL(dataset, test_insts)

train_loader = DataLoader(
    train_dataset, batch_size=1, 
    num_workers=NUM_WORKERS, shuffle=True
)
val_loader = DataLoader(
    val_dataset, batch_size=1,
    num_workers=NUM_WORKERS, shuffle=False
)
test_loader = DataLoader(
    test_dataset, batch_size=1,
    num_workers=NUM_WORKERS, shuffle=False
)


We now have train, validation and test dataloaders, which we can use to train and evaluate our model! Let's look at a sample batch.

In [15]:
sample_batch = next(iter(train_loader))

support, query = sample_batch

support[0]

{'audio': tensor([[-0.0519, -0.0414, -0.0291,  ...,  0.0323,  0.0186,  0.0041]]),
 'audio_path': ['/home/hugo/mir_datasets/tinysol/audio/Winds/Clarinet_Bb/ordinario/ClBb-ord-E6-mf-N-T19d.wav'],
 'sr': tensor([44100]),
 'instrument': ['Clarinet in Bb'],
 'label': tensor([6])}

# 2. Model <a name="model"></a>

We will be using a [Prototypical Network](foundations/approaches.md) for this task.

Because forward passes in few-shot learning require a support set and a query set, we will 