In [None]:
import os
import argparse
import datetime
import random
import copy

import pretty_midi
import torch
from torch.utils.tensorboard import SummaryWriter

from dataset import LoaderWrapper
from models.edit_musebert import EditMuseBERT
from utils.data_utils import prep_batch, prep_batch_inference, onset_pitch_duration_prettymidi_notes
import utils.rules

In [None]:
# Device
torch.manual_seed(21)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Make loaders
wrapper = LoaderWrapper(1, 1)
train_loader = wrapper.get_loader(split='train')
dev_loader = wrapper.get_loader(split='dev')
print(f'Training laoder size: {len(train_loader)}')
print(f'Dev laoder size: {len(dev_loader)}')
print(f'Dev #songs: {- dev_loader.dataset.split_idx}')

# Set the rule set
# wrapper.collate.rule = utils.rules.identity

# Setup training
if True:
    model = EditMuseBERT(device, wrapper,n_edit_types=wrapper.collate.editor.pitch_range, n_decoder_layers=2).to(device)
else:
    model = EditMuseBERT(device, wrapper,n_edit_types=wrapper.collate.editor.pitch_range).to(device)

# Load from checkpoints
checkpoint = '../results/checkpoints/debug/batchsize32_lr1e-05_0_4999_100.bin'
print(f'loading checkpoint: {checkpoint}')
loaded = torch.load(checkpoint, map_location=device)
model.load_state_dict(loaded['model_state_dict'])

: 

In [None]:
# Make a chord prog
chd_types = {
    'M3': [0, 4, 7],
    'm3': [0, 3, 7],
    'A3': [0, 4, 8],
    'd3': [0, 3, 6],
    'M7': [0, 4, 7, 11],
    'm7': [0, 3, 7, 10],
    'D7': [0, 4, 7, 10],
}

def make_chd(root, chroma, bass):
    # root / bass: indices 0-11. Remember bass is relative
    # chroma: list of indices 0-11 (absolute)
    out = [0 for _ in range(36)]
    out[root] = 1
    out[bass + 24] = 1
    for c in chroma:
        out[c + 12] = 1 
    return out

def make_prog(prog_text):
    out = []
    for chd_text in prog_text:
        root = chd_text[0]
        bass = chd_text[2]
        chroma = chd_types[chd_text[1]]
        out.append(make_chd(root, chroma, bass))
    return out

cmat = make_prog([
    [9, 'm3', 0],
    [9, 'm3', 7],
    [5, 'M3', 0],
    [5, 'M3', 0],
    [5, 'M3', 0],
    [7, 'M3', 0],
    [8, 'd3', 0],
    [9, 'M3', 0],
    ])
cmat = torch.tensor(cmat).to(device).float().unsqueeze(0)

: 

In [None]:
# Get a texture from PoP909
_, _, pr_mat, ptree, _ = train_loader.dataset.polydis_dataset[random.randint(0, len(train_loader.dataset.polydis_dataset) - 1)]
pr_mat = torch.tensor(pr_mat).to(device).float()
ptree = torch.tensor(ptree).to(device)[0]

# Polydis oracle
ptree_polydis = wrapper.collate.polydis.swap(pr_mat, pr_mat, cmat, cmat, fix_rhy=True, fix_chd=False)[0]
_, notes_polydis = wrapper.collate.polydis.decoder.grid_to_pr_and_notes(ptree_polydis.astype(int))

# Original notes
_, notes_original = wrapper.collate.polydis.decoder.grid_to_pr_and_notes(ptree.numpy().astype(int))
notes_original_ = copy.deepcopy(notes_original)

# Apply rule-based approximations to the input notes
notes_rule = wrapper.collate.rule(notes_original_, cmat[0])
notes_rule_ = copy.deepcopy(notes_rule)

# Convert notes for MuseBERT input
notes_out_line, _, _, _, _ = wrapper.collate.editor.get_edits(notes_rule, notes_polydis)
atr, _, cpt_rel, _, _, length = wrapper.collate.converter.convert(notes_out_line)

# Run the edit models
atr = torch.tensor(atr).to(device).unsqueeze(0)
cpt_rel = torch.tensor(cpt_rel).to(device).unsqueeze(0)
length = [length]
inference_out = model.inference(cmat, [atr, cpt_rel, length], return_context_inserts=True)
notes_pred = inference_out[0][0]
notes_context = inference_out[1][0]
notes_insert = inference_out[2][0]

# Write output midis
notes_pred = onset_pitch_duration_prettymidi_notes(notes_pred)
notes_context = onset_pitch_duration_prettymidi_notes(notes_context)
notes_insert = onset_pitch_duration_prettymidi_notes(notes_insert)

def write_midi(note_seqs, names):
    out_dir = '../results/demo_out'
    mid = pretty_midi.PrettyMIDI()
    for i, seq in enumerate(note_seqs):
        inst = pretty_midi.Instrument(program=0, name=names[i])
        inst.notes = seq
        mid.instruments.append(inst)
    mid.write(f'{out_dir}/out.mid')

write_midi([notes_original_, notes_polydis, notes_rule_, notes_pred, notes_context, notes_insert], ['original', 'polydis', 'rules', 'edit_final', 'edit_enc', 'edit_insert'])
    


: 

: 

: 