In [254]:
import numpy as np 
import pretty_midi
import h5py
import os
import pandas as pd
from random import randrange

In [299]:
def data_preprocess(file_name,bar_number = 32):
    '''
    
    '''
    #从h5中拿到数据，h5文件要放置在当前文件夹中
    f = h5py.File(os.getcwd()+'/'+file_name, 'r')
    s=np.array(f['data'])
    
    #取我们想要的前N个小节
    s = s[:bar_number]
    
    #将所有小节的数据拼接在一起
    s = np.concatenate(tuple(s),axis=1)
    
    #对于除了第129行的每一行，将该行与第129行相加，这样只有头音的地方会变成2
    for i in range(len(s))[:-1]:
        s[i]=s[i]+s[-1]
        #对于该行的每一个单位，如果是2，则替换成音高数字，也就是该行的行数
        for j in range(len(s[i])):
            if s[i][j] == 2:
                s[i][j] = i
    #干掉第129行
    s = s[:-1].astype(np.int64)
    
    return s

In [274]:
def bpm_to_unit_time(bpm):
    #输入bpm，输出1/32音符的时长
    return 60.0/(bpm*2*4)

In [275]:
def row_start_end_compute(row,unit_time):
    '''
    输入一行数据，计算出这行数据中包含的note对应的
    起始时间、终止时间、音高数字、起始index，终止index的list
    '''
    start_time = []
    end_time = []
    start_idices = []
    end_idices = []
    notes_numbers = []
    for idx in range(len(row)):
        if row[idx] >1 and row[idx]<=127:
            start_idx = idx
            end_idx = start_idx
            note_number = row[idx]
            while end_idx<len(row)-1 and row[end_idx+1] ==1:
                end_idx=end_idx+1
            start_idices.append(start_idx)
            end_idices.append(end_idx)
            start_time.append(start_idx*unit_time)
            end_time.append((end_idx+1)*unit_time)
            notes_numbers.append(note_number)
    return start_time, end_time, start_idices, end_idices, notes_numbers   

def data_start_end_matrix(data_one_track, unit_time):
    '''
    逐行使用row_start_end_compute处理数据，
    并返回起始时间、终止时间、音高数字、起始index，终止index的矩阵
    '''
    start_time_matrix=[]
    start_idices_matrix=[]
    end_time_matrix=[]
    end_idices_matrix=[]
    notes_numbers_matrix = []
    for row in data_one_track:
        start_time, end_time, start_idices, end_idices,notes_numbers = row_start_end_compute(row,unit_time)
        
        start_time_matrix.append(start_time)
        end_time_matrix.append(end_time)
        start_idices_matrix.append(start_idices)
        end_idices_matrix.append(end_idices)
        notes_numbers_matrix.append(notes_numbers)
    return start_time_matrix,end_time_matrix,start_idices_matrix,end_idices_matrix,notes_numbers_matrix

In [305]:
def velocity_compute_row(start_idices,beat_pattern, bar_length,week_range,bar_number):
    """
    compute the velocity based on whether the onset of each note is on the heavy or week place;
    return a velocity_list represent the velocity value with respect to every start idex, namely every onset.
    
    beat_pattern:每个小节的beat分配方式，比如[2,3,3]
    bar_length：每个小节的长度，以数据最小颗粒度为单位，比如数据是以1/32为最小颗粒度，长度为1个整音符，bar_length就是32
    
    """
    velocity_list = []

    if len(start_idices) == 0:
        return velocity_list
    assert bar_length%sum(beat_pattern) == 0, "beat pattern must match bar_length" 
    #print(start_idices[-1], bar_length,bar_number,bar_length*bar_number)
    assert start_idices[-1] <bar_length*bar_number, "start idices must all be smaller then bar_length*bar_number"
    heavy_list = [0]
    
    #
    beat_length = int(bar_length/sum(beat_pattern))
    beat_pattern_whole =  beat_pattern * bar_number
    for beat in beat_pattern_whole:
        heavy_list.append(heavy_list[-1]+beat*beat_length)
    heavy_list = heavy_list[:-1]
    
    # get a 0 or 1 list which stands for whether each start index is heavy or not 
    velocity_list_zero_one = [int(idx in heavy_list) for idx in start_idices]
    for e in velocity_list_zero_one:
        if e == 0:
            velocity_list.append(randrange(week_range[0],week_range[1]))
        else:
            velocity_list.append(100)
    return velocity_list

