In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
sys.path.append('/screamlab/home/eri24816/improved-diffusion')
from os import path 
from utils.pianoroll import PianoRoll, PianoRollDataset
from tqdm import tqdm

from metrics import iou, mse

In [2]:
# keys
data_dir = '/screamlab/home/eri24816/pianoroll_dataset/data/dataset_1/pianoroll/'
songs = PianoRollDataset(data_dir,32,32,metadata_file='/screamlab/home/eri24816/pianoroll_dataset/data/dataset_1/metadata.csv')
len(songs.pianorolls)

Creating dataset 32


2570

In [26]:
from typing import List
def get_sim_map(q,k,sim_function):
    m = []
    for k_bar in k.split(32):
        m.append(sim_function(q,k_bar))
    return m

def get_sim_maps(q,ks,sim_function = iou):
    all_sim_maps : List[List] = [] # [key_song, query_pos, key_pos]
    for k in tqdm(ks):
        sim_maps : List[List] = [] # [query_pos, key_pos]
        k = k.to_tensor()/128
        for q_bar in q.split(32):
            sim_maps.append(get_sim_map(q_bar,k,sim_function))
        all_sim_maps.append(sim_maps)
    return all_sim_maps

def rank_similiar_parts(all_sim_maps, q_start=0, q_end=None):
    # rank similarity maps
    scores = []
    for i_key,sim_maps in enumerate(all_sim_maps):
        sim_maps = sim_maps[q_start:q_end]
        q_len = len(sim_maps)
        k_len = len(sim_maps[0])
        for q_pos_on_k in range(k_len - q_len): # move query window over key
            score = 0
            for q_pos in range(q_len):
                score += sim_maps[q_pos][q_pos_on_k+q_pos] # sum similarity of query bars

            scores.append((i_key,q_pos_on_k,score.item()/q_len))
    scores = sorted(scores,key=lambda x: x[2],reverse=True)
    return scores

In [42]:
def process_query(q_file,ks,sim_function = iou):
    global all_sim_maps
    # load query
    q = PianoRoll.from_midi(q_file).to_tensor(0,16*32,True)/128

    # get bar to bar similarity maps. [key_song, query_pos, key_pos]
    all_sim_maps = get_sim_maps(q,ks,sim_function)

    # prepare save path
    from os import path
    import json
    q_id = path.basename(q_file).replace('.mid','')
    q_dir = path.dirname(q_file)
    save_dir = path.join(q_dir,'similar',q_id)
    os.makedirs(save_dir,exist_ok=True)

    # prepare metadata
    metadata_path = path.join(save_dir,'metadata.json')

    files_metadata = {}

    for q_start in range(0,16,4): # move query window over query (4 bars at a time)
        q_end = q_start + 4

        # rank similar keys
        scores = rank_similiar_parts(all_sim_maps, q_start=q_start, q_end=q_end)

        # save query
        q_path = path.join(save_dir,f'{q_start}_query.mid')
        PianoRoll.from_tensor(q*128).slice(q_start*32,q_end*32).to_midi(q_path)
        files_metadata[path.basename(q_path)]={
            'title':f'Query {q_id}',
            'start':q_start,
            'end':q_end,
            'score':None
        }

        for rank,(i_key, q_pos_on_k, score) in enumerate(scores[:10]):
            #print(i_key, q_pos_on_k, score,ks[i_key].metadata.name)
            k = ks[i_key]
            k_path = path.join(save_dir,f'{q_start}_{rank}.mid')
            k.slice(q_pos_on_k*32,(q_pos_on_k+q_end-q_start)*32).to_midi(k_path)

            files_metadata[path.basename(k_path)]={
                'title':k.metadata.name,
                'start':q_pos_on_k,
                'end':q_pos_on_k + q_end - q_start,
                'score':score
            }
    
    # save metadata
    metadata = {'files':files_metadata}
    with open(metadata_path,'w') as f:
        json.dump(metadata,f)
    
    # update global metadata
    global_metadata_path = path.join(q_dir,'similar','metadata.json')
    if not path.exists(global_metadata_path):
        global_metadata = {'queries':[]}
    else:
        with open(global_metadata_path) as f:
            global_metadata = json.load(f)
    global_metadata['queries'].append(q_id)
    with open(global_metadata_path,'w') as f:
        json.dump(global_metadata,f)

