In [1]:
import tensorflow as tf

# tf.config.list_physical_devices('GPU')
print(tf.__version__)

TypeError: Descriptors cannot be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

In [None]:
import glob
import numpy as np
import tensorflow as tf
from keras import Model
from keras.layers import Input, Bidirectional, LSTM, TimeDistributed, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam

# Set mixed precision policy
tf.keras.mixed_precision.set_global_policy('mixed_float16')

def data_generator(x_paths, onset_paths):
    """
    Generator function to load and preprocess data.
    
    Args:
    x_paths (list): List of file paths for input features.
    onset_paths (list): List of file paths for onset labels.
    
    Yields:
    tuple: Tuple of preprocessed input features and onset labels.
    """
    for x_path, onset_path in zip(x_paths, onset_paths):
        X = np.load(x_path)
        onset = np.load(onset_path)

        X = X / np.max(X)

        for x, onset_frame in zip(X, onset):
            yield x, onset_frame

def build_model():
    """
    Builds and returns the LSTM model for onset detection.
    
    Returns:
    keras.Model: Compiled LSTM model.
    """
    input_layer = Input(batch_input_shape=(10, 100, 264), name='onset_input')

    onset_lstm = Bidirectional(LSTM(128, activation='tanh', return_sequences=True, stateful=True, name='onset_lstm'))(input_layer)

    onset_output = TimeDistributed(Dense(88, activation='sigmoid', kernel_initializer='he_normal', name='onset_output'))(onset_lstm)

    return Model(inputs=input_layer, outputs=onset_output)

def train_model(trainX_pattern, trainOnset_pattern, validX_pattern, validOnset_pattern):
    """
    Trains the LSTM model using the provided training and validation data.
    
    Args:
    trainX_pattern (str): File pattern for training input features.
    trainOnset_pattern (str): File pattern for training onset labels.
    validX_pattern (str): File pattern for validation input features.
    validOnset_pattern (str): File pattern for validation onset labels.
    """
    train_x_paths = glob.glob(trainX_pattern)
    train_onset_paths = glob.glob(trainOnset_pattern)
    valid_x_paths = glob.glob(validX_pattern)
    valid_onset_paths = glob.glob(validOnset_pattern)

    input_signature = (
        tf.TensorSpec(shape=(100, 264), dtype=tf.float32),
        tf.TensorSpec(shape=(100, 88), dtype=tf.float32)
    )

    train_dataset = tf.data.Dataset.from_generator(lambda: data_generator(train_x_paths, train_onset_paths), output_signature=input_signature)
    valid_dataset = tf.data.Dataset.from_generator(lambda: data_generator(valid_x_paths, valid_onset_paths), output_signature=input_signature)
    
    train_dataset = train_dataset.batch(10, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
    valid_dataset = valid_dataset.batch(10, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

    model = build_model()

    checkpoint_callback = ModelCheckpoint('onset_detector.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
    early_stopping_callback = EarlyStopping(patience=5, monitor='val_loss', verbose=1, mode='auto')

    optimizer = Adam(learning_rate=0.0005)
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    model.fit(train_dataset, validation_data=valid_dataset, epochs=2, shuffle=False, callbacks=[checkpoint_callback, early_stopping_callback])
    model.save('onset_last.h5')

if __name__ == '__main__':
    train_model('preprocessed/trainX/*.npy', 'preprocessed/trainONSET/*.npy', 'preprocessed/validX/*.npy', 'preprocessed/validONSET/*.npy')
