Install requirements

Import libraries

In [None]:
!pip install tensorflow tensorflow-datasets librosa

Only grab the dataset where instument == keyboard

In [21]:
import tensorflow as tf
import tensorflow_datasets as tfds
import librosa
import numpy as np

# Constants
KEYBOARD_FAMILY_LABEL = 4 # According to NSynth dataset family label
SAMPLE_RATE = 16000
TRIM_LENGTH = 3 * SAMPLE_RATE  # Trim to the first 3 seconds

# Define the processing function
def process_data(example):
    audio = example['audio']
    instrument_family = example['instrument']['family']
    pitch = example['pitch']

    # Filter keyboard samples
    is_keyboard = tf.equal(instrument_family, KEYBOARD_FAMILY_LABEL)
    # Only process the samples where is_keyboard is True
    def process_keyboard_sample(audio, pitch):
        # Trim the audio
        audio = audio[:TRIM_LENGTH]

        # # Convert audio to CQT (Constant-Q Transform) using librosa
        # # Set fmin to the frequency of A0 and n_bins to 88
        # def compute_cqt(x):
        #     return np.abs(librosa.cqt(x, sr=SAMPLE_RATE, fmin=librosa.note_to_hz('A0'), n_bins=88, bins_per_octave=12))

        # Here, tf.numpy_function applies a Python function to the TensorFlow tensor
        # audio = tf.numpy_function(compute_cqt, [audio], tf.float32)

        # Modify pitch
        pitch = pitch - 21
        return audio, pitch
    # Return the processed audio and pitch, only if the sample is a keyboard
    return tf.cond(is_keyboard, lambda: process_keyboard_sample(audio, pitch), lambda: (audio, pitch))

def filter_keyboard_samples(example):
    return tf.equal(example['instrument']['family'], KEYBOARD_FAMILY_LABEL)

def get_data_loader(data_split, batch_size=64, num_batches=None):
    ds = tfds.load('nsynth', split=data_split, as_supervised=False)

    # First, filter out non-keyboard samples
    ds = ds.filter(filter_keyboard_samples)
    ds = ds.map(process_data, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.filter(lambda audio, pitch: tf.reduce_sum(tf.shape(audio)) > 0)  # Filter out empty audio results
    ds = ds.batch(batch_size)

    if num_batches:
        ds = ds.take(num_batches)

    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    return ds

# Input to model
batch_size = 64
train_loader = get_data_loader('train', batch_size, num_batches=40)
val_loader = get_data_loader('valid', batch_size, num_batches=6)
test_loader = get_data_loader('test', batch_size, num_batches=6)
classes = list(range(88))

In [22]:
for audio, pitch in train_loader.take(1):
    first_sample_audio = audio[0]
    first_sample_pitch = pitch[0]

    # Calculate the maximum and minimum values
    max_value = tf.reduce_max(first_sample_audio)
    min_value = tf.reduce_min(first_sample_audio)

    # Calculate the range (max - min)
    range_value = max_value - min_value

    print(f"First sample audio max value: {max_value.numpy()}")
    print(f"First sample audio min value: {min_value.numpy()}")
    print(f"First sample audio range: {range_value.numpy()}")
    print(f"First sample pitch: {first_sample_pitch.numpy()}")


First sample audio max value: 0.46275418996810913
First sample audio min value: -0.46886491775512695
First sample audio range: 0.9316191077232361
First sample pitch: 85


In [23]:
def get_dataset_length(data_loader):
    length = 0
    for _ in data_loader:
        length += 1
    return length

# Use this function to get the length of your data loaders
test_loader_length = get_dataset_length(test_loader)
val_loader_length = get_dataset_length(val_loader)
train_loader_length = get_dataset_length(train_loader)

print(f"Test loader length: {test_loader_length}")
print(f"Validation loader length: {val_loader_length}")
print(f"Train loader length: {train_loader_length}")


Test loader length: 6
Validation loader length: 6
Train loader length: 40


In [24]:
def get_dataset_sample_count(data_loader):
    total_samples = 0
    for audio, pitch in data_loader:
        # Count the number of samples in each batch
        batch_samples = tf.shape(audio)[0]  # assuming audio is a 2D tensor [batch_size, features]
        total_samples += batch_samples
    return total_samples

# Use this function to get the number of samples in your data loaders
test_samples_count = get_dataset_sample_count(test_loader)
val_samples_count = get_dataset_sample_count(val_loader)
train_samples_count = get_dataset_sample_count(train_loader)

print(f"Train loader samples: {train_samples_count}")
print(f"Validation loader samples: {val_samples_count}")
print(f"Test loader samples: {test_samples_count}")

Train loader samples: 2560
Validation loader samples: 384
Test loader samples: 384
