In [None]:
import torch
import librosa
from hppnet.transcriber import HPPNet
from hppnet.midi import save_midi
from hppnet.decoding import extract_notes
import numpy as np
from mir_eval.util import midi_to_hz

SAMPLE_RATE = 16000
HOP_LENGTH = SAMPLE_RATE * 20 // 1000
MIN_MIDI = 21

In [None]:
model_path = 'checkpoints/model-142000-mestro-maps_mus-my-f0.912n0.973.pt'
device = 'cuda'
model = torch.load(model_path, map_location=device).eval()
model.inference_mode = True

In [None]:
audio_path = '恋×シンアイ彼女/水月陵 - flower -piano arrangement-.flac'
audio, sr = librosa.load(audio_path,sr=SAMPLE_RATE,mono=True)

audio = torch.tensor(audio)
audio_length = len(audio)
audio = audio.reshape(-1, audio.shape[-1])

onset_threshold = 0.5
frame_threshold = 0.4
clip_len = 4096
n_step = (audio_length - 1) // HOP_LENGTH + 1

if n_step <= clip_len:
    torch.cuda.empty_cache()
    with torch.no_grad():
        pred = model.forward(audio)
# clip audio to fixed length to prevent out of memory.
else:  # when test on long audio
    print('n_step > clip_len %d ' % clip_len, audio.shape)
    clip_list = [clip_len] * (n_step // clip_len)
    res = n_step % clip_len
    # clip_list.append(res)
    if (n_step > clip_len and res != 0):
        clip_list[-1] -= (clip_len - res)//2
        clip_list += [res + (clip_len - res)//2]

    print('clip list:', clip_list)

    begin = 0
    pred = {}
    losses = {}
    for clip in clip_list:
        end = begin + clip
        audio_i = audio[0][HOP_LENGTH*begin:HOP_LENGTH*end]
        audio_i = audio_i.unsqueeze(0)
        torch.cuda.empty_cache()
        with torch.no_grad():
            pred_i = model.forward(audio_i)

        for key, item in pred_i.items():
            if (key in pred):
                item = item[:,:,:clip,:].to('cpu')
                pred[key] = torch.cat([pred[key].to('cpu'), item], dim=2)
            else:
                pred[key] = item
        begin += clip
for key, value in pred.items():
    value.squeeze_(0).relu_()
    value.squeeze_(0)


# pitch, interval, velocity
p_est, i_est, v_est = extract_notes(
    pred['onset'], pred['frame'], pred['velocity'], onset_threshold, frame_threshold)
scaling = HOP_LENGTH / SAMPLE_RATE
i_est = (i_est * scaling).reshape(-1, 2)
p_est = np.array([midi_to_hz(MIN_MIDI + midi) for midi in p_est])
midi_path =audio_path.split('/')[-1]+model_path.split('/')[-1]+'.mid'

# Save midi file
save_midi(midi_path, p_est, i_est, v_est)