# "音色无关转录"模型评估

## 生成转录结果并保存
经过process后每个文件夹里的文件只有后缀不同，且后缀为"npy" "wav" "mid"。wav采样率已经是22050Hz。

In [None]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
sys.path.append('..')

s_per_frame = 256 / 22050

dataset_folders = ["BACH10_processed", "URMP_processed", "PHENICX_processed"]

In [None]:
def amt_one(model, file):
    waveform, sample_rate = torchaudio.load(file)
    waveform = waveform.unsqueeze(0)
    onset, note = model(waveform)
    onset = onset.cpu().numpy()[0]
    note = note.cpu().numpy()[0]
    return onset, note

def amt_piece(model, piece_folder):
    filename = os.listdir(piece_folder)[0]
    path = os.path.join(piece_folder, filename)[:-3]    # 去掉后缀
    onset, note = amt_one(model, path + "wav")
    midi = np.load(path+"npy")
    # 补时间长度
    freqs, times = note.shape
    if midi.shape[1] < times:
        padding = np.zeros((freqs, times - midi.shape[1]))
        midi = np.concatenate((midi, padding), axis=1)
    elif midi.shape[1] > times:
        midi = midi[:, :times]
    return onset, note, midi

def amt_dataset(model, dataset_folder, output_folder = './'):
    folder_name = os.path.basename(dataset_folder)
    print(f"processing {folder_name}")
    output_folder_name = folder_name.split("_")[0] + "_eval"
    output_path = os.path.join(output_folder, output_folder_name)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    for piece_folder in os.listdir(dataset_folder):
        if os.path.isdir(os.path.join(dataset_folder, piece_folder)):
            result = amt_piece(model, os.path.join(dataset_folder, piece_folder))
            np.save(os.path.join(output_path, piece_folder+".npy"), np.stack(result, axis=0))
            print(f"\tFinish {piece_folder}")


In [None]:
# 计算帧级的评价指标
from utils.midiarray import freq_map, roll2evalarray
from utils.postprocess import min_len_
import mir_eval

s_per_frame = 256 / 22050
freqmap = freq_map((24, 107), 440)

def frame_eval(note, midi, threshold = 0.5):
    """
    对note进行阈值二值化、移除短音符、转换为mir_eval所需数
    计算帧级评价指标
    note: (freqs, times)
    midi: (freqs, times)
    """
    binary_note = (note > threshold).astype(int)    # 二值化
    # 这个min_len_是原位操作，会修改输入
    est_pitch = roll2evalarray(min_len_(binary_note, 3), freqmap)
    ref_pitch = roll2evalarray(midi, freqmap)
    rst_time = s_per_frame * np.arange(len(est_pitch))
    ref_time = s_per_frame * np.arange(len(ref_pitch))
    result = mir_eval.multipitch.evaluate(ref_time, ref_pitch, rst_time, est_pitch)
    return result   # https://github.com/mir-evaluation/mir_eval/blob/main/mir_eval/multipitch.py


def evaluate_frame_dataset(npy_pathes, threshold = 0.5, log = True):
    """
    对npy_pathes中的所有npy文件用同一个阈值进行评估
    dataset_folder: folder containing npy files, each file is a result of amt_piece, shape (3, freqs, times): onset, note, midi
    """
    accs = []
    ps = []
    rs = []
    f1s = []
    for npy_file in npy_pathes:
        result = np.load(npy_file)
        evaluation = frame_eval(result[1], result[2], threshold)
        acc = evaluation['Accuracy']
        p = evaluation['Precision']
        r = evaluation['Recall']
        accs.append(acc)
        ps.append(p)
        rs.append(r)
        f1s.append(2*p*r/(p+r) if p+r > 0 else 0)
    ACC = np.mean(accs)
    P = np.mean(ps)
    R = np.mean(rs)
    F1 = np.mean(f1s)
    if log:
        # | Acc | P | R | F1 |
        print(f"| {threshold:.5f} | {ACC:.5f} | {P:.5f} | {R:.5f} | {F1:.5f} |")
    return ACC, P, R, F1


