In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pretty_midi

from data_preprocessing import get_train_data
from data_preprocessing import group_cqt_frames

from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.losses import BinaryCrossentropy

In [None]:
audio = "../databases/combined_database/MAPS_ENSTDkCl/MAPS_ENSTDkCl_2/ENSTDkCl/MUS/MAPS_MUS-scn15_12_ENSTDkCl.wav"
midi = "../databases/combined_database/MAPS_ENSTDkCl/MAPS_ENSTDkCl_2/ENSTDkCl/MUS/MAPS_MUS-scn15_12_ENSTDkCl.mid"


sampling_rate = 16000
hop_length = 512
n_bins = 88
show_cqt_pr = False
pr_in_frames = True
cqt_in_frames = True
num_frames_before = 3
num_frames_after = 3


cqt, labels = get_train_data(audio, sampling_rate, hop_length, n_bins, midi, show_cqt_pr, pr_in_frames, cqt_in_frames)

In [None]:
model = load_model("../saved-models/onenote-saved-models/saved_model_11.h5")

model.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.9, learning_rate=0.01),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.F1Score()])

# model.summary()

In [None]:
labels_transposed = np.transpose(labels)
grouped_cqt = group_cqt_frames(cqt, 3, 3)
grouped_cqt_abs = np.abs(grouped_cqt)

predicted_cqt = model.predict(grouped_cqt_abs, batch_size=None,  verbose="auto", steps=None, callbacks=None)

# Define threshold
threshold = 0.4

# Apply threshold to create a binary matrix
binary_matrix = np.where(predicted_cqt > threshold, 1, 0)

In [None]:
print(labels.shape)
print(cqt.shape)
print("--------------------------")
print(labels_transposed.shape)
print(grouped_cqt_abs.shape)
print(predicted_cqt.shape)

In [None]:
# Plot original predicted matrix
plt.figure(figsize=(15, 5))
plt.imshow(predicted_cqt.T, cmap='binary', aspect='auto', origin='lower')
plt.xlabel('Frames')
plt.ylabel('MIDI Note Index (A0 - C8)')
plt.title('Piano Roll Representation')
plt.colorbar(label='Note Presence (1=On, 0=Off)')
plt.show()

# Plot binary matrix
plt.figure(figsize=(15, 5))
plt.imshow(binary_matrix.T, cmap='binary', aspect='auto', origin='lower')
plt.xlabel('Frames')
plt.ylabel('MIDI Note Index (A0 - C8)')
plt.title('Binary Piano Roll Representation')
plt.colorbar(label='Note Presence (1=On, 0=Off)')
plt.show()

In [None]:
# Convert binary matrix to MIDI file (technically doesn't really work)
def piano_roll_to_pretty_midi(piano_roll, fs=100, program=0):
    """Convert a Piano Roll array into a PrettyMIDI object with a single instrument."""
    notes, frames = piano_roll.shape
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=program)

    # pad 1 column of zeros so we can acknowledge inital and ending events
    piano_roll = np.pad(piano_roll, ((0, 0), (1, 1)), 'constant')

    # use changes in velocities to find note on/note off events
    changes = np.diff(piano_roll, axis=1)
    for note in range(notes):
        for onset, offset in zip(*np.nonzero(changes[note, :] == 1)):
            onset_time = onset / fs
            offset_time = offset / fs
            note = pretty_midi.Note(velocity=100, pitch=note + 21, start=onset_time, end=offset_time)
            instrument.notes.append(note)

    pm.instruments.append(instrument)
    return pm

In [None]:
# Sampling rate for MIDI
fs = sampling_rate / hop_length

# Convert binary matrix to MIDI
midi_data = piano_roll_to_pretty_midi(binary_matrix, fs=fs)
midi_data.write("predicted_output.mid")