In [34]:
# !pip install miditoolkit 
# !pip install magenta

In [35]:
import os
import random
from miditoolkit import MidiFile
import matplotlib.pyplot as plt
from collections import Counter
import subprocess
import datetime

### Data Exploration

In [5]:
midi_dirpath = 'nesmdb_midi/'
midi_train_dirpath = os.path.join(midi_dirpath, 'train')
midi_train_filesnames = os.listdir(midi_train_dirpath)
midi_train_filepaths = [os.path.join(midi_train_dirpath, filename) for filename in midi_train_filesnames]

In [6]:
def print_progress_bar(iteration, total, prefix='', length=50):
    percent = ("{0:.1f}").format(100 * (iteration / float(total)))
    filled_length = int(length * iteration // total)
    bar = '█' * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% Complete', end='\r', flush=True)
    if iteration == total:
        print()

In [8]:
def to_pretty_midi(filepaths: list, num_samples: int = 100) -> list: 
    sampled_filepaths = random.sample(filepaths, num_samples)
    midis = []
    for i, filepath in enumerate(sampled_filepaths):
        print_progress_bar(i + 1, num_samples, prefix='Converting MIDI files to PrettyMIDI')
        try:
            midi = MidiFile(filepath)
            midis.append(midi)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        
    return midis

In [9]:
midis = to_pretty_midi(midi_train_filepaths, num_samples=len(midi_train_filepaths))


Converting MIDI files to PrettyMIDI |██████████████████████████████████████████████████| 100.0% Complete


In [28]:
instrument_count_distribution = Counter(len(midi.instruments) for midi in midis)
unique_instruments = set([int(instrument.program) for midi in midis for instrument in midi.instruments])
unique_instruments_distribution = Counter([int(instrument.program) for midi in midis for instrument in midi.instruments])
instrument_sets_distribution = Counter(tuple(sorted([int(instrument.program) for instrument in midi.instruments])) for midi in midis)
print("Instrument count distribution:", dict(instrument_count_distribution))
print("Unique instruments:", unique_instruments)
print("Unique instruments distribution:", dict(unique_instruments_distribution))
print("Instrument set distribution:", dict(instrument_sets_distribution))

Instrument count distribution: {4: 2439, 3: 1509, 2: 450, 1: 102, 0: 2}
Unique instruments: {80, 81, 38, 121}
Unique instruments distribution: {80: 4334, 81: 4259, 38: 4045, 121: 2647}
Instrument set distribution: {(38, 80, 81, 121): 2439, (38, 80, 81): 1327, (80, 81): 277, (38, 81): 58, (38, 80, 121): 50, (38, 81, 121): 37, (80, 81, 121): 95, (80,): 41, (38, 80): 97, (38, 121): 8, (80, 121): 8, (81,): 24, (81, 121): 2, (38,): 29, (121,): 8, (): 2}


### LSTM Model

In [None]:
# filter the MIDI files to only include instruments with program number 81
# This is used as a starting point to train the performance-rnn model (as it is best suited for single instrument MIDI files)
def filter_program_81(input_root, output_root):
    for split in ['train', 'valid', 'test']:
        input_dir = os.path.join(input_root, split)
        output_dir = os.path.join(output_root, split)
        os.makedirs(output_dir, exist_ok=True)
        
        for i, fname in enumerate(os.listdir(input_dir)):
            print_progress_bar(i+1, len(os.listdir(input_dir)), prefix=f'Processing {split} MIDI files')
            if not fname.lower().endswith('.mid'):
                continue
            path = os.path.join(input_dir, fname)
            midi = MidiFile(path)
            filtered_instr = [inst for inst in midi.instruments if inst.program == 81]
            
            if filtered_instr:
                midi.instruments = filtered_instr
                midi.dump(os.path.join(output_dir, fname))

# filter_program_81(midi_dirpath, 'nesmdb_midi_program_81') 

Processing train MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Processing valid MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Processing test MIDI files |██████████████████████████████████████████████████| 100.0% Complete


In [37]:
def midi_to_note_sequence(midi_dir):
    subprocess.run([
        'convert_dir_to_note_sequences',
        '--input_dir=nesmdb_81_midi/train',
        '--output_file=nesmdb_81_notesequences.tfrecord',
        '--recursive'
    ], check=True)

midi_to_note_sequence('nesmdb_81_midi/train')

Traceback (most recent call last):
  File "/Users/leofriedman/Desktop/ucsd/cse_253/cse153-group-project/.env/bin/convert_dir_to_note_sequences", line 5, in <module>
    from magenta.scripts.convert_dir_to_note_sequences import console_entry_point
  File "/Users/leofriedman/Desktop/ucsd/cse_253/cse153-group-project/.env/lib/python3.10/site-packages/magenta/__init__.py", line 17, in <module>
    import magenta.common.beam_search
  File "/Users/leofriedman/Desktop/ucsd/cse_253/cse153-group-project/.env/lib/python3.10/site-packages/magenta/common/__init__.py", line 20, in <module>
    from .nade import Nade
  File "/Users/leofriedman/Desktop/ucsd/cse_253/cse153-group-project/.env/lib/python3.10/site-packages/magenta/common/nade.py", line 24, in <module>
    import tensorflow_probability as tfp
ModuleNotFoundError: No module named 'tensorflow_probability'


CalledProcessError: Command '['convert_dir_to_note_sequences', '--input_dir=nesmdb_81_midi/train', '--output_file=nesmdb_81_notesequences.tfrecord', '--recursive']' returned non-zero exit status 1.

In [None]:
def train_performance_rnn():
    current_time = datetime.datetime.now()
    subprocess.run([
        "performance_rnn_train",
        f"--run_dir=logdir/performance_rnn_finetune/{current_time}",
        "--sequence_example_file=nesmdb_81_notesequences.tfrecord",
        '--hparams=batch_size=64,rnn_layer_sizes=[256,256]',
        "--num_training_steps=5000",
        "--bundle_file=performance_with_dynamics.mag",
        "--save_checkpoints_steps=100",
        "--alsologtostderr"
    ], check=True)

# train_performance_rnn()

In [None]:
def generate_midi_from_performance_rnn():
    current_time = datetime.datetime.now()
    subprocess.run([
        "performance_rnn_generate",
        "--bundle_file=performance_with_dynamics.mag",
        f"--output_dir=performance_rnn_generated/{current_time}",
        "--num_outputs=1",
        "--num_steps=4000",  # ~1 minute at 100 qpm
        "--hparams=batch_size=64,rnn_layer_sizes=[256,256]",
        "--condition_on_primer=false",
        "--inject_primer_during_generation=false",
        "--instrument=81",
        "--alsologtostderr"
    ], check=True)