In [1]:

import numpy as np
import os
import torch
import json
from basismixer import make_datasets
from basismixer.utils import load_pyc_bz, save_pyc_bz
from torch.utils.data import DataLoader, ConcatDataset
from basismixer.predictive_models import (construct_model,
                                          FullPredictiveModel,
                                          SupervisedTrainer,
                                          MSELoss)
from partitura import save_performance_midi, load_musicxml
from basismixer.performance_codec import get_performance_codec
from partitura.score import unfold_part_maximal, merge_parts
import partitura.musicanalysis as ma




In [2]:
model_config = [
    dict(
        onsetwise=False,
        basis_functions=[
            "polynomial_pitch_feature",
            "articulation_feature",
            "duration_feature",
            "fermata_feature",
            "metrical_strength_feature",
            "time_signature_feature",
            "relative_score_position_feature",
        ],
        parameter_names=["velocity_dev", "velocity_trend", "timing", "articulation_log"],
        seq_len=1,
        model=dict(
            constructor=["basismixer.predictive_models", "FeedForwardModel"],
            args=dict(hidden_size=128),
        ),
        train_args=dict(
            optimizer=["Adam", dict(lr=1e-4)],
            epochs=100,
            save_freq=10,
            early_stopping=10,
            batch_size=1000,
        ),
    ),
]

# Make datasets

In [3]:
dataset_fn = 'data/data.pyc.bz'
if dataset_fn is not None and os.path.exists(dataset_fn):
    datasets = load_pyc_bz(dataset_fn)
else:
    datasets =  make_datasets(model_config, 'asap-dataset-main', 'asap')
    if dataset_fn is not None:
        save_pyc_bz(datasets, dataset_fn)

# Training Model

In [4]:

def jsonize_dict(input_dict):
    out_dict = dict()
    for k, v in input_dict.items():
        if isinstance(v, np.ndarray):
            out_dict[k] = v.tolist()
        elif isinstance(v, dict):
            out_dict[k] = jsonize_dict(v)
        else:
            out_dict[k] = v
    return out_dict

def build_model(config, in_names, out_names, out_dir):
    model_cfg = config['model'].copy()
    model_cfg['args']['input_names'] = in_names
    model_cfg['args']['input_size'] = len(in_names)
    model_cfg['args']['output_names'] = out_names
    model_cfg['args']['output_size'] = len(out_names)
    model_cfg['args']['input_type'] = 'onsetwise' if config['onsetwise'] else 'notewise'
    model_name = ('-'.join(out_names) +
                  '-' + ('onsetwise' if config['onsetwise'] else 'notewise'))
    model_out_dir = os.path.join(out_dir, model_name)
    if not os.path.exists(model_out_dir):
        os.mkdir(model_out_dir)
    
    config_out = os.path.join(model_out_dir, 'config.json')
    json.dump(jsonize_dict(model_cfg),
              open(config_out, 'w'),
              indent=2)
    model = construct_model(model_cfg)

    return model, model_out_dir

RNG = np.random.RandomState(1984)

def split_datasets(datasets, valid_size=0.1):

    n_pieces = len(datasets)

    dataset_idx = np.arange(n_pieces)
    RNG.shuffle(dataset_idx)
    len_valid = np.maximum(int(n_pieces * valid_size), 1)

    valid_idxs = dataset_idx[:len_valid]
    train_idxs = dataset_idx[len_valid:]

    return (ConcatDataset([datasets[i] for i in train_idxs]),
            ConcatDataset([datasets[i] for i in valid_idxs]))


def train_model(model, train_set, valid_set,
                config, out_dir):
    batch_size = config['train_args'].pop('batch_size')
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True)
    valid_loader = DataLoader(valid_set,
                              batch_size=batch_size,
                              shuffle=False)

    loss = MSELoss()

    optim_name, optim_args = config['train_args']['optimizer']
    optim = getattr(torch.optim, optim_name)
    config['train_args']['optimizer'] = optim(model.parameters(), **optim_args)
    train_args = config['train_args']
    train_args.pop('seq_len', None)
    trainer = SupervisedTrainer(model=model,
                                train_loss=loss,
                                valid_loss=loss,
                                train_dataloader=train_loader,
                                valid_dataloader=valid_loader,
                                out_dir=out_dir,
                                **config['train_args'])
    trainer.train()


In [None]:

models = []
out_dir = 'models'
for (dataset, in_names, out_names), config in zip(datasets, model_config):
    model, model_out_dir = build_model(config, in_names, out_names, out_dir)
    train_set, valid_set = split_datasets(dataset)
    train_model(model, train_set, valid_set, config, model_out_dir)   
    models.append(model)

In [14]:
models

[FeedForwardModel(
   (hidden_layers): Sequential(
     (0): Linear(in_features=43, out_features=128, bias=True)
     (1): ReLU()
   )
   (output): Linear(in_features=128, out_features=4, bias=True)
 )]

# Generate performance

In [5]:
xml_fn = 'asap-dataset-main/Bach/Prelude/bwv_846/xml_score.musicxml'
models_dir = 'models' 

In [6]:
def load_model(models_dir):
    models = []
    for f in os.listdir(models_dir):
        path = os.path.join(models_dir, f)
        if os.path.isdir(path):
            model_config = json.load(open(os.path.join(path, 'config.json')))
            params = torch.load(os.path.join(path, 'best_model.pth'), 
                                map_location=torch.device('cpu'))['state_dict']
        
            model = construct_model(model_config, params)
            models.append(model)
    
    output_names = list(set([name for out_name in [m.output_names for m in models] for name in out_name]))
    input_names = list(set([name for in_name in [m.input_names for m in models] for name in in_name]))
    input_names.sort()
    output_names.sort()

    default_values = dict(
        velocity_trend=64,
        velocity_dev=0,
        beat_period_standardized=0,
        timing=0,
        articulation_log=0,
        beat_period_mean=0.5,
        beat_period_std=0.1)
    all_output_names = list(default_values.keys())
    full_model = FullPredictiveModel(models, input_names,
                                     all_output_names, default_values)

    return full_model, output_names
            