def velocity_compute_matrix(start_idices_matrix,bar_number, beat_pattern=[2,3,3], bar_length=32,week_range=[70,90]):
    velocity_matrix=[]
    for start_idices in start_idices_matrix:
        velocity_matrix.append(velocity_compute_row(start_idices, beat_pattern,bar_length,week_range,bar_number))
    return velocity_matrix

In [277]:
def append_note(start_idices_matrix,velocity_matrix,start_time_matrix,end_time_matrix,notes_numbers_matrix,instru_name): 
    
    #convert instrument name to instrument number
    instru_program = pretty_midi.instrument_name_to_program(instru_name)
    #create instrument instance
    instru = pretty_midi.Instrument(program=instru_program)
    
    for i in range(len(start_idices_matrix)):
        notes_numbers = notes_numbers_matrix[i]
        velocity_list = velocity_matrix[i]
        start_time = start_time_matrix[i]
        end_time = end_time_matrix[i]
        
        for idx in range(len(start_time)):
            note = pretty_midi.Note(velocity=velocity_list[idx], pitch=notes_numbers[idx],start=start_time[idx],end=end_time[idx])
            instru.notes.append(note)
    return instru

In [278]:
def one_track_append(data_one_track,midi,instru_name,
                     bar_number,bpm,beat_pattern=[2,2,2,2],bar_length=32,week_range=[70,90]):
    #输入的数据是单乐器轨，串联所有小节的数据
    
    unit_time = bpm_to_unit_time(bpm)
    start_time_matrix,end_time_matrix,start_idices_matrix,end_idices_matrix,notes_numbers_matrix=data_start_end_matrix(data_one_track, unit_time)
    velocity_matrix = velocity_compute_matrix(start_idices_matrix, bar_number,beat_pattern, bar_length,week_range)
    intru = append_note(start_idices_matrix,velocity_matrix,start_time_matrix,end_time_matrix,notes_numbers_matrix,instru_name)
    midi.instruments.append(intru)
    return midi

In [279]:
'''要改一下data preprocessing的过程，处理成多轨连接的

def main(data, instruments=['Volin'],bar_number=1,bpm=120):
    #输入的数据可以是1个以上轨的数据，需要干掉所有用来补齐的数据（比如-1什么的），串联所有小节
    
    midi = pretty_midi.PrettyMIDI()
    for i in range(len(data)):
        instrument = instruments[i]
        data_one_track = data[i]
        midi = one_track_append(data_one_track,midi,instrument, bar_number,bpm)
    midi.write('test.mid')
'''

"要改一下data preprocessing的过程，处理成多轨连接的\n\ndef main(data, instruments=['Volin'],bar_number=1,bpm=120):\n    #输入的数据可以是1个以上轨的数据，需要干掉所有用来补齐的数据（比如-1什么的），串联所有小节\n    \n    midi = pretty_midi.PrettyMIDI()\n    for i in range(len(data)):\n        instrument = instruments[i]\n        data_one_track = data[i]\n        midi = one_track_append(data_one_track,midi,instrument, bar_number,bpm)\n    midi.write('test.mid')\n"

In [327]:
bar_number=36
data = data_preprocess('generated.hdf5',bar_number)


In [328]:
df = pd.DataFrame(data)
df.to_csv('debug')

In [329]:
midi = pretty_midi.PrettyMIDI()

In [330]:
midi = one_track_append(data, midi, instru_name='Cello',bar_number=bar_number, bpm=120,beat_pattern=[2,3,3])

In [331]:
midi.write('test3.mid')