In [1]:
import dataset as dtst
import torch
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
# Load the model
from model import DisentangleVAE
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DisentangleVAE.init_model(device)

# load model parameters
# pytorch uses .pt file to save model parameters.
model_path = 'result/models/disvae-nozoth_epoch.pt'  
# setting cuda if cuda is available, which will speed up the computation.
model.load_model(model_path, map_location=device)

In [3]:
# Load the dataset
shift_low = -6
shift_high = 6
num_bar = 2
contain_chord = True
fns = dtst.collect_data_fns()
dataset = dtst.wrap_dataset(fns, np.arange(len(fns)), shift_low, shift_high,
                            num_bar=num_bar, contain_chord=contain_chord)
loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

The folder contains 886 .npz files.
Selected 858 files, all are in duple meter.


In [4]:
# Generate a random chord
# For now, uniformly sample a triad in root position. At actual training time, consider basing the sample distribution on noised frequencies. 
def gen_chord():
    out = torch.zeros(8, 36)
    for i in range(out.shape[0]):
        out[i] = gen_chord_step()
    return out

def gen_chord_step():
    root = torch.randint(high=12, size=(1, ))
    bass = 0
    
    out = torch.zeros(1, 36)
    out[0, root] = 1
    out[0, bass + 24] = 1
    
    out[0, 12 + root] = 1
    out[0, 12 + (root + 7) % 12] = 1

    if random.random() < 0.5:
        # Minor
        out[0, 12 + (root + 4) % 12] = 1
    else:
        out[0, 12 + (root + 3) % 12] = 1

    return out

gen_chord()


tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
         0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
         0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0.,
         0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0

In [5]:
# Nearest neighbours
# An naive rule set to apply HST. Consider using a learned one.
def hst_nn(notes, chord_mat):
    notes_out = []
    for i in range(chord_mat.shape[0]):
        notes_out += apply_nn_step(notes, chord_mat[i: i+1, :], i)
    return notes_out

def apply_nn_step(notes, chord, step, quant_size=1):
    # Filter out notes starting in the window
    filtered_notes = []
    for note in notes:
        if note.start >= quant_size * step and note.start <= quant_size * (step + 1):
            filtered_notes.append(note)
    
    for note in filtered_notes:
        pitch_dist = 999
        pit_out = 0
        for pit in range(20, 100):
            # Is this pitch included in the chroma?
            if chord[0, 12 + pit % 12] == 0:
                continue
            dist = abs(note.pitch - pit)
            if dist < pitch_dist:
                pitch_dist = dist
                pit_out = pit
        note.pitch = pit_out
    return filtered_notes

melody, pr, pr_mat, ptree, chord = dataset[12]
_, notes = model.decoder.grid_to_pr_and_notes(ptree.squeeze(0).astype(int))
hst_nn(notes, gen_chord())

[Note(start=3.500000, end=4.000000, pitch=74, velocity=100),
 Note(start=4.000000, end=5.500000, pitch=53, velocity=100),
 Note(start=4.000000, end=4.500000, pitch=80, velocity=100),
 Note(start=4.000000, end=5.500000, pitch=53, velocity=100),
 Note(start=4.000000, end=4.500000, pitch=80, velocity=100),
 Note(start=4.250000, end=5.500000, pitch=61, velocity=100),
 Note(start=4.500000, end=5.250000, pitch=65, velocity=100),
 Note(start=4.500000, end=4.750000, pitch=80, velocity=100),
 Note(start=4.750000, end=6.000000, pitch=73, velocity=100),
 Note(start=5.000000, end=5.250000, pitch=75, velocity=100),
 Note(start=5.000000, end=5.250000, pitch=75, velocity=100),
 Note(start=5.500000, end=5.750000, pitch=87, velocity=100),
 Note(start=5.750000, end=6.000000, pitch=87, velocity=100),
 Note(start=6.000000, end=7.250000, pitch=56, velocity=100),
 Note(start=6.000000, end=7.250000, pitch=85, velocity=100),
 Note(start=6.000000, end=7.250000, pitch=56, velocity=100),
 Note(start=6.000000, en

In [6]:
def prec_recall_f1(pred, ref):
    hit = 0
    for note in pred:
        start = note.start
        end = note.end
        pit = note.pitch
        for note_ref in ref:
            if note_ref.start == start and note_ref.end == end and note_ref.pitch == pit:
                hit += 1
                break
    n_pred = len(pred)
    n_true = len(ref)
    prec = hit / n_pred
    recall = hit / n_true
    f1 = 2 / (1 / prec + 1 / recall)
    return prec, recall, f1

def run_trial():
    n_trial = 5
    precs = []
    recalls = []
    f1s = []

    iter = 0
    while True:
        for batch in loader:
            iter += 1
            if iter > n_trial:
                break

            melody, pr, pr_mat, ptree, _ = batch
            pr_mat = pr_mat[0]
            chord = gen_chord()
            chord_ = chord.unsqueeze(0)
            _, notes = model.decoder.grid_to_pr_and_notes(ptree[0].squeeze(0).numpy().astype(int))

            polydis_out = model.swap(pr_mat.float(), pr_mat.float(), chord_.float(), chord_.float(), fix_rhy=True, fix_chd=False)
            _, notes_polydis = model.decoder.grid_to_pr_and_notes(polydis_out.squeeze(0).astype(int))
            notes_rule = hst_nn(notes, chord)
            
            prec, recall, f1 = prec_recall_f1(notes_rule, notes_polydis)
            precs.append(prec)
            recalls.append(recall)
            f1s.append(f1s)
        
        if iter > n_trial:
            break
    return precs, recalls, f1s

prec, recall, f1 = run_trial()

In [7]:
print(prec)

[0.4262295081967213, 0.16129032258064516, 0.4186046511627907, 0.32, 0.2916666666666667]
