In [None]:
import librosa

# data feeding
batch_size = 32

# fft variables
spec_feature_count = 128


# lstm segments
segment_length_secs = 5
segment_length = librosa.time_to_frames(segment_length_secs)


drum_notes = [35, 38, 42, 46, 41, 43, 45, 47, 48, 50, 49, 51]
output_classes = len(drum_notes)


In [None]:
import tensorflow as tf
import pretty_midi
from matplotlib import pyplot as plt
import numpy as np
import librosa
import librosa.display
from os.path import isfile
import glob
from multiprocessing import Process
from pydub import AudioSegment

# TODO: data augmentation using audiomentations: https://github.com/iver56/audiomentations
# TODO: pad data with zero_padding / silence

class PlotLosses(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []
        self.fig = plt.figure()
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.i += 1
        
        clear_output(wait=True)
        print('loss: ' + str(logs.get('loss')))
        print("val_loss:" + str(logs.get('val_loss')))
        plt.plot(self.x, self.losses, label="loss")
        plt.plot(self.x, self.val_losses, label="val_loss")
        plt.ylim(0, 0.1)
        plt.yticks(np.arange(0.01, 0.1, 0.01))
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'validation'], loc='upper right')
        plt.show()

# Maps notes into 12 drum classes (kick, snare, closed hihat, open hihat, 6x toms, crash cymbal, ride cymbal)
gm_drum_map = {
    # kicks
    35: 35,
    36: 35,
    # snares
    38: 38,
    40: 38,
    # hihats
    42: 42,
    44: 42,
    # open hihat
    46: 46,
    # toms
    41: 41,
    43: 43,
    45: 45,
    47: 47,
    48: 48,
    50: 50,
    # crash / splash cymbals
    49: 49,
    55: 49,
    57: 49,
    # ride cymbals
    51: 51,
    59: 51,
}

class_to_idx = {k: v for v, k in enumerate(drum_notes)}
thread_count = 16
high_pass_hz = 15000
high_pass_ratio = 40

def plot_labels(labels):
    plt.yticks(range(12), [pretty_midi.note_number_to_drum_name(n) for n in drum_notes])
    plt.imshow(labels, interpolation='nearest', aspect='auto')
    plt.gca().invert_yaxis()

def samples_to_np(a):
    return np.array(a.get_array_of_samples()).astype(np.float32) / 32767.0

def process_file(data_file, callback):
    label_file = data_file[:-4] + '.midi'
    if not isfile(data_file) or not isfile(label_file):
        return
    audio, sr = librosa.load(data_file)
    if np.isnan(audio).any():
        raise Exception('found nan in transformed audio clip')
    stft_features = librosa.power_to_db(librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=spec_feature_count, n_fft=2048, hop_length=512))
    if np.isnan(stft_features).any():
        raise Exception('found nan in stft')

    # print(stft_features.shape)
    # plt.figure(figsize=(20, 10))
    # librosa.display.specshow(stft_features[:, :500])
    # plt.colorbar()
    # plt.show()

    pm = pretty_midi.PrettyMIDI(label_file)
    instrument = pm.instruments[0]
    pitches = []
    onsets = []

    for i, note in enumerate(instrument.notes):
        mapped_note = gm_drum_map.get(note.pitch, 0)
        if mapped_note == 0:
            continue
        if note.velocity < 5:
            continue
        pitches.append(mapped_note)
        onsets.append(note.start)

    onset_frames = librosa.time_to_frames(onsets, sr=sr)
    labels = np.zeros((output_classes, stft_features.shape[1]), dtype='int64')

    for i in range(len(onsets)):
        frame = onset_frames[i]
        if frame >= labels.shape[1]:
            continue
        class_idx = class_to_idx.get(pitches[i])
        labels[class_idx][frame] = 1
        # Soft target vectors
        if frame > 0:
            labels[class_idx][frame - 1] = 0.5
        if frame < labels.shape[1] - 1:
            labels[class_idx][frame + 1] = 0.5 

    segment_count = stft_features.shape[1] // segment_length
    for i in range(segment_count):
        stft_slice = stft_features[:, i * segment_length:(i + 1) * segment_length]

        labels_slice = labels[:, i * segment_length:(i + 1) * segment_length]
        if stft_slice.shape[1] != labels_slice.shape[1]:
            raise Exception('mismatched total frame count between stft_features and labels')
        callback(stft_slice, labels_slice)

