# SuperSlakh

SuperSlakh is a dataset of synthesized MIDI songs for training music tasks

## 1. Download Data

To start, we'll download the [Lakh Midi dataset](https://colinraffel.com/projects/lmd/) and some soundfonts. 

The function downloads every soundfont from [this archive.org](https://archive.org/download/free-soundfonts-sf2-2019-04) list, but feel free to add more to the `soundfonts/` folder

In total, this downloads 6 Gb of Midi, and 6.4 Gb of soundfonts

I also used:

- A couple of [St. GIGA's](http://stgiga.weebly.com/creations.html) soundfonts (the 4gb one is great)
- [Tyroland (musical artifacts)](https://musical-artifacts.com/artifacts/1305)
- [Phil's Computer Lab](https://www.philscomputerlab.com/general-midi-and-soundfonts.html)

In [None]:
from src.download import download_midi_and_soundfonts


download_midi_and_soundfonts()

## 2. Catalog & Split Songs

In order to make working with the data a bit easier, I'm using a SQLite database to manage Songs, Stems, Kits, and Instruments. This allows for read/write in parallel, relating  
`songs > stems <-> instruments < kits`

Now we need to read each soundfont file to get the metadata into the database.

In [None]:
"""
for each soundfont, 
    create a kit
    create instrument for each preset
"""
import fluidsynth
import os
from src.db import SQLiteClient
soundfont_dir = 'soundfonts/'

soundfonts = os.listdir(soundfont_dir)
num_presets = 0
with SQLiteClient() as client:
    fs = fluidsynth.Synth()
    fs.start()
    for file in soundfonts:
        try:
            sfid = fs.sfload(os.path.join(soundfont_dir, file))
            kit_id = client.insert_kit(file)
            for bank in range(128):
                for preset_num in range(128):
                    name = fs.sfpreset_name(sfid, bank, preset_num)
                    if name is not None:
                        num_presets += 1
                        client.insert_instrument(name, bank, preset_num, file, kit_id)
                fs.sfunload(sfid)
        except Exception as e:
            print('ERROR---', e)
        
    fs.delete()

print(len(soundfonts), num_presets)

In [None]:
from src.db import SQLiteClient

with SQLiteClient() as client:
    instruments = client.get_all_instruments()
    kit_ids = [row['kit_id'] for row in instruments]
    print(len(set(kit_ids)), len(instruments))

And similar for each midi file

We'll also do a bit of filtering here, we're only interested in songs that:
- are valid / readable
- are between >30s and <6min
- have a drum track

In [None]:
from src.midi.intake import extract_midi_metadata
import time
from concurrent.futures import as_completed
import random
import sqlite3
"""
for each midi file
    read w/ pretty_midi
    if valid
        create song
        create stem for each track
"""
root = 'midi/'

def process_midi_file(midi_path: str):
    try:
        with SQLiteClient() as client:
            if client.does_path_exist(midi_path):
                print('noneed')
                return
            midi_data = extract_midi_metadata(midi_path)
            if 30 < midi_data.length < 360 and midi_data.has_drum:
                inserted = False
                while not inserted:
                    try:
                        client.insert_song(midi_path, None, midi_data.downbeats, midi_data.beats, midi_data.bpm, [])
                        inserted = True
                    except sqlite3.OperationalError as e:
                        if 'database is locked' in str(e):
                            print('locked, waiting')
                            time.sleep(1)  # Improved back-off strategy might be needed
                        else:
                            return
        print(f"Processed {midi_path}")
        return
    except Exception as e:
        print(f"Error processing {midi_path}: {e}")
        return

def get_file_list(root):
    file_list = []
    for subdir in os.listdir(root):
        files = os.listdir(os.path.join(root, subdir))
        file_list.extend([os.path.join(root, subdir, f) for f in files if f.endswith('.mid')])
    random.shuffle(file_list)
    return file_list


file_list = get_file_list(root)

from concurrent.futures import ProcessPoolExecutor
with ProcessPoolExecutor(max_workers=14) as executor:
    futures = [executor.submit(process_midi_file, file) for file in file_list]
    for future in as_completed(futures):
        future.result()

## 3. Classify Instruments

The songs denote which instrument should play each track by the preset number (for the most part). 

However, the soundfont preset numbers determined by whoever made each kit, and while many conform to the GM standard, many do not.

So, we need to assign a `gm_class` to each instrument, using its filename, preset name, etc. to match it to the most likely instrument class,  
e.g. `E Guitar Cln` is most likely `28, Electric Guitar (clean)`.

So, we'll do a first pass to classify all the exact matches, a second pass using string distance (levenshtein) to classify the very similar matches, and finally an LLM pass to classify the rest.

Then, we'll do a quick check over the stems, and reassign the preset number of each stem whose name is an exact match for another instrument (default class is 1, Accoustic Grand Piano, and there are many '1's with names like 'Trumpet')

In [None]:
from classify_inst import predict_gm_class_gemini
import pandas as pd
import json
from multiprocessing import Pool

inst_class = pd.read_csv('data/midi_instrument.csv')



def process_row(row):
    subdict = {k: row[k] for k in row.keys() if k not in ['id', 'gm_class']}
    try:
        if subdict['name'].lower() in inst_class['name'].str.lower().values:
            gm_class = inst_class.loc[inst_class['name'].str.lower() == subdict['name'].lower(), 'program'].values[0]
        else:
            gm_class, name = predict_gm_class_gemini(json.dumps(subdict)).split(',', 1)
            gm_name = inst_class.loc[inst_class['program'] == int(gm_class), 'name'].values[0]
            print(f"{subdict['name']} -- {name} -- {gm_name}")
            if not name == gm_name:
                gm_class = inst_class.loc[inst_class['name'] == name, 'program'].values[0]
        with SQLiteClient() as db:
            db.update_inst_gm(row['id'], int(gm_class))
    except Exception as e:
        print(e)
        with SQLiteClient() as db:
            db.update_inst_gm(row['id'], -2)

with SQLiteClient() as db:
    instruments = [dict(row) for row in db.get_all_instruments()]
    inst_df = pd.DataFrame(instruments)
    print(len(instruments))

with Pool(4) as p:
    p.map(process_row, [row for _, row in inst_df.iterrows()])

Also, many of the drum instruments are 'Misc Sound Effect Banks', so we'll do a quick listen to the kick, snare, hihat of each to make sure they're valid.

In [None]:
import time

with SQLiteClient() as client:
    instruments = [dict(row) for row in client.get_all_instruments()]
    inst_df = pd.DataFrame(instruments)

drum_df = inst_df[inst_df['bank'] > 126]

def test_drum_presets(df):
    fs = fluidsynth.Synth()
    fs.start(driver='alsa')

    for index, row in df.iterrows():
        if row['flag'] != 0:
            continue
        root = 'soundfonts'
        sfid = fs.sfload(os.path.join(root, row['sf_path']))
        fs.program_select(9, sfid, row['bank'], row['preset'])

        beats = [36, 38, 42]  # Kick, Snare, Hi-Hat
        for beat in beats:
            fs.noteon(9, beat, 100)
            time.sleep(0.5)  
            fs.noteoff(9, beat)

        user_input = input("Keep this preset? (y/n): ")
        if user_input.lower() == 'n':
            df.at[index, 'flag'] = -1
        elif user_input.lower() != 'y':
            print("Stopping.")
            break
        else:
            df.at[index, 'flag'] = 1

        fs.sfunload(sfid, True)

    # Cleanup
    fs.delete()
    return df

flag_df = test_drum_presets(drum_df)
flag_df.to_csv('flagged_drum.csv')

## 4. Assign Instruments to Stems

My first thought here was get all stems & presets of each GM class and assign each preset proportionally (but randomly) to each stem. However, this is super inefficient if you need to render per-song instead of per-preset, as you would need to load a new soundfont for every instrument, rather than doing a quick program change (foreshadowing).

So, instead we want to use as few kits per song as possible. The core logic is then:
- for each song:
    - pick a random kit
    - while there are unassigned stems:
        - for each stem
            - pick an unused preset in kit w/ matching class
            - if none, continue
        - if there are still unassigned stems, pick a new kit

This way, we can minimize the amount of kit-switching required by the renderer. Also, the kits tend to be 'similar in vibe' and by choosing mostly instruments from a single kit, the outputs are more sonically cohesive.

In [None]:
from src.db import SQLiteClient
import pandas as pd
import random

with SQLiteClient() as client:
    instruments =[dict(row) for row in client.get_all_instruments()]
    inst_df = pd.DataFrame(instruments)

    stems =[dict(row) for row in client.get_unassigned_stems()]
    stem_df = pd.DataFrame(stems)
    stem_df['program'] = stem_df['program'].apply(lambda x: min(x+1, 128))

song_ids = (stem_df['song_id'].unique())
kit_ids = inst_df['kit_id'].unique()
kit_to_num_presets = {}
for kit_id in kit_ids:
    kit_to_num_presets[kit_id] = len(inst_df[inst_df['kit_id'] == kit_id])

# Put X kit_ids where X = num presets in kit
scaled_kit_ids = [kit_id for kit_id in kit_ids for _ in range(kit_to_num_presets[kit_id])]

def pick_random_kit_id(used_ids):
    unused_ids = [id for id in kit_ids if id not in used_ids]
    if unused_ids:
        return random.choice(unused_ids)
    else:
        return None

song_to_stems = {}
for song_id in song_ids:
    song_to_stems[song_id] = stem_df[stem_df['song_id'] == song_id]

kit_to_presets = {}
for kit_id in kit_ids:
    kit_to_presets[kit_id] = inst_df[inst_df['kit_id'] == kit_id]

In [None]:
def assign_song_presets(song_id):
    stems_for_song = song_to_stems[song_id]
    used_ids = []
    usable_kits = scaled_kit_ids.copy()
    while stems_for_song['inst_id'].isnull().any():
        random_kit_id = random.choice(usable_kits)
        if random_kit_id is None: # No more kits
            break
        used_ids.append(random_kit_id)
        usable_kits =[kit for kit in usable_kits if kit != random_kit_id]
        presets_in_kit = kit_to_presets[random_kit_id]
        null_stems =  stems_for_song[stems_for_song['inst_id'].isnull()]
        for _index, stem in null_stems.iterrows():
            usable_presets = presets_in_kit[~presets_in_kit['id'].isin(stems_for_song['inst_id'])]
            if stem['is_drum']:
                suitable_presets = usable_presets[usable_presets['is_drum'] == 1]
            else:
                suitable_presets = usable_presets[usable_presets['gm_class'] == stem['program']]
            if not suitable_presets.empty:
                chosen_preset = suitable_presets.sample(1)
                stems_for_song.loc[stems_for_song['id'] == stem['id'], 'inst_id'] = chosen_preset['id'].values[0]

    if len(used_ids) > 0:
        print(f'NUM KITS = {len(used_ids)}')
        stem_inst_list = stems_for_song[['id', 'inst_id']].to_dict('records')
        with SQLiteClient() as client:
            client.update_stem_inst_ids(stem_inst_list)
    return

In [None]:
from multiprocessing import Pool

with Pool(processes=12) as p:
    p.map(assign_song_presets, song_ids)

## 5. Render Songs

In a perfect world, we would simply load each kit, render every stem using a preset from that kit, and then mix the stems into the full songs when you're done. This parallelizes nicely per-preset and is the best-case scenario in terms of io overhead.

However, 1.1M ~2:00 Flacs at 22kHz, is ~5.5Tb, and I don't have that kinda room.

So, we need to instead:
- Render each stem in a song
- Mix the full song
- Delete the stems

Which means we need to parallelize per-song instead, so we can delete stems as we go.

This is a loooot more inefficient, as we need to load/unload a few kits per song, but it will do.

In [None]:
import multiprocessing
import os
import pretty_midi
import shutil
from src.audio.render import make_synth, fluidsynthesize, save_audio, mix_audios
from src.db import SQLiteClient

SAMPLE_RATE = 22050

def render_stem(stem, synthesizer, sfid, tracknum):
    output_path = stem['midi_filepath'].replace('/midi/', '/audio/').replace('.mid', '.flac')
    if os.path.isfile(output_path):
        print('stem -- '+stem['name'])
        return output_path
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    inst = pretty_midi.PrettyMIDI(stem['midi_filepath']).instruments[0]
    audio = fluidsynthesize(inst, fs=SAMPLE_RATE, synthesizer=synthesizer, sfid=sfid, channel=tracknum)
    save_audio(audio, output_path, normalize=True, sr=SAMPLE_RATE)
    print('stem -- '+stem['name'])
    return output_path

def render_song(song_id):
    with SQLiteClient() as client:
        stems_and_instruments = client.get_stems_and_instruments(song_id)
    stems_and_instruments.sort(key=lambda elem: elem[1]['kit_id'])
    stem_paths = []
    current_sf_path = stems_and_instruments[0][1]['sf_path']
    synth, sfid = make_synth(current_sf_path, sr=SAMPLE_RATE)
    for (stem, instrument) in stems_and_instruments:
        if stem['audio_filepath'] is not None:
            stem_paths.append(stem['audio_filepath'])
            continue
        if instrument['sf_path'] != current_sf_path:
            synth.delete()
            del synth
            current_sf_path = instrument['sf_path']
            synth, sfid = make_synth(current_sf_path)
        tracknum = 9 if stem['is_drum'] else 0
        synth.program_select(tracknum, sfid, instrument['bank'], instrument['preset'])
        output_path = render_stem(stem, synth, sfid, tracknum)
        stem_paths.append(output_path)
    output_path = os.path.join(stem_paths[0].split('instruments')[0], 'full.flac')
    mix_audios(stem_paths, output_path)
    print('mixed', output_path)
    shutil.rmtree(os.path.join(os.path.dirname(output_path), 'instruments'), ignore_errors=True)
    with SQLiteClient() as client:
        client.update_song_audio_filepath(song_id, output_path)


with SQLiteClient() as client:
    song_ids = [song['id'] for song in client.get_unrendered_song_ids()]
print(len(song_ids))
with multiprocessing.Pool(8) as p:
    p.map(render_song, song_ids)

## 6. Extract Features

see `extract_features.py`

In [9]:
from src.audio.io import split_audio_by_tempo
from src.db import SQLiteClient
import random
with SQLiteClient() as client:
    songs = client.get_rendered_songs()


# Pick a random song from the list of songs
random_song = random.choice(songs)

# Print the details of the randomly selected song
splits = split_audio_by_tempo(random_song['audio_filepath'], random_song['midi_filepath'])
for split in splits:
    print(split[0], len(split[1]))

93 2123520


ValueError: not enough values to unpack (expected 2, got 1)

In [28]:
import librosa

from IPython.display import Audio

path = '/media/bleu/bulkdata2/superslakh/audio/2/20bb00d8df3f27954c21a977d3fcaaf9/full.flac'

audio, sr = librosa.load(path)

Audio(audio, rate=sr)