# "音色无关转录"模型评估
针对basic-pitch已经训练的模型（论文作者给出）展开。按照官网要求进行环境配置：
```
pip install basic-pitch[tf]
```

## 生成转录结果并保存
经过process后每个文件夹里的文件只有后缀不同，且后缀为"npy" "wav" "mid"。wav采样率已经是22050Hz。

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
sys.path.append('..')

s_per_frame = 256 / 22050

dataset_folders = ["BACH10_processed", "URMP_processed", "PHENICX_processed"]

In [None]:
# 加载模型
import tensorflow as tf
from basic_pitch.inference import predict, Model
from basic_pitch import ICASSP_2022_MODEL_PATH

model = Model(ICASSP_2022_MODEL_PATH)

In [None]:
def amt_one(model, file):
    model_output, midi_data, note_events = predict(file, model)
    # basic-pitch输出的是字典，且时间是第一维度，且频率为钢琴88音阶
    note = model_output['note'].T[3:-1]
    onset = model_output['onset'].T[3:-1]
    return onset, note

def amt_piece(model, piece_folder):
    filename = os.listdir(piece_folder)[0]
    path = os.path.join(piece_folder, filename)[:-3]    # 去掉后缀
    onset, note = amt_one(model, path + "wav")
    midi = np.load(path+"npy")
    # 补时间长度
    freqs, times = note.shape
    if midi.shape[1] < times:
        padding = np.zeros((freqs, times - midi.shape[1]))
        midi = np.concatenate((midi, padding), axis=1)
    elif midi.shape[1] > times:
        midi = midi[:, :times]
    return onset, note, midi