def _write_tfrecords_thread(tfrecord_path, data_files):
    with tf.io.TFRecordWriter(tfrecord_path) as writer:
        for i in range(len(data_files)):
            def write_record(stft_slice, labels_slice):
                feature = {
                    'stft': tf.train.Feature(float_list=tf.train.FloatList(value=stft_slice.flatten())),
                    'labels': tf.train.Feature(float_list=tf.train.FloatList(value=labels_slice.flatten())),
                }
                example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(example.SerializeToString())

            process_file(data_files[i], write_record)

ENABLE_THREADING = True
def write_tfrecords(data_files, tfrecord_path_prefix):
    if not ENABLE_THREADING:
        _write_tfrecords_thread(tfrecord_path_prefix + '.tfrecords', data_files)
    else:
        coord = tf.train.Coordinator()
        processes = []
        chunked_data_files = np.array_split(data_files, thread_count)
        for thread_index in range(thread_count):
            args = (tfrecord_path_prefix + str(thread_index) + '.tfrecords', chunked_data_files[thread_index])
            p = Process(target=_write_tfrecords_thread, args=args)
            p.start()
            processes.append(p)
        coord.join(processes)

In [None]:
%%time

# Process all raw wav/midi files and extract features
# import glob

# train_files = glob.glob('/mnt/d/drums/train/**/*.wav', recursive=True)
# validate_files = glob.glob('/mnt/d/drums/valid/**/*.wav', recursive=True)

# print(f'Processing {len(train_files) + len(validate_files)} files')

# # print(train_files)
# write_tfrecords(train_files, '/mnt/d/drums/train')
# write_tfrecords(validate_files, '/mnt/d/drums/validate')

# print('Done!')

In [None]:
# Load features from Google cloud storage
import tensorflow as tf
from matplotlib import pyplot as plt
from google.cloud import storage

TRAIN_TFRECORDS = []
VALID_TFRECORDS = []

client = storage.Client()
for blob in client.list_blobs('drums-bucket'):
    if blob.name.startswith('train'):
        TRAIN_TFRECORDS.append(blob.name)
    elif blob.name.startswith('validate'):
        VALID_TFRECORDS.append(blob.name)

TRAIN_TFRECORDS = [f'gs://drums-bucket/{f}' for f in TRAIN_TFRECORDS]
VALID_TFRECORDS = [f'gs://drums-bucket/{f}' for f in VALID_TFRECORDS]

# # Load features from disk
# import glob
# import tensorflow as tf
# from matplotlib import pyplot as plt

# TRAIN_TFRECORDS = glob.glob('/mnt/d/drums/train*.tfrecords')
# VALID_TFRECORDS = glob.glob('/mnt/d/drums/validate*.tfrecords')

feature_description = {
    'stft': tf.io.FixedLenFeature((spec_feature_count, segment_length), tf.float32),
    'labels': tf.io.FixedLenFeature((output_classes, segment_length), tf.float32),
}

def _parse_feature(example):
    parsed = tf.io.parse_single_example(example, feature_description)
    # TODO: transpose when writing instead
    return tf.expand_dims(tf.transpose(parsed['stft']), axis=-1), tf.transpose(parsed['labels'])

def load_dataset(filename):
    dataset_options = tf.data.Options()
    dataset_options.experimental_deterministic = False
    dataset_options.threading.private_threadpool_size = 16
    dataset = tf.data.TFRecordDataset(filename, num_parallel_reads=8)
    dataset = dataset.with_options(dataset_options)
    dataset = dataset.map(_parse_feature, num_parallel_calls=8)
    dataset = dataset.shuffle(512)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(2)
    return dataset