model, predicted_parameter_names = load_model(models_dir)

In [13]:
def compute_basis_from_xml(xml_filename, input_names):
    # Load the musicxml file and preprocess the part
    part = preprocess_part(xml_filename)

    # Compute basis functions
    basis = compute_basis_functions(part, input_names)

    return basis, part

def preprocess_part(xml_filename):
    part = load_musicxml(xml_filename, force_note_ids=True)
    part = merge_parts(part)
    part = unfold_part_maximal(part)
    return part

def compute_basis_functions(part, input_names):
    unique_input_names = list(set([name.split('.')[0] for name in input_names]))
    _basis, bf_names = ma.make_note_feats(part, unique_input_names)

    basis = np.zeros((len(_basis), len(input_names)))
    for i, name in enumerate(input_names):
        if name in bf_names:
            basis[:, i] = _basis[:, bf_names.index(name)]

    return basis
    
basis, part = compute_basis_from_xml(xml_fn, model.input_names)

In [9]:

def post_process_predictions(predictions):
    max_articulation = 1.5
    max_bps = 1
    max_timing = 0.2
    predictions['articulation_log'] = np.clip(predictions['articulation_log'],
                                              -max_articulation, max_articulation)
    predictions['velocity_dev'] = np.clip(predictions['velocity_dev'], 0, 0.8)
    predictions['beat_period_standardized'] = np.clip(predictions['beat_period_standardized'],
                                                      -max_bps, max_bps)
    predictions['timing'] = np.clip(predictions['timing'],
                                    -max_timing, max_timing)
    predictions['velocity_trend'][predictions['velocity_trend'] > 0.8] = 0.8
    

# Score positions for each note in the score
score_onsets = part.beat_map([n.start.t for n in part.notes_tied])

# make predictions
predictions = model.predict(basis, score_onsets)
post_process_predictions(predictions)
predictions
# print('Field names:', predictions.dtype.names)

array([(0.567938  , 0.10323137, 0., -2.74854200e-03, -0.16606605, 0.5, 0.1),
       (0.55248684, 0.09134048, 0., -2.24622302e-02,  0.22410601, 0.5, 0.1),
       (0.5225154 , 0.02410828, 0., -2.40946636e-02,  1.1968677 , 0.5, 0.1),
       (0.5413121 , 0.03046213, 0., -2.12163199e-02,  1.1614943 , 0.5, 0.1),
       (0.5609597 , 0.03551603, 0., -1.99739616e-02,  1.1349216 , 0.5, 0.1),
       (0.5345293 , 0.02658806, 0., -2.63696108e-02,  1.0157868 , 0.5, 0.1),
       (0.5543354 , 0.03373576, 0., -2.38831751e-02,  0.93688   , 0.5, 0.1),
       (0.5729949 , 0.03966352, 0., -2.25725770e-02,  0.8429782 , 0.5, 0.1),
       (0.5717347 , 0.08564179, 0., -2.05840282e-02,  0.13284898, 0.5, 0.1),
       (0.57571423, 0.09287181, 0., -2.31509879e-02, -0.20693976, 0.5, 0.1),
       (0.5356383 , 0.02845541, 0., -2.63335891e-02,  1.0759798 , 0.5, 0.1),
       (0.5504419 , 0.03519005, 0., -2.42968779e-02,  1.0951544 , 0.5, 0.1),
       (0.5621021 , 0.03961174, 0., -2.16676462e-02,  1.0192068 , 0.5, 0.1),

In [10]:
perf_codec = get_performance_codec(model.output_names)
predicted_ppart = perf_codec.decode(part, predictions)

In [11]:
midi_fn = 'bach_prelude_846.mid'
save_performance_midi(predicted_ppart, midi_fn)

In [12]:
predicted_ppart.note_array()

array([( 0.        , 0.8912697 ,     0,  856, 60, 59, 0, 1, 'n0-1'),
       ( 0.14471368, 1.0220466 ,   139,  981, 64, 59, 0, 1, 'n1-1'),
       ( 0.27134612, 0.28655177,   260,  276, 67, 63, 0, 1, 'n2-1'),
       ( 0.39346778, 0.27961123,   378,  268, 72, 65, 0, 1, 'n3-1'),
       ( 0.5172254 , 0.27450827,   497,  263, 76, 67, 0, 1, 'n4-1'),
       ( 0.6486211 , 0.25275066,   623,  242, 67, 65, 0, 1, 'n6-1'),
       ( 0.7711346 , 0.23929796,   740,  230, 72, 66, 0, 1, 'n7-1'),
       ( 0.894824  , 0.22421865,   859,  215, 76, 68, 0, 1, 'n8-1'),
       ( 1.0178355 , 1.0964568 ,   977, 1053, 60, 62, 0, 1, 'n9-1'),
       ( 1.1454026 , 0.7580764 ,  1100,  727, 64, 61, 0, 1, 'n10-1'),
       ( 1.2735851 , 0.26351917,  1223,  253, 67, 64, 0, 1, 'n11-1'),
       ( 1.3965484 , 0.26704493,  1341,  256, 72, 65, 0, 1, 'n12-1'),
       ( 1.5189191 , 0.25335053,  1458,  243, 76, 66, 0, 1, 'n13-1'),
       ( 1.6446931 , 0.20847109,  1579,  200, 67, 64, 0, 1, 'n15-1'),
       ( 1.7664402 , 0.182739