In [2]:
# 计算两个歌曲音符序列的DTW
import io
import os
import shutil
from dtw import dtw,accelerated_dtw
import numpy as np
from tqdm import tqdm

# output_dir = 'l2m'

In [3]:
SEP = '[sep]'
ALIGN = '[align]'
beam = 5

In [4]:
input_lyric_file = '/home/dwj/data/code/model_modify/zhonghao/data/para/test.lyric'
ground_truth_melody_file = '/home/dwj/data/code/model_modify/zhonghao/data/para/test.melody'
generate_melody_file = '/home/dwj/data/code/model_modify/zhonghao/generate/l2m/result/2021_7_7_test.log'
song_id_file = '/home/dwj/data/code/model_modify/zhonghao/data/para/test_song_id.txt'

generated_file = True
get_last = True
find_structure = True
cut_exceed_sent = True

In [5]:
def get_pitch_duration_structure(note_seq):
    seq = []
    
    #遍历寻找pitch-duration的结构
    #当有不合法情况出现时，找最后一个pitch和第一个duration，保证其相邻
    #p1 d1 p2 p3 d2 p4 d3-> p1 d1 p3 d1 p4 d3
    #p1 d1 p2 d2 d3 p3 d4-> p1 d1 p2 d2 p3 d4
    #p1 d1 p2 p3 d2 d3 p4 d4 -> p1 d1 p3 d2 p4 d4
    
    i = 0
    while (i<len(note_seq)):
        if note_seq[i] > 128:
            #Duration
            i += 1
            continue
        else:
            #Pitch
            if i+1>=len(note_seq):
                #No Duration Followed
                break
            if note_seq[i+1] <= 128:
                #Followed by a pitch
                i += 1
                continue
            
            #Followed by a duration
#             pitch = note_seq[i]
#             duration = float(Duration_vocab[note_seq[i+1]])
            
            #Here trans back to str for bleu calculate
            pitch = str(note_seq[i])
            duration = str(note_seq[i+1])
            
            seq.append(pitch)
            seq.append(duration)
            i += 2
    return seq

In [6]:
# 输入带有[SEP]的序列，分割句子
def separate_sentences(x,find_structure = False):
    lst = x.copy()
    sep_positions = [i for i,x in enumerate(lst) if x==SEP]
    sep_positions.insert(0,-1)

    ret = []
    for i in range(len(sep_positions)-1):
        sent = lst[sep_positions[i]+1:sep_positions[i+1]] #SZH: not include sep token
        if find_structure:
            sent = list(map(int, sent))
            sent = get_pitch_duration_structure(sent)
        ret.append(sent)
    return ret