parsed_train = load_dataset(TRAIN_TFRECORDS)
parsed_valid = load_dataset(VALID_TFRECORDS)

for feature_batch, label_batch in parsed_train.take(1):
    print(feature_batch.shape)
    for n in range(1):
        print(feature_batch[n].shape)
        print(label_batch[n].shape)
        plt.figure(figsize=(20, 20))
        plt.subplot(2, 1, 1)
        plt.imshow(tf.transpose(tf.squeeze(feature_batch[n], axis=-1)), interpolation='nearest', aspect='auto')
        plt.gca().invert_yaxis()
        plt.subplot(2, 1, 2)
        plot_labels(tf.transpose(label_batch[n]))

In [None]:
from tensorflow.keras import Sequential, mixed_precision, Model
from tensorflow.keras.layers import Bidirectional, LSTM, Dropout, TimeDistributed, Dense, Activation, Conv2D, BatchNormalization, Reshape, Flatten, Input, MaxPool2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from IPython.display import clear_output

cnn_layers = 4
cnn_filters = 32
cnn_dropout = 0.3

lstm_units = 64
lstm_layers = 3
lstm_dropout = 0.1

num_epochs = 200

mixed_precision.set_global_policy('mixed_float16')

inputs = Input(shape=(segment_length, spec_feature_count, 1))

cnn = inputs
for i in range(cnn_layers):
    cnn = Conv2D(cnn_filters, (3, 3), activation='relu', padding='same', strides=1)(cnn)
    cnn = BatchNormalization()(cnn)
    cnn = Conv2D(cnn_filters, (3, 3), activation='relu', padding='same', strides=1)(cnn)
    cnn = BatchNormalization()(cnn)
    cnn = MaxPool2D(pool_size=(1, 3))(cnn)
    cnn = Dropout(cnn_dropout)(cnn)
    
cnn = Reshape((segment_length, -1))(cnn)

lstm = Bidirectional(LSTM(units=lstm_units, input_shape=(segment_length, -1, -1), return_sequences=True))(cnn)
lstm = Dropout(lstm_dropout)(lstm)
for i in range(lstm_layers - 1):
    lstm = Bidirectional(LSTM(units=lstm_units, return_sequences=True, activation='tanh'))(lstm)
    lstm = Dropout(lstm_dropout)(lstm)
# lstm = TimeDistributed(Dense(spec_feature_count))(lstm)
lstm = TimeDistributed(Dense(output_classes))(lstm)
output = Activation('sigmoid', dtype='float32')(lstm)

model = Model(inputs=inputs, outputs=output)
model.summary()
model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['binary_crossentropy'])

plot_losses = PlotLosses()

early = EarlyStopping(monitor='val_loss', min_delta=0, patience=30, verbose=1, mode='auto', restore_best_weights=True)
save = model.fit(parsed_train, batch_size=batch_size, epochs=num_epochs, validation_data=(parsed_valid), callbacks=[plot_losses, early])

In [None]:
TEST_FILE = '/mnt/d/drums/e-gmd-v1.0.0/drummer4/session1/1_rock_87_beat_4-4_1.wav'

def get_onsets(filepath):
    all_predictions = []
    all_labels = []
    
    batches = []

    def handle_slices(stft_slice, labels_slice):
        batches.append(tf.transpose(stft_slice))
        all_labels.append(tf.transpose(labels_slice))

    process_file(filepath, handle_slices)
    
    predictions = model.predict(np.array(batches))
    all_predictions = np.concatenate(predictions)
    all_labels = np.concatenate(all_labels)

    plt.figure(figsize=(40, 10))
    plt.subplot(2, 1, 1)
    plot_labels(tf.transpose(all_predictions[:5000]))
    plt.subplot(2, 1, 2)
    plot_labels(tf.transpose(all_labels[:5000]))

get_onsets(TEST_FILE)