In [None]:
import torch
from virtuoso import model as modelzoo
from virtuoso import model_parameters as param
from omegaconf import OmegaConf
import _pickle as pickle
import sys
import yaml
from virtuoso import parser

sys.modules['model_parameters'] = param


In [None]:
args = parser.get_parser()
args

In [None]:
conf = OmegaConf.create(dict)
with open('isgn_param.yml', 'w') as f:
    yaml.dump(dict, f, default_flow_style=False)

In [None]:
model = modelzoo.ISGN(conf, 'cpu')
model

In [None]:
with open('isgn_param.yml', 'r') as f:
    yaml_obj = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
conf

In [None]:
with open("isgn_param.dat", 'rb') as f:
    param = pickle.load(f)
with open("prime_isgn_best.pth.tar", 'rb') as f:
    weights = torch.load(f)

In [None]:
dict = {'nn_params': {}, 'training_params':{}}
for key in vars(param):
    if isinstance(getattr(param,key), param.Param):
        dict['nn_params'][key] = {}
        for subkey in vars(getattr(param, key)):
            dict['nn_params'][key][subkey] = getattr(getattr(param, key), subkey)
    elif key == 'training_args':
        dict['training_params'][key] = {}
        for subkey in vars(param.training_args):
            dict['training_params'][key][subkey] = getattr(param.training_args, subkey)
    else:
        dict['nn_params'][key] = getattr(param, key)

dict


In [None]:
test = dict['training_args']
for 

In [None]:


if 'isgn' in args.modelCode:
    MODEL = modelzoo.ISGN(NET_PARAM, device).to(device)
elif 'han' in args.modelCode:
    if 'ar' in args.modelCode:
        step_by_step = True
    else:
        step_by_step = False
    MODEL = modelzoo.HAN_Integrated(NET_PARAM, device, step_by_step).to(device)
elif 'trill' in args.modelCode:
    MODEL = modelzoo.TrillRNN(NET_PARAM, device).to(device)
else:
    print('Error: Unclassified model code')
    # Model = modelzoo.HAN_VAE(NET_PARAM, device, False).to(device)

In [4]:
from virtuoso.inference import get_input_from_xml
from virtuoso.utils import load_dat

input_keys = ('midi_pitch', 'duration', 'beat_importance', 'measure_length', 'qpm_primo',
                          'following_rest', 'distance_from_abs_dynamic', 'distance_from_recent_tempo',
                          'beat_position', 'xml_position', 'grace_order', 'preceded_by_grace_note',
                          'followed_by_fermata_rest', 'pitch', 'tempo', 'dynamic', 'time_sig_vec',
                          'slur_beam_vec',  'composer_vec', 'notation', 'tempo_primo', 'note_location')
output_keys = ('beat_tempo', 'velocity', 'onset_deviation', 'articulation', 'pedal_refresh_time',
                            'pedal_cut_time', 'pedal_at_start', 'pedal_at_end', 'soft_pedal',
                            'pedal_refresh', 'pedal_cut')
graph_keys = ['onset', 'forward', 'melisma', 'rest', 'voice']
stats = load_dat('dataset/stat.dat')
score, input, edges, note_locations = get_input_from_xml('test_pieces/bps_5_1/musicxml_cleaned.musicxml', 'Beethoven', input_keys, graph_keys, stats['stats'])

In [None]:
input

In [5]:
old_stat = load_dat('training_data_stat.dat')

In [None]:
stats['stats']

In [6]:
import math
from virtuoso.pyScoreParser.utils import binary_index
from virtuoso.pyScoreParser.feature_utils import time_signature_to_vector, pitch_into_vector, cal_beat_importance, note_notation_to_vector, composer_name_to_vec
from virtuoso.pyScoreParser.xml_utils import cal_total_xml_length
import virtuoso.pyScoreParser.xml_direction_encoding as dir_enc


