In [1]:
import pandas as pd
import tensorflow as tf
from matplotlib import pyplot as plt
import sys
sys.path.append('..')
from data.load_data import *
from processing.utils import *
from NotezartTransformer import NotezartTransformer

import pickle
from pathlib import Path

#print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

seed = 2022
tf.random.set_seed(seed)
np.random.seed(seed)

checkpoint_path = Path('resource/gen4/v2').absolute()

In [2]:
# read in a midi file as an array of events
def read_midi(midi_path):
    note_items, tempo_items = read_items(midi_path)
    note_items = quantize_items(note_items)
    max_time = note_items[-1].end
    chord_items = extract_chords(note_items)
    items = chord_items + tempo_items + note_items
    groups = group_items(items, max_time)
    events = item2event(groups)
    return np.array(events, dtype=object)

# read in a series of midi files as a list of sequence of events
def transform_midi(midi_paths):
    # extract events
    events = []
    all_events = []
    for path in midi_paths:
        try:
            midi = read_midi(path)
            events.append(np.asarray([e.to_key() for e in midi]))
            all_events.append(midi)
        except:
            print(f"Failed: {path}")
    return all_events, np.asarray(events, dtype=object)

def build_lookup(midi_paths, dictionary_path):
    all_events, events = transform_midi(midi_paths=midi_paths)

    unique_events = np.unique([e for s in events for e in s])
    event2word = dict(zip(unique_events, list(range(0, len(unique_events)))))
    word2event = {i: e for e, i in event2word.items()}

    with open(dictionary_path, 'wb') as handle:
        pickle.dump([event2word, word2event], handle, protocol=pickle.HIGHEST_PROTOCOL)
      
    return all_events

In [3]:
def train_model(dataset_name):
    print("Loading data...")
    dictionary_path = f"{checkpoint_path}/dictionary/dictionary.pkl"
    midi_paths = get_all_files(dataset_name=dataset_name)

    # Build lookup dictionaries
    all_events = build_lookup(midi_paths=midi_paths, dictionary_path=dictionary_path)
    model = NotezartTransformer(checkpoint=checkpoint_path, is_training=True)
    model.load_model()

    training_set_path = f"{checkpoint_path}/data/training_set.pkl"

    def get_data(training_set_path):
        training_data = model.prepare_data(all_events=all_events)

        with open(training_set_path, 'wb') as handle:
            pickle.dump(training_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

        return training_data

    training_data = get_data(training_set_path)

    output_checkpoint_folder = f"{checkpoint_path}/model" # your decision
    model.finetune(epochs=100, training_data=training_data, output_checkpoint_folder=output_checkpoint_folder)

    return model
    

In [None]:
model = train_model


In [None]:
from NotezartTransformer import NotezartTransformer

#model = NotezartTransformer(checkpoint=checkpoint_path, is_training=True)
#model.load_model()

model = NotezartTransformer(checkpoint=checkpoint_path, is_training=False)
model.load_model(existing_model=f"{checkpoint_path}/checkpoints/model-027")

In [None]:
training_set_path = f"{checkpoint_path}/data/training_set.pkl"

def get_data(training_set_path):
    training_data = model.prepare_data(midi_paths)

    with open(training_set_path, 'wb') as handle:
        pickle.dump(training_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return training_data

training_data = get_data(training_set_path)
training_data = pickle.load(open(training_set_path, 'rb'))

In [None]:
output_checkpoint_folder = f"{checkpoint_path}/checkpoints" # your decision

#model.finetune(epochs=100, training_data=training_data, output_checkpoint_folder=output_checkpoint_folder)

In [None]:
model.generate(
        n_target_bar=20,
        temperature=1.2,
        topk=5,
        output_path=f"{checkpoint_path}/output/sample.midi",
        prompt=None)