In [None]:
%%capture
!pip install torch
!pip install pytorch-lightning
!pip install numpy
!pip install --no-cache-dir --upgrade music-fsl

##Variables


In [None]:
sample_rate = 16000 # sample rate of the audio
n_way= 5 # number of classes per episode
n_support = 5 # number of support examples per class
n_query = 20 # number of samples per class to use as query
n_train_episodes = int(1000) # number of episodes to generate for training
n_val_episodes = 50 # number of episodes to generate for validation
num_workers = 10 # number of workers to use for data loading

##Good_sounds

In [None]:
import torch
import numpy as np
from torch import nn
import pytorch_lightning as pl
from torchmetrics import Accuracy


from music_fsl.backbone import Backbone
from music_fsl.protonet import PrototypicalNet

In [None]:
import torch
from typing import List, Dict, Any

class ClassConditionalDataset(torch.utils.data.Dataset):

    def __getitem__(self, index: int) -> Dict[Any, Any]:
        """
        Grab an item from the dataset. The item returned must be a dictionary. 
        """
        raise NotImplementedError
    
    @property
    def classlist(self) -> List[str]:
        """
        The classlist property returns a list of class labels available in the dataset.
        This property enables users of the dataset to easily access a list of all the classes in the dataset.

        Returns:
            List[str]: A list of class labels available in the dataset. 
        """
        raise NotImplementedError

    @property
    def class_to_indices(self) -> Dict[str, List[int]]:
        """
        Returns a dictionary where the keys are class labels and the values are 
        lists of indices in the dataset that belong to that class. 
        This property enables users of the dataset to easily access 
        examples that belong to specific classes. 

        Implement me!

        Returns:
            Dict[str, List[int]]: A dictionary mapping class labels to lists of dataset indices. 
        """
        raise NotImplementedError

In [None]:
# from collections import defaultdict
# import mirdata
# import librosa
# import music_fsl.util as util
# from typing import List, Dict


# class GoodSounds(ClassConditionalDataset):
#     """
#     Initialize a `GoodSounds Dataset Loader` dataset instance.
    
#     Args:
#         instruments (List[str]): A list of instruments to include in the dataset.
#         duration (float): The duration of each audio clip in the dataset (in seconds).
#         sample_rate (int): The sample rate of the audio clips in the dataset (in Hz).
#         dataset - loaded mirdata.dataset
#     """

#     INSTRUMENTS = [
#         'flute', 'cello', 'clarinet', 'trumpet', 'violin', 'sax_alto', 'sax_tenor', 'sax_baritone', 'sax_soprano', 'oboe', 'piccolo', 'bass'
#     ]

#     def __init__(self, 
#             instruments: List[str] = None,
#             duration: float = 1.0, 
#             sample_rate: int = 16000,
#             dataset = None,
#         ):
#         if instruments is None:
#             instruments = self.INSTRUMENTS

#         self.instruments = instruments  
#         self.duration = duration
#         self.sample_rate = sample_rate

#         # initialize the medley_solos_db dataset and download if necessary
#         if dataset is not None:
#             self.dataset = dataset
#         else:
#           self.dataset = mirdata.initialize('medley_solos_db')
#           self.dataset.download()

#         # make sure the instruments passed in are valid
#         for instrument in instruments:
#             assert instrument in self.INSTRUMENTS, f"{instrument} is not a valid instrument"

#         # load all tracks for this instrument
#         self.tracks = []
#         i = 0
#         for track in self.dataset.load_tracks().values():
#             if track.instrument in self.instruments:
#               if librosa.get_duration(filename=track.audio_path) >= duration:
#                 self.tracks.append(track)


#     @property
#     def classlist(self) -> List[str]:
#         return self.instruments

#     @property
#     def class_to_indices(self) -> Dict[str, List[int]]:
#         # cache it in self._class_to_indices 
#         # so we don't have to recompute it every time
#         if not hasattr(self, "_class_to_indices"):
#             self._class_to_indices = defaultdict(list)
#             for i, track in enumerate(self.tracks):
#                 self._class_to_indices[track.instrument].append(i)

#         return self._class_to_indices

#     def __getitem__(self, index) -> Dict:
#         # load the track for this index
#         track = self.tracks[index]

#         # load the excerpt
#         data = util.load_excerpt(track.audio_path, self.duration, self.sample_rate)
#         data["label"] = track.instrument

#         return data

#     def __len__(self) -> int:
#         return len(self.tracks)

