# YOHO training

In [None]:
# Import used libraries

import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import librosa
import torch
torch.manual_seed(0)

print(f"{pd.__name__} version: {pd.__version__}")
print(f"{matplotlib.__name__} version: {matplotlib.__version__}")
print(f"{librosa.__name__} version: {librosa.__version__}")
print(f"{torch.__name__} version: {torch.__version__}")


from yoho24.utils import AudioClip, AudioFile, TUTDataset, YOHODataGenerator

In [None]:
def plot_melspectrogram(
    audio: AudioFile, n_mels: int = 40, win_len: float = 1.00, hop_len: float = 1.00
):
    """
    Plots the Mel spectrogram.
    """
    plt.figure(figsize=(10, 4))
    plt.title(f"Mel spectrogram")
    librosa.display.specshow(
        data=audio.mel_spectrogram(n_mels=n_mels, win_len=win_len, hop_len=hop_len), sr=audio.sr, x_axis="frames", y_axis="mel"
    )
    plt.colorbar(format="%+2.0f dB")
    plt.tight_layout()
    plt.show()

## Data generator

In [None]:
N_MELS = 40
WIN_S = 0.04 # 40 ms
HOP_S = 0.01 # 10 ms

In [None]:
audios = [
    AudioFile(filepath=file.filepath, labels=eval(file.events))
    for _, file in pd.read_csv("./data/processed/TUT/TUT-sound-events-2017-development.csv").iterrows()
]

audioclips = [
    audioclip
    for _, audio in enumerate(audios)
    for audioclip in audio.subdivide(win_len=2.56, hop_len=1.96)
    if _ < 1
]

In [None]:
audioclips[5].plot_labels()

In [None]:
audios[1].play()

In [None]:
N_MELS = 40
HOP_MS = 10
WIN_MS = 40

tut_train = TUTDataset(
    audios=audioclips,
)

print(f"Number of audio files: {len(tut_train)}")
print(f"Duration: {tut_train.audios[0].duration} seconds")
print(f"Sampling rate: {tut_train.audios[0].sr} Hz")

In [None]:
train_dataloader = YOHODataGenerator(tut_train, batch_size=1, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))

print(f"Train features shape: {train_features.shape}")
print(f"Train labels shape: {train_labels.shape}")

In [None]:
import numpy as np

mel = np.zeros((257, 40))
mel = torch.tensor(mel, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

from models import YOHO

prediction = YOHO(input_shape=(1, 257, 40), output_shape=(9,18))(mel)
prediction.shape

In [None]:
from torchsummary import summary

summary(YOHO(input_shape=(1, 257, 40), output_shape=(18,9)), (1, 257, 40))