In [None]:
import os, sys
root_path = './'  ## replace this with your root path (i.e., path of this current project)
os.environ['PYTHONPATH'] = root_path
sys.path.append(root_path)
import miditoolkit
import numpy as np
import math
import os, pickle, glob, shutil
from tqdm import tqdm
from utils.indexed_datasets import IndexedDatasetBuilder
import multiprocessing as mp
import traceback

In [None]:
test_sample = ""  ## replace this with a midi test sample

In [None]:
def stress_simple (start, dur, reso=480):
    ## simplified version of stress detection
    ## only considers the first beat of each bar as the strongest beat
    bar = 4*reso
    unit_len = bar
    beat_pos = start - (start//unit_len) * unit_len
    beat_num = beat_pos // (unit_len//4)
    if beat_num == 0:
        return "<strong>"
    elif beat_num == 2:
        return "<substrong>"
    else:
        return "<weak>"

In [None]:
def stress (start, dur, reso=480):
    ## full version of stress detection
    ## decides the stress based on location and the note duration
    if dur in [reso, reso//2, reso//4, reso//8, reso//16, reso*2, reso*4] and start%dur == 0:
        ## categorise the note duration
        unit_len = 4 * dur
    else:
        unit_len = 4 * reso
    beat_pos = start - (start//unit_len) * unit_len
    beat_num = beat_pos // (unit_len//4)
    if beat_num == 0:
        return "<strong>"
    elif beat_num == 2:
        return "<substrong>"
    else:
        return "<weak>"

In [None]:
def prosody (midi_pth: str):
    ## use absolute value
    prosody = []
    
    midi = miditoolkit.MidiFile(midi_pth)
    ## group by bar:
    bar = {}
    ## calculate average note length
    note_durs = []
    strength, length = [], []
    reso = midi.ticks_per_beat
    for inst in midi.instruments:
        for i, note in enumerate(inst.notes):
            strength = stress(start=note.start, dur=note.end-note.start, reso=reso)
            length = "<long>" if note.end-note.start>reso else "<short>"
            prosody.append((strength, length))
    
    return prosody

In [None]:
## Test cell
print(len(prosody(test_sample)))

In [None]:
def phrasing (midi_pth: str):
    ## phrase-level segmentation
    ## midi_pth: the path to the midi file to segment 
    ## (n.b. the midi file has to be monophonic)
    midi = miditoolkit.MidiFile(midi_pth)
    assert len(midi.instruments) == 1  ## monophonic check (partial)
    reso = midi.ticks_per_beat
    notes = midi.instruments[0].notes.copy()
    
    long = []
    pause = []
    note_info = []
    
    for idx, note in enumerate(notes):
        note_bar = note.start // (4 * reso) ## a bar == 4 beat == 4 * 480 ticks
        note_pos = (note.start - (note_bar * 4 * reso)) ## relative position in the current bar
        note_pitch = note.pitch
        note_dur = note.end - note.start
        note_info.append((note_bar, note_pos, note_pitch, note_dur))
        if note_dur > reso:
            long.append(idx)
        if (idx > 0) and (notes[idx].start-notes[idx-1].end >= reso//2):
            pause.append(idx-1)
    
    union = list(set(long + pause))
    if 0 in union:
        union.remove(0)
    if len(notes)-1 in union:
        union.remove(len(notes)-1)
    union.sort()
    
    def dur(note: miditoolkit.Note):
        return abs(note.end-note.start)
    
    i = 1
    while i<len(union):
        if abs(union[i-1]-union[i]) == 1:
            if abs(dur(notes[union[i-1]])-dur(notes[union[i]])) > 240:
                union.remove(union[i])
            else:
                union.remove(union[i-1])
        i = i + 1
    
    ## annotate the phrase markers
    midi.markers=[]
    for k, b in enumerate(union):
        midi.markers.append(miditoolkit.Marker(time=notes[b].end, text=f"Phrase_{k}"))
    
    ## uncomment the following line if you intend to save the segmented midi
    # midi.dump(os.path.join('./', os.path.basename(midi_pth)[:-4]+'_phrased.mid'))
    
    is_boundary = []
    for i in range(len(notes)):
        if i in union:
            is_boundary.append("<true>")
        else:
            is_boundary.append("<false>")
    
    assert len(note_info) == len(is_boundary)
    return is_boundary, note_info, union

In [None]:
def get_notes (midi_pth: str):
    midi = miditoolkit.MidiFile(midi_pth)
    assert len(midi.instruments) == 1  ## monophonic
    reso = midi.ticks_per_beat
    notes = midi.instruments[0].notes.copy()
    
    note_info = []
    
    for note in notes:
        note_info.append()

In [None]:
## Test cell
bound, note_info, _ = phrasing(test_sample)
print(len(bound), len(note_info))

In [None]:
def tokenise (midi_pth, event2word_dict):
    ## tokenise a midi sample
    ## midi_pth: the path to the midi to tokenise
    prsd = prosody(midi_pth)
    bound, notes, _ = phrasing(midi_pth)
    assert len(prsd) == len(notes)
    src_words, tgt_words = [], []
    
    tgt_words.append({
        'bar':event2word_dict['Bar'][f"<s>"],
        'pos':event2word_dict['Pos'][f"<pad>"],
        'token':event2word_dict['Pitch'][f"<pad>"],
        'dur':event2word_dict['Dur'][f"<pad>"],
        'phrase':event2word_dict['Phrase'][f"<pad>"],
    })
    
    for idx in range(len(prsd)):
        if notes[idx][0] >= 200:
            return [], []
        src_words.append({
            'strength':event2word_dict['Strength'][prsd[idx][0]],
            'length':event2word_dict['Length'][prsd[idx][1]],
            'phrase':event2word_dict['Phrase'][bound[idx]],
        })
        tgt_words.append({
            'bar':event2word_dict['Bar'][f"Bar_{notes[idx][0]}"],
            'pos':event2word_dict['Pos'][f"Pos_{notes[idx][1]}"],
            'token':event2word_dict['Pitch'][f"Pitch_{notes[idx][2]}"],
            'dur':event2word_dict['Dur'][f"Dur_{notes[idx][3]}"],
            'phrase':event2word_dict['Phrase'][bound[idx]],
        })
    
    ## eos
    src_words.append({
        'strength':event2word_dict['Strength'][f"</s>"],
        'length':event2word_dict['Length'][f"</s>"],
        'phrase':event2word_dict['Phrase'][f"</s>"],
    })
    tgt_words.append({
        'bar':event2word_dict['Bar'][f"</s>"],
        'pos':event2word_dict['Pos'][f"<pad>"],
        'token':event2word_dict['Pitch'][f"<pad>"],
        'dur':event2word_dict['Dur'][f"<pad>"],
        'phrase':event2word_dict['Phrase'][f"<pad>"],
    })
    
    return src_words, tgt_words

In [None]:
def data_to_binary (midi_pth, i, event2word_dict, split):
    ## single-file base function of transforming data into binary representations
    try:
        src_words, tgt_words = tokenise(midi_pth, event2word_dict)
        if len(src_words) == 0 or len(tgt_words) == 0 or len(tgt_words) > 1024:
            return None
        
        data_sample = {
            'input_path': midi_pth,
            'item_name': os.path.basename(midi_pth),
            'src_words': src_words,
            'tgt_words': tgt_words,
            'word_length': len(tgt_words)
        }
        
        return [data_sample]
    
    except Exception as e:
        traceback.print_exc()
        return None

In [None]:
def data2binary(dataset_dirs, words_dir, split, word2event_dict, event2word_dict):
    ## batch processing of data binarisation
    ## dataset_dirs: directories of data to binarise
    ## words_dir: output directory of binarised data
    ## split: the name of split of the dataset (e.g., 'train', 'valid', 'test')

    # make the output directory if non-existent
    save_dir = f'{words_dir}/{split}'
    if os.path.exists(save_dir):
        shutil.rmtree(save_dir)
    os.makedirs(save_dir)
    
    midi_files = []
    for dataset_dir in dataset_dirs:
        midi_files.extend(glob.glob(os.path.join(os.path.join(dataset_dir, split), "*.mid")))
    
    futures = []
    ds_builder = IndexedDatasetBuilder(save_dir)  # index dataset
    p = mp.Pool(int(os.getenv('N_PROC', 2)))  # 不要开太大，容易内存溢出
    
    for i in range (len(midi_files)):
        futures.append(p.apply_async(data_to_binary, args=[midi_files[i], i, event2word_dict, split]))
    p.close()

    words_length = []
    all_words = []
    for f in tqdm(futures):
        item = f.get()
        if item is None:
            continue
        for i in range(len(item)):
            sample = item[i]
            words_length.append(sample['word_length'])
            all_words.append(sample)
            ds_builder.add_item(sample) # add item index

    # save 
    ds_builder.finalize()
    np.save(f'{words_dir}/{split}_words_length.npy', words_length)
    np.save(f'{words_dir}/{split}_words.npy', all_words)
    p.join()
    print(f'| # {split}_tokens: {sum(words_length)}')
    
    return all_words, words_length

In [None]:
## shuffle and split dataset
def split_data(output_dir='./'):
    dataset_dirs = [] ## fill in your own directories to datasets
    all_files = []
    for dataset_dir in dataset_dirs:
        all_files.extend(glob.glob(os.path.join(dataset_dir, '*.mid')))
    ## shuffle
    print(f"|>>> Total Files: {len(all_files)}")
    
    indices = [i for i in range(len(all_files))]
    import random, shutil
    random.shuffle(indices)
    train_end = int(np.floor(0.8*len(all_files)))
    valid_end = int(train_end + np.floor(0.1*len(all_files)))
    train_idx = indices[:train_end]
    valid_idx = indices[train_end:valid_end]
    test_idx = indices[valid_end:]
    assert len(all_files) == len(train_idx)+len(valid_idx)+len(test_idx)
    print(f"|>>>>> Train Files: {len(train_idx)}")
    print(f"|>>>>> Valid Files: {len(valid_idx)}")
    print(f"|>>>>> Test Files: {len(test_idx)}")
    
    for split in ['train', 'test', 'valid']:
        os.makedirs(os.path.join(output_dir, split), exist_ok=True)
    
    for t in train_idx:
        shutil.copy(all_files[t], os.path.join(f'{output_dir}/train', os.path.basename(all_files[t])))
    for v in valid_idx:
        shutil.copy(all_files[v], os.path.join(f'{output_dir}/valid', os.path.basename(all_files[v])))
    for t in test_idx:
        shutil.copy(all_files[t], os.path.join(f'{output_dir}/test', os.path.basename(all_files[t])))

In [None]:
split_data()

In [None]:
for split in ['train', 'valid', 'test']:
    data2binary(dataset_dirs=dataset_dirs,
                words_dir=words_dir,
                split=split,
                word2event_dict=word2event_dict,
                event2word_dict=event2word_dict)