In [None]:
class GoodSounds(ClassConditionalDataset):
    """
    Initialize a `GoodSounds Dataset Loader` dataset instance.
    
    Args:
        instruments (List[str]): A list of instruments to include in the dataset.
        duration (float): The duration of each audio clip in the dataset (in seconds).
        sample_rate (int): The sample rate of the audio clips in the dataset (in Hz).
        dataset - loaded mirdata.dataset
    """

    INSTRUMENTS = [
        'flute', 'cello', 'clarinet', 'trumpet', 'violin', 'sax_alto', 'sax_tenor', 'sax_baritone', 'sax_soprano', 'oboe', 'piccolo', 'bass'
    ]

    def __init__(self, 
            instruments: List[str] = None,
            duration: float = 1.0, 
            sample_rate: int = 16000,
            dataset_path: str = None
        ):
        if instruments is None:
            instruments = self.INSTRUMENTS

        self.instruments = instruments  
        self.duration = duration
        self.sample_rate = sample_rate
        self.dataset_path = dataset_path

        # make sure the instruments passed in are valid
        for instrument in instruments:
            assert instrument in self.INSTRUMENTS, f"{instrument} is not a valid instrument"

        # load all tracks for this instrument
        self.tracks = []
        for dir in os.listdir(self.dataset_path):
            ins = dir.split('_')[0]
            if ins in self.instruments:
                for subdir_dir, dirs_dir, files_dir in os.walk(os.path.join(self.dataset_path, dir, 'neumann')):
                    for file in files_dir:
                        if file.endswith('.wav'):
                            if librosa.get_duration(filename=os.path.join(self.dataset_path, dir, 'neumann', file)) >= duration:
                                self.tracks.append([os.path.join(self.dataset_path, dir, 'neumann', file), ins])
            else:
                ins = f'{ins}_{dir.split("_")[1]}'
                if ins in self.instruments:
                    for subdir_dir, dirs_dir, files_dir in os.walk(os.path.join(self.dataset_path, dir, 'neumann')):
                        for file in files_dir:
                            if file.endswith('.wav'):
                                if librosa.get_duration(filename=os.path.join(self.dataset_path, dir, 'neumann', file)) >= duration:
                                    self.tracks.append([os.path.join(self.dataset_path, dir, 'neumann', file), ins])


    @property
    def classlist(self) -> List[str]:
        return self.instruments

    @property
    def class_to_indices(self) -> Dict[str, List[int]]:
        # cache it in self._class_to_indices 
        # so we don't have to recompute it every time
        if not hasattr(self, "_class_to_indices"):
            self._class_to_indices = defaultdict(list)
            for i, track in enumerate(self.tracks):
                self._class_to_indices[track[1]].append(i)

        return self._class_to_indices

    def __getitem__(self, index) -> Dict:
        # load the track for this index
        track = self.tracks[index]

        # load the excerpt
        data = util.load_excerpt(track[0], self.duration, self.sample_rate)
        data["label"] = track[1]

        return data

    def __len__(self) -> int:
        return len(self.tracks)

In [None]:
import random
import torch
from music_fsl.data import ClassConditionalDataset
import music_fsl.util as util

