In [1]:
%load_ext autoreload
%autoreload 2

In [66]:
import glob
import os
import random
from collections import Counter

import nesmdb
import numpy as np
import pretty_midi
from IPython.display import display, Audio
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import utils

In [78]:
# Directories for the original midi data files
MIDI_ROOT_DIR = "data/nesmdb_midi"
MIDI_TRAIN_DIR = f"{MIDI_ROOT_DIR}/train"
MIDI_VALID_DIR = f"{MIDI_ROOT_DIR}/valid"
MIDI_TEST_DIR = f"{MIDI_ROOT_DIR}/test"

# Output directories in which the converted event-based data will be stored
EVENT_ROOT_DIR = "data/nesmdb_event"
EVENT_TRAIN_DIR = f"{EVENT_ROOT_DIR}/train"
EVENT_VALID_DIR = f"{EVENT_ROOT_DIR}/valid"
EVENT_TEST_DIR = f"{EVENT_ROOT_DIR}/test"

# Output directories for the final tokenized version of the data
TOKEN_ROOT_DIR = "data/nesmdb_token"
TOKEN_TRAIN_DIR = f"{TOKEN_ROOT_DIR}/train"
TOKEN_VALID_DIR = f"{TOKEN_ROOT_DIR}/valid"
TOKEN_TEST_DIR = f"{TOKEN_ROOT_DIR}/test"

VOCAB_PATH = "data/vocab.txt"

In [5]:
for d in [
    EVENT_TRAIN_DIR, EVENT_VALID_DIR, EVENT_TEST_DIR,
    TOKEN_TRAIN_DIR, TOKEN_VALID_DIR, TOKEN_TEST_DIR,
]:
    os.makedirs(d, exist_ok=True)

## Convert MIDI files to events

First, we convert the MIDI data to an event-based format, that represents music as `NOTEON`, `NOTEOFF`, and `WT` events, which represent notes turning on, off, and time passing, respectively.

In [6]:
midi_train_files = glob.glob(f"{MIDI_TRAIN_DIR}/*.mid")
midi_valid_files = glob.glob(f"{MIDI_VALID_DIR}/*.mid")
midi_test_files = glob.glob(f"{MIDI_TEST_DIR}/*.mid")

print(f"Train set files: {len(midi_train_files)}")
print(f"Validation set files: {len(midi_valid_files)}")
print(f"Test set files: {len(midi_test_files)}")

Train set files: 4502
Validation set files: 403
Test set files: 373


In [None]:
for file in tqdm(midi_train_files):
    prepro.convert_midi_file(file, outdir=EVENT_TRAIN_DIR)

for file in tqdm(midi_valid_files):
    prepro.convert_midi_file(file, outdir=EVENT_VALID_DIR)

for file in tqdm(midi_test_files):
    prepro.convert_midi_file(file, outdir=EVENT_TEST_DIR)

## Convert event files to tokens

Next, we tokenize the event-based data by quantizing each wait time event and adding the special `BOS` and `EOS` tokens to represent the start and end of the sequence, respectively.

In [9]:
event_train_files = glob.glob(f"{EVENT_TRAIN_DIR}/*.txt")
event_valid_files = glob.glob(f"{EVENT_VALID_DIR}/*.txt")
event_test_files = glob.glob(f"{EVENT_TEST_DIR}/*.txt")

event_files = event_train_files + event_valid_files + event_test_files

In [86]:
vocab = utils.load_data(VOCAB_PATH)

In [102]:
def wt_to_int(wt_token):
    """
    Converts "WT_23" to 23
    """
    return int(wt_token[3:])


wait_times = [wt_to_int(t) for t in vocab if "WT" in t]


def quantize_wait_event(wt_token):
    wait_time = wt_to_int(wt_token)
    quant_wait_time = min(wait_times, key=lambda x: abs(x - wait_time))
    quant_wait_event = f"WT_{quant_wait_time}"
    return quant_wait_event


def convert_event_file(infile, outdir):
    """Convert an event data file to a token file.
    
    All wait events are quantized, and the BOS and EOS tokens are
    added at the start and end of the sequence, respectively.
    """
    event_data = utils.load_data(infile)

    token_data = [
        quantize_wait_event(t) if "WT" in t else t
        for t in event_data
    ]
    token_data.insert(0, "BOS")
    token_data.append("EOS")

    _, fname = os.path.split(infile)
    outfile = os.path.join(outdir, fname)
    with open(outfile, "w") as f:
        f.write("\n".join(token_data))

In [104]:
# for file in tqdm(event_train_files):
#     convert_event_file(file, outdir=TOKEN_TRAIN_DIR)

for file in tqdm(event_valid_files):
    convert_event_file(file, outdir=TOKEN_VALID_DIR)

for file in tqdm(event_test_files):
    convert_event_file(file, outdir=TOKEN_TEST_DIR)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=403.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=373.0), HTML(value='')))


