In [None]:
import random
import time

import pandas as pd
import tensorflow as tf
from matplotlib import pyplot as plt
import sys
sys.path.append('..')
from data.load_data import *
from processing.utils import *

#print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

seed = 2022
tf.random.set_seed(seed)
np.random.seed(seed)

In [None]:
def read_midi(midi_path):
    note_items, tempo_items = read_items(midi_path)
    note_items = quantize_items(note_items)
    max_time = note_items[-1].end
    chord_items = extract_chords(note_items)
    items = chord_items + tempo_items + note_items
    groups = group_items(items, max_time)
    return item2event(groups)

def transform_midi(midi_paths):
    # extract events
    all_events = []
    for path in midi_paths:
        events = read_midi(path)
        for event in events:
            all_events.append(event)
    return np.array(all_events, dtype=object)

def transform_midi_alt(midi_paths):
    # extract events
    all_events = []
    for path in midi_paths:
        events = read_midi(path)
        for event in events:
            all_events.append(event)
    return pd.DataFrame.from_records([e.to_dict() for e in all_events])

In [None]:
print("Loading data...")
midi_paths = get_all_files(dataset_name="MOZART_SMALL")
dataset = transform_midi(midi_paths=midi_paths)

In [None]:
df = dataset
df.head()

In [None]:
print("Creating tensorflow dataset...")
notes_dataset = tf.data.Dataset.from_tensor_slices(dataset)
print(f">> {notes_dataset.element_spec}")

In [None]:
def create_sequences(dataset: tf.data.Dataset, seq_length: int, vocab_size=128) -> tf.data.Dataset:
    """ Returns TF Dataset of sequence and label examples """
    seq_length = seq_length + 1

    # Take 1 extra for the labels
    windows = dataset.window(seq_length, shift=1, stride=1, drop_remainder=True)

    # `flat_map` flattens the" dataset of datasets" into a dataset of tensors
    flatten = lambda x: x.batch(seq_length, drop_remainder=True)
    sequences = windows.flat_map(flatten)

    # Normalize note pitch
    def scale_pitch(x):
        return x / vocab_size

    # Split the labels
    def split_labels(sequences):
        inputs = sequences[:-1]
        labels_dense = sequences[-1]
        labels = {key: labels_dense[i] for i, key in enumerate(["pitch"])}

        return scale_pitch(inputs), labels

    return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)


seq_length = 64
vocab_size = 128  # range of pitches supported in pretty_midi
sequence_dataset = create_sequences(notes_dataset, seq_length, vocab_size)
print(sequence_dataset.element_spec)

In [None]:
for seq, target in sequence_dataset.take(1):
    print('sequence shape:', seq.shape)
    print('sequence elements (first 5):', seq[0: 5])
    print('target:', target)

In [None]:
batch_size = 1024
buffer_size = note_count - seq_length  # the number of items in the dataset
train_dataset = (sequence_dataset
                 .shuffle(buffer_size)
                 .batch(batch_size, drop_remainder=True)
                 .cache()
                 .prefetch(tf.data.experimental.AUTOTUNE))

In [None]:
input_shape = (seq_length, 1)
learning_rate = 0.005

inputs = tf.keras.Input(input_shape)
x = tf.keras.layers.LSTM(512)(inputs)
x = tf.keras.layers.Dense(1024)(x)

outputs = {
    "pitch": tf.keras.layers.Dense(128, name="pitch")(x)
}

model = tf.keras.Model(inputs, outputs)

loss = {
    "pitch": tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
}

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

model.compile(loss=loss, optimizer=optimizer)

model.summary()

In [None]:
model.evaluate(train_dataset, return_dict=True)

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath="./training_checkpoints/ckpt_{epoch}",
        save_weights_only=True),
    tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        patience=5,
        verbose=1,
        restore_best_weights=True)
]

In [None]:
epochs = 50
history = model.fit(
    train_dataset,
    epochs=epochs,
    callbacks=callbacks,
)

In [None]:
plt.plot(history.epoch, history.history['loss'], label='total loss')
plt.show()

In [None]:
def predict_next_note(notes: np.ndarray, model: tf.keras.Model, temperature=1.0) -> int:
    """Generates a note IDs using a trained sequence model."""
    assert temperature > 0
    # Add batch dimension
    inputs = tf.expand_dims(notes, 0)
    predictions = model.predict(inputs)
    pitch_logits = predictions['pitch']
    pitch_logits /= temperature
    pitch = tf.random.categorical(pitch_logits, num_samples=1)
    pitch = tf.squeeze(pitch, axis=-1)
    return int(pitch)

In [None]:
temperature = 3.0
num_predictions = 500

a = random.randint(0, 6000)
print(f"Using starter notes from {a} to {a + 64}...")
sample_notes = np.stack(dataset[a:a + 64])
input_notes = (sample_notes[:seq_length] / np.array([vocab_size]))

generated_notes = []
prev_start = 0
for _ in range(num_predictions):
    pitch = predict_next_note(input_notes, model, temperature)
    start = prev_start + 0.1
    end = start + random.random()
    input_note = (pitch,)
    generated_notes.append((*input_note, start, end))
    input_notes = np.delete(input_notes, 0, axis=0)
    input_notes = np.append(input_notes, np.expand_dims(input_note, axis=0), axis=0)
    prev_start = start

generated_notes = pd.DataFrame(generated_notes, columns=("pitch", "start", "end"))

In [None]:
def notes_to_midi(notes: pd.DataFrame, out_file: str, instrument_name="Acoustic Grand Piano") -> pretty_midi.PrettyMIDI:
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=pretty_midi.instrument_name_to_program(instrument_name))

    for i, note in notes.iterrows():
        print(note)
        n = pretty_midi.Note(
            velocity=random.randint(80, 120),
            pitch=int(note["pitch"]),
            start=note["start"],
            end=note["end"],
        )
        instrument.notes.append(n)

    pm.instruments.append(instrument)
    pm.write(out_file)
    return pm

In [None]:
out_file = f"../output/{int(time.time())}.mid"
out_pm = notes_to_midi(generated_notes, out_file=out_file)

In [None]:
def plot_piano_roll(notes: pd.DataFrame, count=None):
    if count:
        title = f'First {count} notes'
    else:
        title = f'Whole track'
        count = len(notes['pitch'])
    plt.figure(figsize=(20, 4))
    plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)
    plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)
    plt.plot(
        plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
    plt.xlabel('Time [s]')
    plt.ylabel('Pitch')
    _ = plt.title(title)


plot_piano_roll(generated_notes)