TEM_EMB_TAB=  dir_enc.define_tempo_embedding_table()
DYN_EMB_TAB = dir_enc.define_dynamic_embedding_table()
def extract_score_features(xml_notes, measure_positions, beats=None, qpm_primo=0, vel_standard=False):
    xml_length = len(xml_notes)
    # melody_notes = extract_melody_only_from_notes(xml_notes)
    features = []

    if qpm_primo == 0:
        qpm_primo = xml_notes[0].state_fixed.qpm
    tempo_primo_word = dir_enc.direction_words_flatten(xml_notes[0].tempo)
    if tempo_primo_word:
        tempo_primo = dir_enc.dynamic_embedding(tempo_primo_word, dir_enc.define_tempo_embedding_table(), 5)
        tempo_primo = tempo_primo[0:2]
    else:
        tempo_primo = [0, 0]

    cresc_words = ['cresc', 'decresc', 'dim']

    onset_positions = list(set([note.note_duration.xml_position for note in xml_notes]))
    onset_positions.sort()
    total_length = cal_total_xml_length(xml_notes)
    
    class NoteLocation:
        def __init__(self):
            self.beat = 0
    class MusicFeature:
        def __init__(self):
            self.midi_pitch =0
            self.note_location = NoteLocation()

    for i in range(xml_length):
        note = xml_notes[i]
        feature = MusicFeature()
        note_position = note.note_duration.xml_position
        measure_index = binary_index(measure_positions, note_position)
        if measure_index+1 < len(measure_positions):
            measure_length = measure_positions[measure_index+1] - measure_positions[measure_index]
            # measure_sec_length = measure_seocnds[measure_index+1] - measure_seocnds[measure_index]
        else:
            measure_length = measure_positions[measure_index] - measure_positions[measure_index-1]
            # measure_sec_length = measure_seocnds[measure_index] - measure_seocnds[measure_index-1]
        feature.midi_pitch = note.pitch[1]
        feature.pitch = pitch_into_vector(note.pitch[1])
        feature.duration = note.note_duration.duration / note.state_fixed.divisions

        beat_position = (note_position - measure_positions[measure_index]) / measure_length
        feature.beat_position = beat_position
        feature.beat_importance = cal_beat_importance(beat_position, note.tempo.time_numerator)
        feature.measure_length = measure_length / note.state_fixed.divisions
        feature.note_location.voice = note.voice
        feature.note_location.onset = binary_index(onset_positions, note_position)
        feature.xml_position = note.note_duration.xml_position / total_length
        feature.grace_order = note.note_duration.grace_order
        feature.is_grace_note = int(note.note_duration.is_grace_note)
        feature.preceded_by_grace_note = int(note.note_duration.preceded_by_grace_note)
        # feature.melody = int(note in melody_notes)

        feature.slur_beam_vec = [int(note.note_notations.is_slur_start), int(note.note_notations.is_slur_continue),
                                 int(note.note_notations.is_slur_stop), int(note.note_notations.is_beam_start),
                                 int(note.note_notations.is_beam_continue), int(note.note_notations.is_beam_stop)]

        feature.time_sig_vec = time_signature_to_vector(note.tempo.time_signature)
        feature.following_rest = note.following_rest_duration / note.state_fixed.divisions
        feature.followed_by_fermata_rest = int(note.followed_by_fermata_rest)

        dynamic_words = dir_enc.direction_words_flatten(note.dynamic)
        tempo_words = dir_enc.direction_words_flatten(note.tempo)

        feature.dynamic = dir_enc.dynamic_embedding(dynamic_words, DYN_EMB_TAB, len_vec=4)
        if feature.dynamic[1] != 0:
            for rel in note.dynamic.relative:
                for word in cresc_words:
                    if word in rel.type['type'] or word in rel.type['content']:
                        rel_length = rel.end_xml_position - rel.xml_position
                        if rel_length == float("inf") or rel_length == 0:
                            rel_length = note.state_fixed.divisions * 10
                        ratio = (note_position - rel.xml_position) / rel_length
                        feature.dynamic[1] *= (ratio+0.05)
                        break
        if note.dynamic.cresciuto:
            feature.cresciuto = (note.dynamic.cresciuto.overlapped +1) / 2
            if note.dynamic.cresciuto.type == 'diminuendo':
                feature.cresciuto *= -1
        else:
            feature.cresciuto = 0
        feature.dynamic.append(feature.cresciuto)
        feature.tempo = dir_enc.dynamic_embedding(tempo_words, TEM_EMB_TAB, len_vec=5)
        feature.notation = note_notation_to_vector(note)
        feature.qpm_primo = math.log(qpm_primo, 10)
        feature.tempo_primo = tempo_primo
        feature.note_location.measure = note.measure_number-1
        feature.distance_from_abs_dynamic = (note.note_duration.xml_position - note.dynamic.absolute_position) / note.state_fixed.divisions
        feature.distance_from_recent_tempo = (note_position - note.tempo.recently_changed_position) / note.state_fixed.divisions
        # print(feature.dynamic + feature.tempo)
        features.append(feature)

    return features


In [7]:
features = extract_score_features(score.xml_notes, score.measure_positions)

In [None]:
for key in output_keys:
    print(stats['stats'][key]['mean'])

In [8]:
def read_xml_to_array(xml_notes, features, composer_name, means, stds):
    composer_vec = composer_name_to_vec(composer_name)

    test_x = []
    note_locations = []
    for feat in features:
        temp_x = [(feat.midi_pitch - means[0][0]) / stds[0][0], (feat.duration - means[0][1]) / stds[0][1],
                    (feat.beat_importance-means[0][2])/stds[0][2], (feat.measure_length-means[0][3])/stds[0][3],
                   (feat.qpm_primo - means[0][4]) / stds[0][4],(feat.following_rest - means[0][5]) / stds[0][5],
                    (feat.distance_from_abs_dynamic - means[0][6]) / stds[0][6],
                  (feat.distance_from_recent_tempo - means[0][7]) / stds[0][7] ,
                  feat.beat_position, feat.xml_position, feat.grace_order,
                    feat.preceded_by_grace_note, feat.followed_by_fermata_rest] \
                   + feat.pitch + feat.tempo + feat.dynamic + feat.time_sig_vec + feat.slur_beam_vec + composer_vec + feat.notation + feat.tempo_primo
        # temp_x.append(feat.is_beat)
        test_x.append(temp_x)
    return test_x


In [9]:
converted = read_xml_to_array(score.xml_notes, features, 'Beethoven', old_stat[0], old_stat[1])

In [10]:
old = np.asarray(converted)

In [11]:
new = input.squeeze().cpu().numpy()

In [12]:
(new[0,:5] -old[0,:5]).tolist()

[0.03164851910161659,
 -0.03590252749381295,
 0.006223060692309712,
 -0.01259772423479208,
 0.04386550496198205]

In [16]:
new.shape

(2072, 78)