In [43]:
for i in range(0,32):
    q_file = f'../log/ema_0.9999_2700000/{i}.mid'
    ks = songs.pianorolls
    process_query(q_file,ks,sim_function=iou)


100%|██████████| 2570/2570 [01:01<00:00, 41.78it/s]
100%|██████████| 2570/2570 [01:01<00:00, 41.76it/s]
100%|██████████| 2570/2570 [00:57<00:00, 44.66it/s]
100%|██████████| 2570/2570 [00:57<00:00, 44.86it/s]
100%|██████████| 2570/2570 [00:58<00:00, 44.03it/s]
100%|██████████| 2570/2570 [00:57<00:00, 44.36it/s]
100%|██████████| 2570/2570 [00:56<00:00, 45.80it/s]
100%|██████████| 2570/2570 [01:01<00:00, 41.92it/s]
100%|██████████| 2570/2570 [00:56<00:00, 45.18it/s]
100%|██████████| 2570/2570 [00:57<00:00, 44.78it/s]
100%|██████████| 2570/2570 [00:59<00:00, 43.21it/s]
100%|██████████| 2570/2570 [00:59<00:00, 43.55it/s]
100%|██████████| 2570/2570 [00:59<00:00, 43.03it/s]
100%|██████████| 2570/2570 [00:59<00:00, 42.84it/s]


In [19]:
from os import path
q_id = path.basename(q_file).replace('.mid','')
q_dir = path.dirname(q_file)
save_dir = path.join(q_dir,'similar',q_id)

In [20]:

q_dir,q_id,q_file,save_dir



('../log/ema_0.9999_2700000',
 '10',
 '../log/ema_0.9999_2700000/10.mid',
 '../log/ema_0.9999_2700000/similar/10')

In [72]:
k.slice(q_pos_on_k*32,q_pos_on_k*32+(q_end-q_start))

PianoRoll Bar 048 - 064 of Sia - Move Your Body (Alan Walker Remix) _ Piano Cover by Pianella Piano

In [27]:

def sort_by_similarity(x,ys : 'list[PianoRoll]' ,sim_function = mse):
    scores={} 
    x = x.to_tensor()/128
    for y in tqdm(ys):
        scores[y]=sim_function(x,y.to_tensor(0,x.shape[0],padding=True)/128).item()
    ys = sorted(ys,key=lambda y: scores[y])
    scores = sorted(scores.values())
    return ys, scores, x




# query
query_file = '../log/ema_0.9999_2700000/6.mid'
query = PianoRoll.from_midi(query_file).slice(32*14,32*15)

# save query
p = query_file.replace('.mid','similar')+'/query.mid'
os.makedirs(path.dirname(p),exist_ok=True)
query.to_midi(p)

keys.append(query) # add query to keys to check similarity function

# sort by similarity
ys, scores, x = sort_by_similarity(query,keys,iou)
print(scores[:10],scores[-10:])

# save 10 most similar
for i, y in enumerate(list(reversed(ys))[:20]):
    p = query_file.replace('.mid','similar')+f'/{i}.mid'
    os.makedirs(path.dirname(p),exist_ok=True)
    y.to_midi(p)



print('done')

100%|██████████| 253853/253853 [00:17<00:00, 14686.22it/s]


[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] [0.26604390144348145, 0.2663997411727905, 0.2673889696598053, 0.2691650688648224, 0.2694220244884491, 0.271394819021225, 0.27623701095581055, 0.28327450156211853, 0.29047897458076477, 0.5194606781005859]
done


In [18]:
__name__

'__main__'

In [30]:
a=50
# format to 000
f'{a:03d}'

'050'

In [94]:
get_sim_map(q_bar,q_bar,iou)

[tensor(0.5801)]