from typing import Tuple, Dict
class EpisodeDataset(torch.utils.data.Dataset):
    """
        A dataset for sampling few-shot learning tasks from a class-conditional dataset.

    Args:
        dataset (ClassConditionalDataset): The dataset to sample episodes from.
        n_way (int): The number of classes to sample per episode.
            Default: 5.
        n_support (int): The number of samples per class to use as support.
            Default: 5.
        n_query (int): The number of samples per class to use as query.
            Default: 20.
        n_episodes (int): The number of episodes to generate.
            Default: 100.
    """
    def __init__(self,
        dataset: ClassConditionalDataset, 
        n_way: int = 5, 
        n_support: int = 5,
        n_query: int = 20,
        n_episodes: int = 100,
    ):
        self.dataset = dataset

        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.n_episodes = n_episodes
    
    def __getitem__(self, index: int) -> Tuple[Dict, Dict]:
        """Sample an episode from the class-conditional dataset. 

        Each episode is a tuple of two dictionaries: a support set and a query set.
        The support set contains a set of samples from each of the classes in the
        episode, and the query set contains another set of samples from each of the
        classes. The class labels are added to each item in the support and query
        sets, and the list of classes is also included in each dictionary.

        Yields:
            Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing the support
            set and the query set for an episode.
        """
        # seed the random number generator so we can reproduce this episode
        rng = random.Random(index)

        # sample the list of classes for this episode
        episode_classlist = rng.sample(self.dataset.classlist, self.n_way)

        # sample the support and query sets for this episode
        support, query = [], []
        for c in episode_classlist:
            # grab the dataset indices for this class
            all_indices = self.dataset.class_to_indices[c]
            # sample the support and query sets for this class
            indices = rng.sample(all_indices, self.n_support + self.n_query)
            items = [self.dataset[i] for i in indices]

            # add the class label to each item
            for item in items:
                item["target"] = torch.tensor(episode_classlist.index(c))

            # split the support and query sets
            support.extend(items[:self.n_support])
            query.extend(items[self.n_support:])

        # collate the support and query sets
        support = util.collate_list_of_dicts(support)
        query = util.collate_list_of_dicts(query)

        support["classlist"] = episode_classlist
        query["classlist"] = episode_classlist
        
        return support, query

    def __len__(self):
        return self.n_episodes

    def print_episode(self, support, query):
        """Print a summary of the support and query sets for an episode.

        Args:
            support (Dict[str, Any]): The support set for an episode.
            query (Dict[str, Any]): The query set for an episode.
        """
        print("Support Set:")
        print(f"  Classlist: {support['classlist']}")
        print(f"  Audio Shape: {support['audio'].shape}")
        print(f"  Target Shape: {support['target'].shape}")
        print()
        print("Query Set:")
        print(f"  Classlist: {query['classlist']}")
        print(f"  Audio Shape: {query['audio'].shape}")
        print(f"  Target Shape: {query['target'].shape}")


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

TRAIN_INSTRUMENTS = [
       'cello', 'clarinet', 'violin', 'sax_alto', 'sax_baritone', 'sax_soprano', 'piccolo',
    ]

TEST_INSTRUMENTS = [
        'flute', 'trumpet', 'sax_tenor', 'oboe', 'bass'
    ]


In [None]:
# initialize the datasets
train_data = GoodSounds(
    instruments=TRAIN_INSTRUMENTS, 
    sample_rate=sample_rate,
    dataset_path = '/content/drive/MyDrive/good_sounds/sound_files'
)

val_data = GoodSounds(
    instruments=TEST_INSTRUMENTS, 
    sample_rate=sample_rate,
    dataset_path = '/content/drive/MyDrive/good_sounds/sound_files'
)

In [None]:
print(f"The dataset has {len(train_data)} examples.")
print(f"The dataset has {len(train_data.classlist)} classes.\n")

# print the number of examples for each class
for instrument, indices in train_data.class_to_indices.items():
    print(f"{instrument} has {len(indices)} examples")

print(f"The dataset has {len(val_data)} examples.")
print(f"The dataset has {len(val_data.classlist)} classes.\n")

# print the number of examples for each class
for instrument, indices in val_data.class_to_indices.items():
    print(f"{instrument} has {len(indices)} examples")

In [None]:
# initialize the episode datasets
train_episodes = EpisodeDataset(
    dataset=train_data, 
    n_way=n_way, 
    n_support=n_support,
    n_query=n_query, 
    n_episodes=n_train_episodes
)

val_episodes = EpisodeDataset(
    dataset=val_data, 
    n_way=n_way, 
    n_support=n_support,
    n_query=n_query, 
    n_episodes=n_val_episodes
)

In [None]:
# initialize the dataloaders
from torch.utils.data import DataLoader
train_loader = DataLoader(
    train_episodes, 
    batch_size=None,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_episodes, 
    batch_size=None,
    num_workers=num_workers
)

In [None]:
# build models
backbone = Backbone(sample_rate=sample_rate)
protonet = PrototypicalNet(backbone)


In [None]:

class FewShotLearner(pl.LightningModule):

    def __init__(self, 
        protonet: nn.Module, 
        learning_rate: float = 0.001,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.protonet = protonet
        self.learning_rate = learning_rate

        self.loss = nn.CrossEntropyLoss()
        self.metrics = nn.ModuleDict({
            'accuracy': Accuracy()
        })

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def step(self, batch, batch_idx, tag: str):
        support, query = batch

        logits = self.protonet(support, query)
        loss = self.loss(logits, query["target"])

        output = {"loss": loss}
        for k, metric in self.metrics.items():
            output[k] = metric(logits, query["target"])

        for k, v in output.items():
            self.log(f"{k}/{tag}", v)
        return output

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "test")

In [None]:
learner = FewShotLearner(protonet)

##Train


In [None]:
# set up the trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler import SimpleProfiler