def amt_dataset(model, dataset_folder, output_folder = './'):
    folder_name = os.path.basename(dataset_folder)
    print(f"processing {folder_name}")
    output_folder_name = folder_name.split("_")[0] + "_eval"
    output_path = os.path.join(output_folder, output_folder_name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    for piece_folder in os.listdir(dataset_folder):
        if os.path.isdir(os.path.join(dataset_folder, piece_folder)):
            result = amt_piece(model, os.path.join(dataset_folder, piece_folder))
            np.save(os.path.join(output_path, piece_folder+".npy"), np.stack(result, axis=0))
            print(f"\tFinish {piece_folder}")


In [None]:
# 计算帧级的评价指标
from utils.midiarray import freq_map, roll2evalarray
from utils.postprocess import min_len_
import mir_eval

s_per_frame = 256 / 22050
freqmap = freq_map((24, 107), 440)

def frame_eval(note, midi, threshold = 0.5):
    """
    对note进行阈值二值化、移除短音符、转换为mir_eval所需数
    计算帧级评价指标
    note: (freqs, times)
    midi: (freqs, times)
    """
    binary_note = (note > threshold).astype(int)    # 二值化
    # 这个min_len_是原位操作，会修改输入
    est_pitch = roll2evalarray(min_len_(binary_note, 3), freqmap)
    ref_pitch = roll2evalarray(midi, freqmap)
    rst_time = s_per_frame * np.arange(len(est_pitch))
    ref_time = s_per_frame * np.arange(len(ref_pitch))
    result = mir_eval.multipitch.evaluate(ref_time, ref_pitch, rst_time, est_pitch)
    return result   # https://github.com/mir-evaluation/mir_eval/blob/main/mir_eval/multipitch.py


def evaluate_frame_dataset(npy_pathes, threshold = 0.5, log = True):
    """
    对npy_pathes中的所有npy文件用同一个阈值进行评估
    dataset_folder: folder containing npy files, each file is a result of amt_piece, shape (3, freqs, times): onset, note, midi
    """
    accs = []
    ps = []
    rs = []
    f1s = []
    for npy_file in npy_pathes:
        result = np.load(npy_file)
        evaluation = frame_eval(result[1], result[2], threshold)
        acc = evaluation['Accuracy']
        p = evaluation['Precision']
        r = evaluation['Recall']
        accs.append(acc)
        ps.append(p)
        rs.append(r)
        f1s.append(2*p*r/(p+r) if p+r > 0 else 0)
    ACC = np.mean(accs)
    P = np.mean(ps)
    R = np.mean(rs)
    F1 = np.mean(f1s)
    if log:
        # | Acc | P | R | F1 |
        print(f"| {threshold:.5f} | {ACC:.5f} | {P:.5f} | {R:.5f} | {F1:.5f} |")
    return ACC, P, R, F1


def find_best_threshold(npy_pathes, origin_range = (0.1, 0.9), step_num = 10, generation = 4, log = True):
    if log:
        print("| threshold | Acc | P | R | F1 |")
        print("| --------- | --- |---|---|----|")
    
    start = origin_range[0]
    end = origin_range[1]
    step = (end - start) / step_num
    
    best_thre = -1
    max_f1 = -1
    best_thre_idx = -1

    for g in range(generation):
        lastF1 = -1
        thresholds = np.r_[start:end:step]
        for idx, thre in enumerate(thresholds):
            ACC, P, R, F1 = evaluate_frame_dataset(npy_pathes, thre, log)
            if F1 > max_f1:
                max_f1 = F1
                best_thre_idx = idx
                best_thre = thre
            if F1 < lastF1: # 假设F1是一个凹函数，只要开始下降就可以停止了
                break
            lastF1 = F1
        if log:
            print(f"| Best threshold | {best_thre} | ~ | ~ | F1: {max_f1} |")
        
        # 如果是边缘的话，下一轮start不会覆盖到最优值，所以提前加入；否则清空
        if best_thre_idx == -1: # 说明最优值还在左边
            best_thre_idx = -1  # -1表示最大值在左边外面
            start = best_thre
            end = thresholds[0]
            step = (end - start) / step_num
            start += step
        elif best_thre_idx == 0:    # 最值就是最左边的
            best_thre_idx = -1
            start = best_thre
            end = thresholds[1]
            step = (end - start) / step_num
            start += step
        elif best_thre_idx == -2:   # 如果是右边缘的右边最大，说明最优值还在右边
            best_thre_idx = -2      # -2表示最大值在右边外面
            start = thresholds[-1]
            end = best_thre
            step = (end - start) / step_num
            start += step
        elif best_thre_idx == len(thresholds) - 1:  # 最值就是最右边的
            best_thre_idx = -2
            start = thresholds[-2]
            end = best_thre
            step = (end - start) / step_num
            start += step
        else:
            start = thresholds[best_thre_idx-1]
            end = thresholds[best_thre_idx+1]
            step = (end - start) / step_num
            start += step   # 少分析一轮
            max_f1 = -1     # 清空最大值，因为一定在区间内

    return best_thre

### 运行模型

In [None]:
model_folder_name = "basicpitch_raw"
# 得到所有运行结果
for dataset_folder in dataset_folders:
    amt_dataset(model, dataset_folder, f"./{model_folder_name}")

### 计算帧级指标，并寻找最好阈值

In [None]:
# 只看BACH10的合奏结果
npyfolder = f"{model_folder_name}/BACH10_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith("0.npy")]
find_best_threshold(npys, (0.1, 0.5), step_num=10, generation=4, log=True)

- BACH10合奏：Best threshold: 0.356, F1: 0.809163494183113

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.10000 | 0.06911 | 0.06917 | 0.98839 | 0.12928 |
| 0.14000 | 0.44051 | 0.45429 | 0.93507 | 0.61103 |
| 0.18000 | 0.56174 | 0.59487 | 0.90897 | 0.71895 |
| 0.22000 | 0.62229 | 0.67667 | 0.88480 | 0.76676 |
| 0.26000 | 0.65851 | 0.73505 | 0.86283 | 0.79365 |
| 0.30000 | 0.67461 | 0.77503 | 0.83842 | 0.80519 |
| 0.34000 | 0.67982 | 0.80420 | 0.81448 | 0.80883 |
| 0.38000 | 0.67823 | 0.82614 | 0.79122 | 0.80764 |
| Best threshold | 0.34 | ~ | ~ | F1: 0.8088316027103488 |
| 0.30800 | 0.67692 | 0.78218 | 0.83377 | 0.80682 |
| 0.31600 | 0.67801 | 0.78819 | 0.82877 | 0.80759 |
| 0.32400 | 0.67905 | 0.79407 | 0.82394 | 0.80831 |
| 0.33200 | 0.67961 | 0.79906 | 0.81947 | 0.80870 |
| 0.34000 | 0.67982 | 0.80420 | 0.81448 | 0.80883 |
| 0.34800 | 0.68018 | 0.80896 | 0.81019 | 0.80907 |
| 0.35600 | 0.68033 | 0.81361 | 0.80581 | 0.80916 |
| 0.36400 | 0.67949 | 0.81744 | 0.80097 | 0.80857 |
| Best threshold | 0.356 | ~ | ~ | F1: 0.809163494183113 |
| 0.34960 | 0.68006 | 0.80974 | 0.80925 | 0.80899 |
| 0.35120 | 0.68007 | 0.81061 | 0.80839 | 0.80899 |
| 0.35280 | 0.68011 | 0.81159 | 0.80746 | 0.80901 |
| 0.35440 | 0.68024 | 0.81264 | 0.80661 | 0.80910 |
| 0.35600 | 0.68033 | 0.81361 | 0.80581 | 0.80916 |
| 0.35760 | 0.68016 | 0.81428 | 0.80492 | 0.80904 |
| Best threshold | 0.356 | ~ | ~ | F1: 0.809163494183113 |

In [None]:
# 看BACH10所有音频的结果（独奏+合奏）
npyfolder = f"{model_folder_name}/BACH10_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith(".npy")]
find_best_threshold(npys, (0.25, 0.7), step_num=10, generation=4, log=True)

- BACH10所有音频：Best threshold: 0.47428, F1: 0.8315663457665301

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.25000 | 0.64523 | 0.69325 | 0.89696 | 0.77757 |
| 0.29500 | 0.67768 | 0.73955 | 0.88423 | 0.80251 |
| 0.34000 | 0.69596 | 0.77090 | 0.87234 | 0.81641 |
| 0.38500 | 0.70689 | 0.79416 | 0.86138 | 0.82473 |
| 0.43000 | 0.71299 | 0.81393 | 0.84917 | 0.82945 |
| 0.47500 | 0.71570 | 0.83177 | 0.83567 | 0.83148 |
| 0.52000 | 0.71294 | 0.84513 | 0.82030 | 0.82926 |
| Best threshold | 0.475 | ~ | ~ | F1: 0.8314788218697448 |
| 0.43900 | 0.71389 | 0.81783 | 0.84651 | 0.83015 |
| 0.44800 | 0.71442 | 0.82124 | 0.84392 | 0.83056 |
| 0.45700 | 0.71522 | 0.82517 | 0.84122 | 0.83115 |
| 0.46600 | 0.71535 | 0.82834 | 0.83838 | 0.83124 |
| 0.47500 | 0.71570 | 0.83177 | 0.83567 | 0.83148 |
| 0.48400 | 0.71554 | 0.83466 | 0.83289 | 0.83136 |
| Best threshold | 0.475 | ~ | ~ | F1: 0.8314788218697448 |
| 0.46780 | 0.71545 | 0.82909 | 0.83781 | 0.83131 |
| 0.46960 | 0.71546 | 0.82975 | 0.83722 | 0.83132 |
| 0.47140 | 0.71550 | 0.83037 | 0.83671 | 0.83135 |
| 0.47320 | 0.71569 | 0.83112 | 0.83626 | 0.83148 |
| 0.47500 | 0.71570 | 0.83177 | 0.83567 | 0.83148 |
| 0.47680 | 0.71567 | 0.83230 | 0.83516 | 0.83145 |
| Best threshold | 0.475 | ~ | ~ | F1: 0.8314788218697448 |
| 0.47356 | 0.71576 | 0.83132 | 0.83617 | 0.83153 |
| 0.47392 | 0.71582 | 0.83150 | 0.83607 | 0.83156 |
| 0.47428 | 0.71582 | 0.83161 | 0.83598 | 0.83157 |
| 0.47464 | 0.71577 | 0.83168 | 0.83585 | 0.83153 |
| Best threshold | 0.47428 | ~ | ~ | F1: 0.8315663457665301 |

In [None]:
# 看PHENICX合奏音频的结果
npyfolder = f"{model_folder_name}/PHENICX_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith(".npy")]
find_best_threshold(npys, (0.1, 0.4), step_num=10, generation=4, log=True)

- PHENICX合奏：Best threshold: 0.2032, F1: 0.4182288392628133

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.10000 | 0.08342 | 0.08465 | 0.84400 | 0.15362 |
| 0.13000 | 0.23923 | 0.28719 | 0.58059 | 0.38417 |
| 0.16000 | 0.25987 | 0.33902 | 0.52054 | 0.41061 |
| 0.19000 | 0.26523 | 0.36868 | 0.48107 | 0.41739 |
| 0.22000 | 0.26508 | 0.38888 | 0.45066 | 0.41736 |
| Best threshold | 0.19 | ~ | ~ | F1: 0.4173910922611989 |
| 0.16600 | 0.26172 | 0.34615 | 0.51168 | 0.41294 |
| 0.17200 | 0.26321 | 0.35269 | 0.50353 | 0.41481 |
| 0.17800 | 0.26411 | 0.35835 | 0.49569 | 0.41596 |
| 0.18400 | 0.26475 | 0.36370 | 0.48808 | 0.41677 |
| 0.19000 | 0.26523 | 0.36868 | 0.48107 | 0.41739 |
| 0.19600 | 0.26554 | 0.37347 | 0.47425 | 0.41780 |
| 0.20200 | 0.26569 | 0.37780 | 0.46804 | 0.41802 |
| 0.20800 | 0.26579 | 0.38194 | 0.46228 | 0.41819 |
| 0.21400 | 0.26547 | 0.38552 | 0.45630 | 0.41781 |
| Best threshold | 0.208 | ~ | ~ | F1: 0.41818627219263216 |
| 0.20320 | 0.26585 | 0.37882 | 0.46700 | 0.41823 |
| 0.20440 | 0.26584 | 0.37963 | 0.46577 | 0.41822 |
| Best threshold | 0.2032 | ~ | ~ | F1: 0.4182288392628133 |
| 0.20332 | 0.26583 | 0.37888 | 0.46684 | 0.41820 |
| 0.20344 | 0.26585 | 0.37899 | 0.46673 | 0.41822 |
| 0.20356 | 0.26584 | 0.37906 | 0.46661 | 0.41821 |
| Best threshold | 0.2032 | ~ | ~ | F1: 0.4182288392628133 |

In [None]:
# 只看URMP的合奏结果
npyfolder = f"{model_folder_name}/URMP_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith("0.npy")]
find_best_threshold(npys, (0.1, 0.5), step_num=10, generation=4, log=True)

- URMP合奏：Best threshold: 0.3064, F1: 0.5170594175870273

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.10000 | 0.04301 | 0.04324 | 0.88841 | 0.08214 |
| 0.14000 | 0.25517 | 0.28754 | 0.66252 | 0.39861 |
| 0.18000 | 0.31596 | 0.37984 | 0.61983 | 0.46891 |
| 0.22000 | 0.34412 | 0.43502 | 0.58982 | 0.49899 |
| 0.26000 | 0.35715 | 0.47110 | 0.56512 | 0.51231 |
| 0.30000 | 0.36170 | 0.49584 | 0.54281 | 0.51675 |
| 0.34000 | 0.36162 | 0.51411 | 0.52214 | 0.51645 |
| Best threshold | 0.3 | ~ | ~ | F1: 0.5167462516913957 |
| 0.26800 | 0.35867 | 0.47686 | 0.56050 | 0.51379 |
| 0.27600 | 0.35972 | 0.48220 | 0.55572 | 0.51486 |
| 0.28400 | 0.36065 | 0.48707 | 0.55136 | 0.51573 |
| 0.29200 | 0.36132 | 0.49165 | 0.54709 | 0.51640 |
| 0.30000 | 0.36170 | 0.49584 | 0.54281 | 0.51675 |
| 0.30800 | 0.36201 | 0.49982 | 0.53872 | 0.51702 |
| 0.31600 | 0.36198 | 0.50354 | 0.53437 | 0.51695 |
| Best threshold | 0.308 | ~ | ~ | F1: 0.5170180377708988 |
| 0.30160 | 0.36183 | 0.49672 | 0.54201 | 0.51686 |
| 0.30320 | 0.36191 | 0.49751 | 0.54121 | 0.51692 |
| 0.30480 | 0.36198 | 0.49828 | 0.54045 | 0.51699 |
| 0.30640 | 0.36206 | 0.49911 | 0.53963 | 0.51706 |
| 0.30800 | 0.36201 | 0.49982 | 0.53872 | 0.51702 |
| Best threshold | 0.3064 | ~ | ~ | F1: 0.5170594175870273 |
| 0.30512 | 0.36200 | 0.49845 | 0.54027 | 0.51700 |
| 0.30544 | 0.36201 | 0.49862 | 0.54011 | 0.51702 |
| 0.30576 | 0.36202 | 0.49876 | 0.53994 | 0.51702 |
| Best threshold | 0.3064 | ~ | ~ | F1: 0.5170594175870273 |

In [None]:
# 只看URMP的独奏结果
npyfolder = f"{model_folder_name}/URMP_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if not f.endswith("0.npy")]
find_best_threshold(npys, (0.35, 0.6), step_num=10, generation=4, log=True)

- URMP的独奏：Best threshold: 0.49, F1: 0.5552016311298443

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.35000 | 0.39431 | 0.50414 | 0.59512 | 0.54471 |
| 0.37500 | 0.39832 | 0.51437 | 0.58945 | 0.54836 |
| 0.40000 | 0.40134 | 0.52325 | 0.58400 | 0.55105 |
| 0.42500 | 0.40363 | 0.53114 | 0.57863 | 0.55302 |
| 0.45000 | 0.40539 | 0.53834 | 0.57323 | 0.55441 |
| 0.47500 | 0.40623 | 0.54452 | 0.56761 | 0.55498 |
| 0.50000 | 0.40665 | 0.55016 | 0.56210 | 0.55517 |
| 0.52500 | 0.40635 | 0.55525 | 0.55613 | 0.55469 |
| Best threshold | 0.5 | ~ | ~ | F1: 0.5551684215308371 |
| 0.48000 | 0.40642 | 0.54575 | 0.56657 | 0.55512 |
| 0.48500 | 0.40655 | 0.54694 | 0.56546 | 0.55519 |
| 0.49000 | 0.40660 | 0.54802 | 0.56436 | 0.55520 |
| 0.49500 | 0.40659 | 0.54909 | 0.56315 | 0.55515 |
| Best threshold | 0.49 | ~ | ~ | F1: 0.5552016311298443 |
| 0.48600 | 0.40654 | 0.54713 | 0.56523 | 0.55517 |
| 0.48700 | 0.40655 | 0.54736 | 0.56498 | 0.55517 |
| Best threshold | 0.486 | ~ | ~ | F1: 0.5551734307211511 |
| 0.48610 | 0.40654 | 0.54716 | 0.56520 | 0.55518 |
| 0.48620 | 0.40654 | 0.54717 | 0.56517 | 0.55517 |
| Best threshold | 0.4861 | ~ | ~ | F1: 0.5551764222375639 |