# Defining a custom task

In `pyannote.audio`, a *task* is a combination of a **_problem_** that needs to be addressed and an **experimental protocol**.

For example, one can address **_voice activity detection_** following the **AMI only_words** experimental protocol, by instantiating the following *task*:


In [None]:
# this assumes that the AMI corpus has been setup for diarization
# according to https://github.com/pyannote/AMI-diarization-setup
import os
os.environ['PYANNOTE_DATABASE_CONFIG'] = '/Users/bredin/Development/pyannote/pyannote-db/AMI-diarization-setup/pyannote/database.yml'

from pyannote.database import get_protocol, FileFinder
ami = get_protocol('AMI.SpeakerDiarization.only_words', 
                   preprocessors={'audio': FileFinder()})

# address voice activity detection
from pyannote.audio.tasks import VoiceActivityDetection
task = VoiceActivityDetection(ami)

A growing collection of tasks is readily available in `pyannote.audio.tasks`...

In [None]:
from pyannote.audio.tasks import __all__ as TASKS; print('\n'.join(TASKS))

... but you will eventually want to use `pyannote.audio` to address a different task.  
In this example, we will add a new task addressing the **sound event detection** problem.



## Problem specification

A problem is expected to be solved by a model $f$ that takes an audio chunk  $X$ as input and returns its predicted solution $\hat{y} = f(X)$. 

### Resolution

Depending on the addressed problem, you might expect the model to output just one prediction for the whole audio chunk (`Resolution.CHUNK`) or a temporal sequence of predictions (`Resolution.FRAME`).

In our particular case, we would like the model to provide one decision for the whole chunk:

In [None]:
from pyannote.audio.core.task import Resolution
resolution = Resolution.CHUNK

### Type of problem

Similarly, the type of your problem may fall into one of these generic machine learning categories:
* `Problem.BINARY_CLASSIFICATION` for binary classification
* `Problem.MONO_LABEL_CLASSIFICATION` for multi-class classification 
* `Problem.MULTI_LABEL_CLASSIFICATION` for multi-label classification
* `Problem.REGRESSION` for regression
* `Problem.REPRESENTATION` for representation learning

In our particular case, we would like the model to do multi-label classification because one audio chunk may contain multiple sound events:

In [None]:
from pyannote.audio.core.task import Problem
problem = Problem.MULTI_LABEL_CLASSIFICATION

In [None]:
from pyannote.audio.core.task import Specifications
specifications = Specifications(
    problem=problem,
    resolution=resolution,
    duration=5.0,
    classes=["Speech", "Dog", "Cat", "Alarm_bell_ringing", "Dishes", 
             "Frying", "Blender", "Running_water", "Vacuum_cleaner", 
             "Electric_shaver_toothbrush"],
)

A task is expected to be solved by a model $f$ that (usually) takes an audio chunk  $X$ as input and returns its predicted solution $\hat{y} = f(X)$. 

To help training the model $f$, the task $\mathcal{T}$ is in charge of 
- generating $(X, y)$ training samples using the **dataset**
- defining the loss function $\mathcal{L}(y, \hat{y})$


In [None]:
from typing import Optional
import torch
import torch.nn as nn
import numpy as np
from pyannote.core import Annotation
from pyannote.audio import Model
from pyannote.audio.core.task import Task, Resolution