trainer = pl.Trainer(
    gpus=1 if torch.cuda.is_available() else 0,
    max_epochs=1,
    log_every_n_steps=1, 
    val_check_interval=50,
    profiler=SimpleProfiler(
        filename="profile.txt",
    ), 
    logger=TensorBoardLogger(
        save_dir=".",
        name="logs"
    ), 
)

# train!
trainer.fit(learner, train_loader, val_dataloaders=val_loader)

##Save

In [None]:
!zip -r /content/good_sounds_5_5_001.zip /content/logs

In [None]:
from google.colab import files
files.download('/content/good_sounds_5_5_001.zip') 

##Tensorboard

In [None]:
!pip install tensorboard

In [None]:
%load_ext tensorboard

In [None]:
tensorboard --logdir /content/logs

##Data visualisation

In [None]:
%%capture
!pip install "torchmetrics==0.10.2" 
!pip install tqdm

In [None]:
from pathlib import Path

import numpy as np
import torch
import tqdm
from torchmetrics import Accuracy

from music_fsl.util import dim_reduce, embedding_plot, batch_device

In [None]:
checkpoint_path = "/content/logs/version_0/checkpoints/epoch=0-step=1000.ckpt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sample_rate = 16000

In [None]:
protonet = PrototypicalNet(Backbone(sample_rate))
learner = FewShotLearner.load_from_checkpoint(checkpoint_path, protonet=protonet)
learner.eval()
learner = learner.to(DEVICE)

In [None]:
n_query = 15
n_episodes = 50 

dataset = GoodSounds(
    instruments=TEST_INSTRUMENTS, 
    sample_rate=sample_rate,
    dataset_path = '/content/drive/MyDrive/good_sounds/sound_files'
)
# load our evaluation data
test_episodes = EpisodeDataset(
    dataset=dataset, 
    n_way=n_way, 
    n_support=n_support,
    n_query=n_query, 
    n_episodes=n_episodes
)

In [None]:
metric = Accuracy(num_classes=n_way, average="samples")

In [None]:
# collect all the embeddings in the test set
# so we can plot them later
embedding_table = []
pbar = tqdm.tqdm(range(len(test_episodes)))
for episode_idx in pbar:
    support, query = test_episodes[episode_idx]

    # move all tensors to cuda if necessary
    batch_device(support, DEVICE)
    batch_device(query, DEVICE)

    # get the embeddings
    logits = learner.protonet(support, query)

    # compute the accuracy
    acc = metric(logits, query["target"])
    pbar.set_description(f"Episode {episode_idx} // Accuracy: {acc.item():.2f}")

    # add all the support and query embeddings to our records
    for subset_idx, subset in enumerate((support, query)):
        for emb, label in zip(subset["embeddings"], subset["target"]):
            embedding_table.append({
                "embedding": emb.detach().cpu().numpy(),
                "label": support["classlist"][label],
                "marker": ("support", "query")[subset_idx], 
                "episode_idx": episode_idx
            })
        
    # also add the prototype embeddings to our records
    for class_idx, emb in enumerate(support["prototypes"]):
        embedding_table.append({
            "embedding": emb.detach().cpu().numpy(),
            "label": support["classlist"][class_idx],
            "marker": "prototype", 
            "episode_idx": episode_idx
        })

In [None]:
# compute the total accuracy across all episodes
total_acc = metric.compute()
print(f"Total accuracy, averaged across all episodes: {total_acc:.2f}")

In [None]:
# perform a TSNE over all embeddings in the test dataset
embeddings = dim_reduce(
    embeddings=np.stack([d["embedding"] for d in embedding_table]),
    method="tsne",
    n_components=2,
)

# replace the original 512-dim embeddings with the 2-dim tsne embeddings
# in our embedding table
for entry, dim_reduced_embedding in zip(embedding_table, embeddings):
    entry["embedding"] = dim_reduced_embedding

In [None]:
fig = embedding_plot(
    proj=np.stack([d["embedding"] for d in embedding_table]),
    color_labels=[d["label"] for d in embedding_table],
    marker_labels=[d["marker"] for d in embedding_table],
    title="IRMAS Protonet Embeddings",
)

fig.show()

In [None]:
episode_idx = 5

subtable = [d for d in embedding_table if d["episode_idx"] == episode_idx]

fig = embedding_plot(
    proj=np.stack([d["embedding"] for d in subtable]),
    color_labels=[d["label"] for d in subtable],
    marker_labels=[d["marker"] for d in subtable],
    title=f"episode {episode_idx} -- embeddings",
)
fig.show()