In [1]:
import os
import librosa
import matplotlib.pyplot as plt
import torch
import torchaudio 
from torchaudio import transforms
from torch.utils.data import random_split
from IPython.display import Audio

In [2]:
from classifier.audioutils import AudioUtils
from classifier.sounddataset import SoundDataset

In [3]:
# Load audio input data into map

# NOTE: Assumes all samples are of the sample sample rate - if this is not the case add a method to AudioUtils to convert to 44100 (or whatever!)

CURRENT_DIR = os.getcwd()
AUDIO_DIR = os.path.join(CURRENT_DIR, "audio_input")

# Scan AUDIO_DIR and load all filenames into a map with key - path, value - label (1-10, 1. kick, 2. snare, ... 10.cymbal)

audio_datas = []
for subdir, dirs, files in os.walk(AUDIO_DIR):
    for file in files:
        if os.path.splitext(file)[-1] == ".wav":
            # get label for subdir
            label = int(os.path.split(subdir)[-1].split("_")[0])
            path = os.path.join(subdir, file)
            audio_datas.append((path, label))

print("Number of input files: ", len(audio_datas))

Number of input files:  2244


In [4]:
# Calculate the average length and use this to pad or truncate audio files to all be the same size
max_length = 0
lengths = []
for audio_data in audio_datas:
    sig, sr = AudioUtils.open(audio_data[0])
    max_length = max(max_length, sig.shape[1])
    lengths.append(sig.shape[1])

average_length = sum(lengths) / len(audio_datas)
desired_sample_len = int(average_length)

print("Max Length:", max_length, "Average Length:", average_length, "Sample Size:", desired_sample_len)


Max Length: 341501 Average Length: 20734.03787878788 Sample Size: 20734


In [5]:
# Create SoundDataset
audio_dataset = SoundDataset(audio_datas, 1, desired_sample_len)

# Random split 80:20 between training and validation
num_items = len(audio_dataset)
num_train = round(num_items * 0.8)
num_validate = num_items - num_train
train_ds, val_ds = random_split(audio_datas, [num_train, num_validate])

# Create data loaders for training and validation
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
validation_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False)
