# Lakh MIDI Dataset pre-processing

This notebook converts the [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) (LMD) to Magenta `NoteSequence` protocol buffers. `INPUT_DIR` is expected to point to the [`lmd_full`](http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz) directory.

Copyright 2020 InterDigital R&D and Télécom Paris.  
Author: Ondřej Cífka

In [1]:
import concurrent.futures as cf
import os
import pickle
import sys
import datetime
import warnings

import note_seq
import pretty_midi
from tqdm.auto import tqdm

In [2]:
INPUT_DIR = '../lmd_full/'
OUTPUT_DIR = 'data'
TOTAL_FILES = 178561

In [3]:
# Decrease MAX_TICK value to avoid running out of RAM. Long files will be skipped
pretty_midi.pretty_midi.MAX_TICK = 1e6

In [4]:
def get_paths():
    for dirpath, _, filenames in os.walk(INPUT_DIR):
        for filename in filenames:
            yield os.path.join(dirpath, filename)

In [5]:
def process_file(path):
    if os.stat(path).st_size > 100000:
        return None, 0

    try:
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', r'Tempo, Key or Time signature change events found on non-zero tracks')
            ns = note_seq.midi_io.midi_file_to_note_sequence(path)
    except note_seq.midi_io.MIDIConversionError:
        return None, 0
    out_path = os.path.splitext(path)[0] + f'.pickle'
    out_path = os.path.join(OUTPUT_DIR, os.path.relpath(out_path, INPUT_DIR))
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, 'wb') as f:
        pickle.dump(ns, f)
    return out_path, ns.total_time

In [6]:
os.makedirs(OUTPUT_DIR)
with cf.ProcessPoolExecutor(20) as pool:
    results = list(tqdm(
        pool.map(process_file, tqdm(get_paths(), desc='collect', total=TOTAL_FILES), chunksize=100),
        desc='convert', total=TOTAL_FILES))

print(sum(1 for p, _ in results if p is not None), '/', len(results), 'files converted successfully')
print('Total time:', datetime.timedelta(seconds=sum(t for _, t in results)))

HBox(children=(FloatProgress(value=0.0, description='collect', max=178561.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='convert', max=178561.0, style=ProgressStyle(description_w…


169556 / 178561 files converted successfully
Total time: 362 days, 3:02:38.526111
