In [13]:
from glob import glob
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import random

In [3]:
with open('dictionary_new.pkl', 'rb') as f:
    token2id, id2token = pickle.load(f)

In [10]:
def group_bar(tokens):
    groups = []
    cur_group = []

    assert tokens[0] == 'Bar_None'
    cur_group.append(tokens[0])

    for t in tokens[1:]:
        if t == 'Bar_None':
            groups.append(cur_group)
            cur_group = []
        cur_group.append(t)

    if cur_group:
        groups.append(cur_group)
        
    return groups

def full2sky(tokens):
    return list(filter(lambda t: not t.startswith('Velocity') and not t.startswith('Tempo'), tokens))

In [14]:
BAR_LEN = 160

cond_data = []
target_data = []

for full_path in tqdm(glob('dataset/pop2piano_data/remi_events_full/*/*.pkl')):
    sky_path = full_path.replace('remi_events_full', 'remi_events_sky')

    full_tokens = np.load(full_path, allow_pickle=True)
    sky_tokens = np.load(sky_path, allow_pickle=True)

    full_tokens = [f'{t["name"]}_{t["value"]}' for t in full_tokens]
    sky_tokens = [f'{t["name"]}_{t["value"]}' for t in sky_tokens]
    
    full_groups = group_bar(full_tokens)
    sky_groups = group_bar(sky_tokens)

    assert len(sky_groups) <= len(full_groups)

    while len(sky_groups) < len(full_groups):
        sky_groups.append(full2sky(full_groups[len(sky_groups)]))
        
    tempo_token = next((t for t in full_tokens if t.startswith('Tempo')), None)
    assert tempo_token is not None

    cond_tokens = ['[BOS]', tempo_token]
    for group in sky_groups:
        cond_tokens.append('Track_Skyline')
        cond_tokens += group
    cond_tokens.append('[EOS]')

    target_tokens = []
    for full_group in full_groups:
        target_tokens.append('Track_Midi')
        target_tokens += full_group
        target_tokens += ['[NONE]'] * (BAR_LEN - len(full_group))

    cond_data.append(' '.join(cond_tokens))
    target_data.append(' '.join(target_tokens))

indices = list(range(len(cond_data)))
random.shuffle(indices)
train_size = int(len(indices) * 0.8)
train_indices = indices[:train_size]
valid_indices = indices[train_size:]

with open('cond_train.txt', 'w') as f:
    for i in train_indices:
        f.write(cond_data[i] + '\n')

with open('target_train.txt', 'w') as f:
    for i in train_indices:
        f.write(target_data[i] + '\n')

with open('cond_valid.txt', 'w') as f:
    for i in valid_indices:
        f.write(cond_data[i] + '\n')

with open('target_valid.txt', 'w') as f:
    for i in valid_indices:
        f.write(target_data[i] + '\n')


100%|██████████| 1392/1392 [00:06<00:00, 218.30it/s]


In [2]:
import pickle
with open('/tmp2/b11902010/dmir_lab/diffusion_compose_and_embellish/dictionary_all.pkl', 'rb') as f:
    d = pickle.load(f)
token2id, id2token = d
voc = [id2token[i] for i in range(len(id2token))]
voc = ['[MASK]', '[PAD]', '[NONE]', '[BOS]', '[EOS]'] + voc
token2id = {t: i for i, t in enumerate(voc)}
id2token = {i: t for i, t in enumerate(voc)}

with open('dictionary_new.pkl', 'wb') as f:
    pickle.dump((token2id, id2token), f)

In [3]:
import pickle
with open('dictionary_new.pkl', 'rb') as f:
    d = pickle.load(f)
d

({'[MASK]': 0,
  '[PAD]': 1,
  '[NONE]': 2,
  '[BOS]': 3,
  '[EOS]': 4,
  'Bar_None': 5,
  'Beat_0': 6,
  'Beat_1': 7,
  'Beat_10': 8,
  'Beat_11': 9,
  'Beat_12': 10,
  'Beat_13': 11,
  'Beat_14': 12,
  'Beat_15': 13,
  'Beat_2': 14,
  'Beat_3': 15,
  'Beat_4': 16,
  'Beat_5': 17,
  'Beat_6': 18,
  'Beat_7': 19,
  'Beat_8': 20,
  'Beat_9': 21,
  'Chord_A#_+': 22,
  'Chord_A#_/o7': 23,
  'Chord_A#_7': 24,
  'Chord_A#_M': 25,
  'Chord_A#_M7': 26,
  'Chord_A#_m': 27,
  'Chord_A#_m7': 28,
  'Chord_A#_o': 29,
  'Chord_A#_o7': 30,
  'Chord_A#_sus2': 31,
  'Chord_A#_sus4': 32,
  'Chord_A_+': 33,
  'Chord_A_/o7': 34,
  'Chord_A_7': 35,
  'Chord_A_M': 36,
  'Chord_A_M7': 37,
  'Chord_A_m': 38,
  'Chord_A_m7': 39,
  'Chord_A_o': 40,
  'Chord_A_o7': 41,
  'Chord_A_sus2': 42,
  'Chord_A_sus4': 43,
  'Chord_B_+': 44,
  'Chord_B_/o7': 45,
  'Chord_B_7': 46,
  'Chord_B_M': 47,
  'Chord_B_M7': 48,
  'Chord_B_m': 49,
  'Chord_B_m7': 50,
  'Chord_B_o': 51,
  'Chord_B_o7': 52,
  'Chord_B_sus2': 53,
  'C