def find_best_threshold(npy_pathes, origin_range = (0.1, 0.9), step_num = 10, generation = 4, log = True):
    if log:
        print("| threshold | Acc | P | R | F1 |")
        print("| --------- | --- |---|---|----|")
    
    start = origin_range[0]
    end = origin_range[1]
    step = (end - start) / step_num
    
    best_thre = -1
    max_f1 = -1
    best_thre_idx = -1

    for g in range(generation):
        lastF1 = -1
        thresholds = np.r_[start:end:step]
        for idx, thre in enumerate(thresholds):
            ACC, P, R, F1 = evaluate_frame_dataset(npy_pathes, thre, log)
            if F1 > max_f1:
                max_f1 = F1
                best_thre_idx = idx
                best_thre = thre
            if F1 < lastF1: # 假设F1是一个凹函数，只要开始下降就可以停止了
                break
            lastF1 = F1
        if log:
            print(f"| Best threshold | {best_thre} | ~ | ~ | F1: {max_f1} |")
        
        # 如果是边缘的话，下一轮start不会覆盖到最优值，所以提前加入；否则清空
        if best_thre_idx == -1: # 说明最优值还在左边
            best_thre_idx = -1  # -1表示最大值在左边外面
            start = best_thre
            end = thresholds[0]
            step = (end - start) / step_num
            start += step
        elif best_thre_idx == 0:    # 最值就是最左边的
            best_thre_idx = -1
            start = best_thre
            end = thresholds[1]
            step = (end - start) / step_num
            start += step
        elif best_thre_idx == -2:   # 如果是右边缘的右边最大，说明最优值还在右边
            best_thre_idx = -2      # -2表示最大值在右边外面
            start = thresholds[-1]
            end = best_thre
            step = (end - start) / step_num
            start += step
        elif best_thre_idx == len(thresholds) - 1:  # 最值就是最右边的
            best_thre_idx = -2
            start = thresholds[-2]
            end = best_thre
            step = (end - start) / step_num
            start += step
        else:
            start = thresholds[best_thre_idx-1]
            end = thresholds[best_thre_idx+1]
            step = (end - start) / step_num
            start += step   # 少分析一轮
            max_f1 = -1     # 清空最大值，因为一定在区间内

    return best_thre

### 运行模型

In [None]:
model_folder_name = "tiny"
sys.path.append(f'../{model_folder_name}')
model = torch.load(f"../{model_folder_name}/basicamt_model.pth")
model.eval()

In [None]:
# 得到所有运行结果
with torch.no_grad():
    for dataset_folder in dataset_folders:
        amt_dataset(model, dataset_folder, f"./{model_folder_name}")

### 计算帧级指标，并寻找最好阈值

In [None]:
# 只看BACH10的合奏结果
npyfolder = f"{model_folder_name}/BACH10_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith("0.npy")]
find_best_threshold(npys, (0.1, 0.5), step_num=10, generation=4, log=True)

- BACH10合奏：Best threshold: 0.15312, F1: 0.8029144908856092

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.10000 | 0.65786 | 0.73881 | 0.85705 | 0.79322 |
| 0.14000 | 0.67115 | 0.79435 | 0.81210 | 0.80279 |
| 0.18000 | 0.66741 | 0.83337 | 0.76996 | 0.80009 |
| Best threshold | 0.14 | ~ | ~ | F1: 0.8027928864331081 |
| 0.10800 | 0.66312 | 0.75257 | 0.84787 | 0.79704 |
| 0.11600 | 0.66638 | 0.76456 | 0.83826 | 0.79938 |
| 0.12400 | 0.66885 | 0.77563 | 0.82911 | 0.80116 |
| 0.13200 | 0.67041 | 0.78556 | 0.82038 | 0.80227 |
| 0.14000 | 0.67115 | 0.79435 | 0.81210 | 0.80279 |
| 0.14800 | 0.67121 | 0.80310 | 0.80322 | 0.80283 |
| 0.15600 | 0.67129 | 0.81163 | 0.79494 | 0.80287 |
| 0.16400 | 0.67049 | 0.81933 | 0.78658 | 0.80230 |
| Best threshold | 0.156 | ~ | ~ | F1: 0.8028732920285189 |
| 0.14960 | 0.67106 | 0.80460 | 0.80150 | 0.80271 |
| 0.15120 | 0.67132 | 0.80669 | 0.79979 | 0.80289 |
| 0.15280 | 0.67140 | 0.80843 | 0.79821 | 0.80295 |
| 0.15440 | 0.67143 | 0.81007 | 0.79664 | 0.80297 |
| 0.15600 | 0.67129 | 0.81163 | 0.79494 | 0.80287 |
| Best threshold | 0.1544 | ~ | ~ | F1: 0.8029676127610162 |
| 0.15312 | 0.67135 | 0.80869 | 0.79789 | 0.80291 |
| 0.15344 | 0.67135 | 0.80907 | 0.79752 | 0.80291 |
| Best threshold | 0.15312 | ~ | ~ | F1: 0.8029144908856092 |

