In [9]:
import torch
from data import run_length_encoding
from data import event_codec
from data import vocabularies
from data import spectrograms
from data import note_sequences
import note_seq
import numpy as np
import pretty_midi
import sklearn
from typing import Callable, Mapping, Optional, Sequence, Tuple
import editdistance
import mir_eval
import librosa
from utils import _audio_to_frames, AttrDict
from model.ListenAttendSpell import ListenAttendSpell 
from tqdm import tqdm
import yaml
import functools
from evaluation import *
import os 

In [10]:
with open('config/model.yaml', 'r') as f:
    file_config = yaml.safe_load(f)
config = AttrDict(file_config)
config.training.beam_size = 2
codec = event_codec.Codec(
        max_shift_steps=300,
        steps_per_second=100,
        event_ranges=[
                event_codec.EventRange('pitch', note_seq.MIN_MIDI_PITCH,
                            note_seq.MAX_MIDI_PITCH),
                event_codec.EventRange('velocity', 0, 127)
        ])
vocab = vocabularies.vocabulary_from_codec(codec)
spectrogram_config = spectrograms.SpectrogramConfig()

In [11]:
model_states = torch.load("/Users/donghyunlee/Desktop/encoder_decoder_model/test_data/las_model.epoch6399.chkpt", map_location=torch.device('cpu'))
model = ListenAttendSpell.load_model(model_states)
model.eval()

ListenAttendSpell(
  (encoder): Encoder(
    (rnn): LSTM(512, 1024, num_layers=4, batch_first=True, dropout=0.2, bidirectional=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(560, 512)
    (rnn): ModuleList(
      (0): LSTMCell(2560, 2048)
    )
    (attention): DotProductAttention()
    (linear): Sequential(
      (0): Linear(in_features=4096, out_features=2048, bias=True)
      (1): Tanh()
      (2): Linear(in_features=2048, out_features=560, bias=True)
    )
    (loss): CrossEntropyLoss()
  )
)

In [12]:
audio_path = '/Users/donghyunlee/Desktop/encoder_decoder_model/test_data/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.flac'

In [13]:
segmented_frames, segmented_times = segmentize_audio(audio_path, spectrogram_config)
segmented_spec=[]
for i in segmented_frames:
    j = make_spectrogram(i, spectrogram_config)
    segmented_spec.append(j)

In [14]:
pred_ns = make_pred_ns(segmented_spec, segmented_times, codec, vocab, model, config)

100%|██████████| 607/607 [1:06:45<00:00,  6.60s/it]


In [15]:
midi_path = '/Users/donghyunlee/Desktop/encoder_decoder_model/test_data/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi'
ref_ns = note_seq.midi_file_to_note_sequence(midi_path)
note_sequences.validate_note_sequence(ref_ns)
ref_ns = note_seq.apply_sustain_control_changes(ref_ns)
output_1 = evaluation_mir_eval(ref_ns, pred_ns)

In [16]:
output_1

{'Onset precision': 0.9843219231774236,
 'Onset recall': 0.9543957436027363,
 'Onset F1': 0.9691278621044507,
 'Onset + offset precision': 0.7123072903057225,
 'Onset + offset recall': 0.6906511274385609,
 'Onset + offset F1': 0.7013120658605608,
 'Onset + velocity precision': 0.9619806637052521,
 'Onset + velocity recall': 0.932733721814036,
 'Onset + velocity F1': 0.9471314638538719,
 'Onset + offset + velocity precision': 0.6925790436373138,
 'Onset + offset + velocity recall': 0.6715226754497087,
 'Onset + offset + velocity F1': 0.6818883457679443}