In [7]:
#输入文件得到歌曲的pitch-duration序列
# generated_file : 是否为fairseq输出的结果内容， false为数据集
# getlast : 序列不是以[SEP]结尾时最后一个sep之后的内容要不要
# finde_structure : 是否调用get_pitch_duration_structure保证pitch-duration结构
# cut_exceed_sent : 当预测的句子超过歌词句子数时是否丢弃多的句子
def get_songs(file, generated_file = False, get_last = False, find_structure = False, cut_exceed_sent=False):
    """
    Get Last : Whether include the last tokens if the sequence not ends with a seperation token
    """
    with io.open(input_lyric_file,'r') as f:
        input_lyrics= f.readlines()
    input_lyrics = list(map(lambda x:x.rstrip('\n').split(' '), input_lyrics))
    input_lyrics_sent_num = list(map(lambda x:x.count(SEP), input_lyrics))
    
    with io.open(song_id_file,'r') as f:
        song_ids = f.readlines()
    song_ids = list(map(lambda x:int(x.rstrip('\n')),song_ids))   
    
    with io.open(file,'r') as f:
        melody_lines = f.readlines()
        if generated_file:
            melody_lines = list(filter(lambda x:x.startswith('H-'), melody_lines))
            if len(melody_lines) == len(input_lyrics) * beam:
                melody_lines.sort( key=lambda x:(int(x.split('\t')[0].split('-')[1]), -float(x.split('\t')[1]) ) )
                melody_lines = [ x for i,x in enumerate(melody_lines) if i%beam == 0 ]
            else:    
                melody_lines.sort( key=lambda x:int(x.split('\t')[0].split('-')[1]) )
    melody_lines = list(map(lambda x:x.rstrip('\n').split('\t')[-1] ,melody_lines))
            
    print(len(melody_lines), len(input_lyrics))
    assert len(melody_lines)==len(input_lyrics)
    
    melody_seqs = list(map(lambda x:x.rstrip('\n').split(' '),melody_lines))
    melody_seqs = [ch for ch in melody_seqs if ch != ALIGN]
    
    for i in range(len(melody_seqs)):
        melody_seqs[i] = list(filter(lambda x:x.isdigit() or x==SEP, melody_seqs[i]))
        
    if get_last:
        for i in range(len(melody_seqs)):
            if melody_seqs[i][-1] != SEP:
                melody_seqs[i].append(SEP)
          
    # 分句子，同时find structure
    melody_seq_sents = list(map(lambda x:separate_sentences(x,find_structure=find_structure), melody_seqs))

    
    # 把句子组合回整首歌，同时切除过多的句子
    return_list = []
    for i,sent_seq in enumerate(melody_seq_sents):
        if cut_exceed_sent and len(sent_seq) > input_lyrics_sent_num[i]:
            sent_seq = sent_seq[0:input_lyrics_sent_num[i]]
        cur_song_return = []

        for j,sent in enumerate(sent_seq):
            cur_song_return.extend(sent)
            cur_song_return.append(SEP)
        return_list.append(cur_song_return)
      

    # 整合回歌
    song_num = song_ids[-1] + 1
    songs = [ [] for _ in range(song_num)]
    for item_id,song_id in enumerate(song_ids):
        songs[song_id].extend(return_list[item_id])
    
    return songs
            


In [8]:
ground_truth_melody = get_songs(ground_truth_melody_file)

714 714


In [9]:
output_melodies = get_songs(generate_melody_file, generated_file, get_last, find_structure, cut_exceed_sent)

714 714


