# YOHO training

In [2]:
# Import used libraries

import pandas as pd
print("Pandas version: ", pd.__version__)

import torch
torch.manual_seed(0)

from utils import AudioClip, AudioFile, TUTDataset, YOHODataGenerator

Pandas version:  2.1.2


## Data generator

In [3]:
audioclips = [
    audioclip
    for _, file in pd.read_csv("./data/tut.train.csv").iterrows()
    for audioclip in AudioFile(filepath=file.filepath, labels=file.events).audioclips(
        win_ms=2560, hop_ms=1960
    ) if _ < 1
]

In [4]:
N_MELS = 40
HOP_MS = 10
WIN_MS = 40

tut_train = TUTDataset(
    audioclips=audioclips,
)

print(f"Number of audio files: {len(tut_train)}")
print(f"Duration: {tut_train.audioclips[0].duration} seconds")
print(f"Sampling rate: {tut_train.audioclips[0].sr} Hz")

Number of audio files: 122
Duration: 2.56 seconds
Sampling rate: 44100 Hz


In [5]:
train_dataloader = YOHODataGenerator(tut_train, batch_size=1, shuffle=True)

train_features, train_labels = next(iter(train_dataloader))

print(f"Train features shape: {train_features.shape}")
print(f"Train labels shape: {train_labels.shape}")

Train features shape: torch.Size([1, 1, 257, 40])
Train labels shape: torch.Size([1, 9, 18])


In [6]:
import numpy as np

mel = np.zeros((257, 40))
mel = torch.tensor(mel, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

from models import YOHO

prediction = YOHO(input_shape=(1, 257, 40), output_shape=(9,18))(mel)
prediction.shape

torch.Size([1, 9, 18])

In [7]:
from torchsummary import summary

summary(YOHO(input_shape=(1, 257, 40), output_shape=(18,9)), (1, 257, 40))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 32, 129, 20]             288
       BatchNorm2d-2          [-1, 32, 129, 20]              64
              ReLU-3          [-1, 32, 129, 20]               0
            Conv2d-4          [-1, 32, 129, 20]             288
       BatchNorm2d-5          [-1, 32, 129, 20]              64
              ReLU-6          [-1, 32, 129, 20]               0
            Conv2d-7          [-1, 64, 129, 20]           2,048
       BatchNorm2d-8          [-1, 64, 129, 20]             128
              ReLU-9          [-1, 64, 129, 20]               0
          Dropout-10          [-1, 64, 129, 20]               0
DepthwiseSeparableConv-11          [-1, 64, 129, 20]               0
           Conv2d-12           [-1, 64, 65, 10]             576
      BatchNorm2d-13           [-1, 64, 65, 10]             128
             ReLU-14           [-1