In [None]:
# 看BACH10所有音频的结果（独奏+合奏）
npyfolder = f"{model_folder_name}/BACH10_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith(".npy")]
find_best_threshold(npys, (0.25, 0.7), step_num=10, generation=4, log=True)

- BACH10所有音频：Best threshold: 0.30184, F1: 0.860866334491572

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.25000 | 0.76245 | 0.85539 | 0.87324 | 0.85901 |
| 0.29500 | 0.76551 | 0.87300 | 0.85966 | 0.86057 |
| 0.34000 | 0.76596 | 0.88729 | 0.84707 | 0.86016 |
| Best threshold | 0.295 | ~ | ~ | F1: 0.8605656078947697 |
| 0.25900 | 0.76346 | 0.85948 | 0.87029 | 0.85960 |
| 0.26800 | 0.76413 | 0.86317 | 0.86742 | 0.85994 |
| 0.27700 | 0.76489 | 0.86684 | 0.86475 | 0.86034 |
| 0.28600 | 0.76519 | 0.87000 | 0.86209 | 0.86045 |
| 0.29500 | 0.76551 | 0.87300 | 0.85966 | 0.86057 |
| 0.30400 | 0.76600 | 0.87632 | 0.85706 | 0.86076 |
| 0.31300 | 0.76596 | 0.87897 | 0.85461 | 0.86063 |
| Best threshold | 0.304 | ~ | ~ | F1: 0.8607566451715536 |
| 0.29680 | 0.76568 | 0.87378 | 0.85911 | 0.86065 |
| 0.29860 | 0.76592 | 0.87459 | 0.85863 | 0.86079 |
| 0.30040 | 0.76603 | 0.87520 | 0.85816 | 0.86083 |
| 0.30220 | 0.76609 | 0.87580 | 0.85766 | 0.86084 |
| 0.30400 | 0.76600 | 0.87632 | 0.85706 | 0.86076 |
| Best threshold | 0.3022 | ~ | ~ | F1: 0.8608435176278999 |
| 0.30076 | 0.76607 | 0.87536 | 0.85805 | 0.86085 |
| 0.30112 | 0.76608 | 0.87547 | 0.85795 | 0.86085 |
| 0.30148 | 0.76610 | 0.87559 | 0.85787 | 0.86086 |
| 0.30184 | 0.76612 | 0.87571 | 0.85777 | 0.86087 |
| 0.30220 | 0.76609 | 0.87580 | 0.85766 | 0.86084 |
| Best threshold | 0.30184 | ~ | ~ | F1: 0.860866334491572 |

In [None]:
# 看PHENICX合奏音频的结果
npyfolder = f"{model_folder_name}/PHENICX_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith(".npy")]
find_best_threshold(npys, (0.02, 0.2), step_num=10, generation=4, log=True)

- PHENICX合奏：Best threshold: 0.06464, F1: 0.5951196104555911

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.02000 | 0.29774 | 0.32434 | 0.78537 | 0.45879 |
| 0.03800 | 0.39485 | 0.47775 | 0.69539 | 0.56579 |
| 0.05600 | 0.42212 | 0.56043 | 0.63020 | 0.59276 |
| 0.07400 | 0.42333 | 0.60922 | 0.57939 | 0.59346 |
| 0.09200 | 0.41475 | 0.64209 | 0.53722 | 0.58452 |
| Best threshold | 0.074 | ~ | ~ | F1: 0.59346064026473 |
| 0.05960 | 0.42356 | 0.57159 | 0.61946 | 0.59407 |
| 0.06320 | 0.42446 | 0.58208 | 0.60926 | 0.59487 |
| 0.06680 | 0.42493 | 0.59219 | 0.59927 | 0.59523 |
| 0.07040 | 0.42455 | 0.60131 | 0.58927 | 0.59476 |
| Best threshold | 0.0668 | ~ | ~ | F1: 0.5952275971754704 |
| 0.06392 | 0.42458 | 0.58414 | 0.60720 | 0.59496 |
| 0.06464 | 0.42476 | 0.58628 | 0.60524 | 0.59512 |
| 0.06536 | 0.42478 | 0.58829 | 0.60309 | 0.59511 |
| Best threshold | 0.06464 | ~ | ~ | F1: 0.5951196104555911 |
| 0.06406 | 0.42460 | 0.58457 | 0.60677 | 0.59498 |
| 0.06421 | 0.42468 | 0.58503 | 0.60642 | 0.59505 |
| 0.06435 | 0.42471 | 0.58546 | 0.60602 | 0.59507 |
| 0.06450 | 0.42471 | 0.58588 | 0.60558 | 0.59507 |
| 0.06464 | 0.42476 | 0.58628 | 0.60524 | 0.59512 |
| 0.06478 | 0.42473 | 0.58663 | 0.60479 | 0.59509 |
| Best threshold | 0.06464 | ~ | ~ | F1: 0.5951196104555911 |

