Author: Edoardo De Matteis

In [None]:
# Install dependencies.
# !pip install torchaudio

In [None]:
# Import dependencies.
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T
import pandas as pd
from pathlib import Path
from typing import Dict, Optional, Callable

In [None]:
# Environmental variables.
TRAIN_BIRDCALLS = "path/to/be/defined"
BIRDCALLS_CSV = "path/to/be/defined"

In [None]:
def get_spectrogram(
    audio_file: Path,
    n_fft: int = 1024,
    win_length: Optional[int] = None,
    hop_length: int = 512,
    n_mels: int = 128,
    mel: bool = True
) -> torch.Tensor:
    """
    Get the spectrogram of an audio file.
    :param audio_file: Path of the audio file..
    :param n_fft:
    :param win_length:
    :param hop_length:
    :param n_mels:
    :param mel: If true we want melodic spectrograms.
    :param kwargs:
    :return:
    """
    waveform, sample_rate = torchaudio.load(audio_file, format="ogg")

    spectrogram: Callable

    if not mel:
        spectrogram = T.Spectrogram(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=True,
            pad_mode="reflect",
            power=2.0,
        )
    else:
        # Mel Spectrogram transform.
        spectrogram = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            center=True,
            pad_mode="reflect",
            power=2.0,
            norm="slaney",
            onesided=True,
            n_mels=n_mels,
            mel_scale="htk",
        )

    return spectrogram(waveform)

In [None]:
# Use this cell to access data
class BirdcallDataset(Dataset):
    def __init__(self, path: str, **kwargs):
        df = pd.read_csv(path).drop(["Unnamed: 0"], axis=1)
        self.idx = df["bird_idx"].values.tolist()
        self.name = df["bird_name"].values.tolist()
        self.call = df["bird_call"].values.tolist()

    def __len__(self) -> int:
        return len(self.idx)

    def __getitem__(self, item) -> Dict:
        return {"idx": self.idx[item], "name": self.name[item], "call": self.call[item]}

    @staticmethod
    def collate_fn(data):
        birds = []
        specs = []

        for obj in data:
            idx = obj["idx"]
            name = obj["name"]
            call = obj["call"]

            audio_file = str(TRAIN_BIRDCALLS / name / call)
            spec = get_spectrogram(audio_file)

            birds.append(idx)
            specs.append(spec)

        return {"bird": torch.tensor(birds), "spec": specs}

In [None]:
# Actual instantiation.
birdcall_ds = BirdcallDataset(BIRDCALLS_CSV)
birdcall_dl = DataLoader(dataset=birdcall_ds, batch_size=32, collate_fn=BirdcallDataset.collate_fn)