In [10]:
def to_tuple(x):
    pitch_and_duration = [ch for ch in x if ch != SEP and ch != ALIGN]
    return [(pitch_and_duration[2*i], pitch_and_duration[2*i+1]) for i in range(len(pitch_and_duration)//2)]
ground_truth_melody = list(map(to_tuple, ground_truth_melody))
output_melodies = list(map(to_tuple, output_melodies))

In [11]:
# Duration_vocab = dict([(str(x/100),129+i)  for i,x in enumerate(list(range(25,3325,25)))])
Duration_vocab = dict([(str(129+i), x/100)  for i,x in enumerate(list(range(25,3325,25)))])
print(Duration_vocab)

{'129': 0.25, '130': 0.5, '131': 0.75, '132': 1.0, '133': 1.25, '134': 1.5, '135': 1.75, '136': 2.0, '137': 2.25, '138': 2.5, '139': 2.75, '140': 3.0, '141': 3.25, '142': 3.5, '143': 3.75, '144': 4.0, '145': 4.25, '146': 4.5, '147': 4.75, '148': 5.0, '149': 5.25, '150': 5.5, '151': 5.75, '152': 6.0, '153': 6.25, '154': 6.5, '155': 6.75, '156': 7.0, '157': 7.25, '158': 7.5, '159': 7.75, '160': 8.0, '161': 8.25, '162': 8.5, '163': 8.75, '164': 9.0, '165': 9.25, '166': 9.5, '167': 9.75, '168': 10.0, '169': 10.25, '170': 10.5, '171': 10.75, '172': 11.0, '173': 11.25, '174': 11.5, '175': 11.75, '176': 12.0, '177': 12.25, '178': 12.5, '179': 12.75, '180': 13.0, '181': 13.25, '182': 13.5, '183': 13.75, '184': 14.0, '185': 14.25, '186': 14.5, '187': 14.75, '188': 15.0, '189': 15.25, '190': 15.5, '191': 15.75, '192': 16.0, '193': 16.25, '194': 16.5, '195': 16.75, '196': 17.0, '197': 17.25, '198': 17.5, '199': 17.75, '200': 18.0, '201': 18.25, '202': 18.5, '203': 18.75, '204': 19.0, '205': 19.25

In [12]:
# pitch按时长不拍，比如 64 0.25 65 0.5 67 1 铺开成 64 65 65 67 67 67 67
def flatten(note_seq, ign_rest = False):
    ret = []
    for note in note_seq:
        pitch = int(note[0])
        duration = Duration_vocab[note[1]]
        if pitch==128:
            if ign_rest : continue
            if len(ret)==0: continue
            ret.extend( [ ret[-1] ] * int(duration*4)  )
        else:
            ret.extend( [pitch] * int(duration*4) )
    return ret

# 音高序列采样
def sample(flat_note_seq, freq = 2):
    # 1/32 -> 1
    # 1/16 -> 2
    return [ flat_note_seq[i*freq] for i in range(len(flat_note_seq)//freq)  ]

# 将pith序列转化为和每个pitch之前一个pitch的相对音高
def grad(flat_note_seq):
    if len(flat_note_seq)==0:
        return []
    ret = flat_note_seq.copy()
    for i in range(len(ret)-1,0,-1):
        ret[i] = ret[i] - ret[i-1]
    ret[0] = 0
    return ret



In [13]:
ground_truth_flat = list(map(flatten, ground_truth_melody))
output_melody_flat = list(map(flatten, output_melodies))

In [14]:
ground_truth_samp=list(map(sample,ground_truth_flat))
output_melody_samp=list(map(sample,output_melody_flat))

In [15]:
ground_truth_grad=list(map(grad,ground_truth_samp))
output_melody_grad=list(map(grad,output_melody_samp))

In [None]:
# 几种计算方式

In [16]:
dtw_samp = []
for i in tqdm(range(len(ground_truth_samp))):
    if len(ground_truth_samp[i])==0 or len(output_melody_samp[i])==0: continue
    d1 = np.array(ground_truth_samp[i]).reshape(-1,1)
    d2 = np.array(output_melody_samp[i]).reshape(-1,1)
    d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(d1,d2, dist='euclidean')
    dtw_samp.append(d / len(d2))

100%|██████████| 461/461 [03:47<00:00,  2.03it/s]


In [17]:
sum(dtw_samp)/len(dtw_samp)

4.218924596978432

In [18]:
dtw_grad = []
for i in tqdm(range(len(ground_truth_grad))):
    if len(ground_truth_grad[i])==0 or len(output_melody_grad[i])==0: continue
    d1 = np.array(ground_truth_grad[i]).reshape(-1,1)
    d2 = np.array(output_melody_grad[i]).reshape(-1,1)
    d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(d1,d2, dist='euclidean')
    dtw_grad.append(d / len(d2))

100%|██████████| 461/461 [03:47<00:00,  2.02it/s]


In [19]:
sum(dtw_grad)/len(dtw_grad)

0.7069763255528089

In [20]:
dtw_mean = []
for i in tqdm(range(len(ground_truth_samp))):
    if len(ground_truth_samp[i])==0 or len(output_melody_samp[i])==0: continue
    d1 = np.array(ground_truth_samp[i]).reshape(-1,1)
    d2 = np.array(output_melody_samp[i]).reshape(-1,1)
    d1 = d1 - np.mean(d1)
    d2 = d2 - np.mean(d2)
    d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(d1,d2, dist='euclidean')
    dtw_mean.append(d / len(d2))

100%|██████████| 461/461 [03:42<00:00,  2.07it/s]


In [21]:
sum(dtw_mean)/len(dtw_mean)

2.9824720033086685

In [None]:
# d1 = np.array(ground_truth_samp[0]).reshape(-1,1)
# d2 = np.array(output_melody_samp[0]).reshape(-1,1)
# d, cost_matrix, acc_cost_matrix, path = accelerated_dtw(d1,d2, dist='euclidean')

In [None]:
# d / len(d2)

In [None]:
# import matplotlib.pyplot as plt

# plt.imshow(acc_cost_matrix.T, origin='lower', cmap='gray', interpolation='nearest')
# plt.plot(path[0], path[1], 'w')
# plt.show()

In [None]:
# from collections import defaultdict
# cnt = defaultdict(int)
# for song in ground_truth_melody:
#     for note in song:
#         cnt[note[1]]+=1