In [None]:
# 只看URMP的合奏结果
npyfolder = f"{model_folder_name}/URMP_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith("0.npy")]
find_best_threshold(npys, (0.1, 0.5), step_num=10, generation=4, log=True)

- URMP合奏：Best threshold: 0.1384, F1: 0.7232317124904409

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.10000 | 0.56187 | 0.69005 | 0.74933 | 0.71611 |
| 0.14000 | 0.57086 | 0.75036 | 0.70254 | 0.72308 |
| 0.18000 | 0.56544 | 0.78914 | 0.66436 | 0.71812 |
| Best threshold | 0.14 | ~ | ~ | F1: 0.7230835170006764 |
| 0.10800 | 0.56556 | 0.70472 | 0.73898 | 0.71910 |
| 0.11600 | 0.56809 | 0.71771 | 0.72932 | 0.72110 |
| 0.12400 | 0.56992 | 0.72978 | 0.72017 | 0.72251 |
| 0.13200 | 0.57070 | 0.74045 | 0.71125 | 0.72306 |
| 0.14000 | 0.57086 | 0.75036 | 0.70254 | 0.72308 |
| 0.14800 | 0.57055 | 0.75937 | 0.69436 | 0.72273 |
| Best threshold | 0.14 | ~ | ~ | F1: 0.7230835170006764 |
| 0.13360 | 0.57087 | 0.74258 | 0.70955 | 0.72317 |
| 0.13520 | 0.57090 | 0.74460 | 0.70776 | 0.72317 |
| 0.13680 | 0.57095 | 0.74659 | 0.70604 | 0.72319 |
| 0.13840 | 0.57102 | 0.74857 | 0.70437 | 0.72323 |
| 0.14000 | 0.57086 | 0.75036 | 0.70254 | 0.72308 |
| Best threshold | 0.1384 | ~ | ~ | F1: 0.7232317124904409 |
| 0.13712 | 0.57094 | 0.74699 | 0.70566 | 0.72318 |
| 0.13744 | 0.57095 | 0.74736 | 0.70534 | 0.72318 |
| 0.13776 | 0.57095 | 0.74772 | 0.70503 | 0.72318 |
| 0.13808 | 0.57097 | 0.74812 | 0.70470 | 0.72320 |
| 0.13840 | 0.57102 | 0.74857 | 0.70437 | 0.72323 |
| 0.13872 | 0.57099 | 0.74893 | 0.70402 | 0.72320 |
| Best threshold | 0.1384 | ~ | ~ | F1: 0.7232317124904409 |

In [None]:
# 只看URMP的独奏结果
npyfolder = f"{model_folder_name}/URMP_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if not f.endswith("0.npy")]
find_best_threshold(npys, (0.35, 0.5), step_num=10, generation=4, log=True)

- URMP的独奏：Best threshold: 0.3518, F1: 0.8055191663063108

| threshold | Acc | P | R | F1 |
| --------- | --- |---|---|----|
| 0.35000 | 0.68625 | 0.85725 | 0.77045 | 0.80550 |
| 0.36500 | 0.68616 | 0.86241 | 0.76626 | 0.80536 |
| Best threshold | 0.35 | ~ | ~ | F1: 0.8054999270748892 |
| 0.35150 | 0.68634 | 0.85782 | 0.77008 | 0.80555 |
| 0.35300 | 0.68635 | 0.85831 | 0.76971 | 0.80556 |
| 0.35450 | 0.68632 | 0.85888 | 0.76926 | 0.80553 |
| Best threshold | 0.353 | ~ | ~ | F1: 0.8055610835598012 |
| 0.35180 | 0.68630 | 0.85788 | 0.76998 | 0.80552 |
| 0.35210 | 0.68630 | 0.85796 | 0.76992 | 0.80552 |
| Best threshold | 0.3518 | ~ | ~ | F1: 0.8055191663063108 |
| 0.35183 | 0.68630 | 0.85788 | 0.76998 | 0.80552 |
| 0.35186 | 0.68630 | 0.85789 | 0.76997 | 0.80552 |
| Best threshold | 0.3518 | ~ | ~ | F1: 0.8055191663063108 |

In [11]:
# 输出参数数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model.cqt))
print(count_parameters(model))

19944
46564