# Your custom task must be a subclass of `pyannote.audio.core.task.Task`
class SoundEventDetection(Task):
    """Sound event detection"""

    def __init__(
        self,
        protocol: Protocol,
        duration: float = 5.0,
        warm_up: Union[float, Tuple[float, float]] = 0.0,
        batch_size: int = 32,
        num_workers: int = None,
        pin_memory: bool = False,
        augmentation: BaseWaveformTransform = None,
        **other_params,
    ):

        super().__init__(
            protocol,
            duration=duration,
            min_duration=min_duration,
            warm_up=warm_up,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            augmentation=augmentation,
        )

    def setup(self, stage=None):

        if stage == "fit":

            # load metadata for training subset
            self.train_metadata_ = list()
            for training_file in self.protocol.train():
                self.training_metadata_.append({
                    # path to audio file (str)
                    "audio": training_file["audio"],
                    # duration of audio file (float)
                    "duration": training_file["duration"],
                    # reference annotation (pyannote.core.Annotation)
                    "annotation": training_file["annotation"],
                })

            # gather the list of classes
            classes = set()
            for training_file in self.train_metadata_:
                classes.update(training_file["reference"].labels())
            classes = sorted(classes)

            # specify the addressed problem
            self.specifications = Specifications(
                # it is a multi-label classification problem
                problem=Problem.MULTI_LABEL_CLASSIFICATION,
                # we expect the model to output one prediction 
                # for the whole chunk
                resolution=Resolution.CHUNK,
                # the model will ingest chunks with that duration (in seconds)
                duration=self.duration,
                # human-readable names of classes
                classes=classes)

            # `has_validation` is True iff protocol defines a development set
            if not self.has_validation:
                return

            # load metadata for validation subset
            self.validation_metadata_ = list()
            for validation_file in self.protocol.development():
                self.validation_metadata_.append({
                    "audio": validation_file["audio"],
                    "num_samples": math.floor(validation_file["duration"] / self.duration),
                    "annotation": validation_file["annotation"],
                })
            
            

    def train__iter__(self):
        # this method generates training samples, one at a time, "ad infinitum". each worker 
        # of the dataloader will run it, independently from other workers. pyannote.audio and
        # pytorch-lightning will take care of making batches out of it.

        # create worker-specific random number generator (RNG) to avoid this common bug:
        # tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
        rng = create_rng_for_worker(self.model.current_epoch)

        # load list and number of classes
        classes = self.specifications.classes
        num_classes = len(classes)

        # yield training samples "ad infinitum"
        while True:

            # select training file at random
            random_training_file, *_ = rng.choices(self.train_metadata_, k=1)

            # select one chunk at random 
            random_start_time = rng.uniform(0, random_training_file["duration"] - self.duration)
            random_chunk = Segment(random_start_time, random_start_time + self.duration)

            # load audio excerpt corresponding to random chunk
            X = self.model.audio.crop(random_training_file["audio"], 
                                      random_chunk, 
                                      fixed=self.duration)
            
            # load labels corresponding to random chunk as {0|1} numpy array
            # y[k] = 1 means that kth class is active
            y = np.zeros((num_classes,))
            active_classes = random_training_file["annotation"].crop(random_chunk).labels()
            for active_class in active_classes:
                y[classes.index(active_class)] = 1
        
            # yield training samples as a dict (use 'X' for input and 'y' for target)
            yield {'X': X, 'y': y}

    def train__len__(self):
        # since train__iter__ runs "ad infinitum", we need a way to define what an epoch is.
        # this is the purpose of this method. it outputs the number of training samples that
        # make an epoch.

        # we compute this number as the total duration of the training set divided by 
        # duration of training chunks. we make sure that an epoch is at least one batch long,
        # or pytorch-lightning will complain
        train_duration = sum(training_file["duration"] for training_file in self.train_metadata_)
        return max(self.batch_size, math.ceil(train_duration / self.duration))

    def val__getitem__(self, sample_idx):

        # load list and number of classes
        classes = self.specifications.classes
        num_classes = len(classes)


        # find which part of the validation set corresponds to sample_idx
        num_samples = np.cumsum([
            validation_file["num_samples"] for validation_file in self.validation_metadata_])
        file_idx = np.where(num_samples < sample_idx)[0][0]
        validation_file = self.validation_metadata_[file_idx]
        idx = sample_idx - (num_samples[file_idx] - validation_file["num_samples"]) 
        chunk = SlidingWindow(start=0., duration=self.duration, step=self.duration)[idx]

        # load audio excerpt corresponding to current chunk
        X = self.model.audio.crop(validation_file["audio"], chunk, fixed=self.duration)

        # load labels corresponding to random chunk as {0|1} numpy array
        # y[k] = 1 means that kth class is active
        y = np.zeros((num_classes,))
        active_classes = validaiton_file["annotation"].crop(chunk).labels()
        for active_class in active_classes:
            y[classes.index(active_class)] = 1

        return {'X': X, 'y': y}

    def val__len__(self):
        return sum(validation_file["num_samples"] 
                   for validation_file in self.validation_metadata_)

    # `pyannote.audio.core.task.Task` base class provides a `LightningModule.training_step` and 
    # `LightningModule.validation_step` methods that rely on self.specifications to guess which 
    # loss and metrics should be used. you can obviously choose to customize them. 
    # More details can be found in pytorch-lightning documentation and in 
    # pyannote.audio.core.task.Task source code. 

    # def training_step(self, batch, batch_idx: int):
    #    return loss

    # def validation_step(self, batch, batch_idx: int):
    #    return metric

    # pyannote.audio.tasks.segmentation.mixin also provides a convenient mixin
    # for "segmentation" tasks (ie. with Resolution.FRAME) that already defines
    # a bunch of useful methods. 
