# "音色无关转录"模型评估
learnable CQT

## 生成转录结果并保存
经过process后每个文件夹里的文件只有后缀不同，且后缀为"npy" "wav" "mid"。wav采样率已经是22050Hz。

In [None]:
import os
from instrument_agnostic_eval_utils import *
import sys
sys.path.append('..')

s_per_frame = 256 / 22050

dataset_folders = ["BACH10_processed", "URMP_processed", "PHENICX_processed"]

### 运行模型

In [None]:
model_folder_name = "basicamt_learnableCQT"
sys.path.append(f'../basicamt/')
from basicamt.basicamt import BasicAMT_all
torch.serialization.add_safe_globals([BasicAMT_all])
model = torch.load(f"../basicamt/ablation/learnableCQT/basicamt_model_learnableCQT.pth", weights_only=False)
model.eval()

In [None]:
# 得到所有运行结果
with torch.no_grad():
    for dataset_folder in dataset_folders:
        amt_dataset(model, dataset_folder, f"./{model_folder_name}")

### 计算帧级指标，并寻找最好阈值

In [None]:
# 为了方便粘贴到excel
output_results = []

In [None]:
# 只看BACH10的合奏结果
npyfolder = f"{model_folder_name}/BACH10_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith("0.npy")]
best_thre, max_acc, max_p, max_r, max_f1 = find_best_threshold(npys, (0.1, 0.5), step_num=10, generation=4, log=True)
print("note level evaluation at best threshold:")
best_onset_thre, note_p, note_r, note_f, note_overlap = find_best_onset_threshold(npys, best_thre, (0.2, 0.6), step_num=10, generation=4, log=True)

output_results.extend([best_thre, max_acc, max_p, max_r, max_f1, best_onset_thre, note_p, note_r, note_f, note_overlap])

In [None]:
# 看BACH10独奏的结果
npyfolder = f"{model_folder_name}/BACH10_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if not f.endswith("0.npy")]
best_thre, max_acc, max_p, max_r, max_f1 = find_best_threshold(npys, (0.4, 0.5), step_num=10, generation=4, log=True)
print("note level evaluation at best threshold:")
best_onset_thre, note_p, note_r, note_f, note_overlap = find_best_onset_threshold(npys, best_thre, (0.2, 0.6), step_num=10, generation=4, log=True)

output_results.extend([best_thre, max_acc, max_p, max_r, max_f1, best_onset_thre, note_p, note_r, note_f, note_overlap])

In [None]:
# 看PHENICX合奏音频的结果
npyfolder = f"{model_folder_name}/PHENICX_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith(".npy")]
best_thre, max_acc, max_p, max_r, max_f1 = find_best_threshold(npys, (0.02, 0.2), step_num=10, generation=4, log=True)
print("note level evaluation at best threshold:")
best_onset_thre, note_p, note_r, note_f, note_overlap = find_best_onset_threshold(npys, best_thre, (0.13, 0.6), step_num=10, generation=4, log=True)

output_results.extend([best_thre, max_acc, max_p, max_r, max_f1, best_onset_thre, note_p, note_r, note_f, note_overlap])

In [None]:
# 只看URMP的合奏结果
npyfolder = f"{model_folder_name}/URMP_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if f.endswith("0.npy")]
best_thre, max_acc, max_p, max_r, max_f1 = find_best_threshold(npys, (0.1, 0.5), step_num=10, generation=4, log=True)
print("note level evaluation at best threshold:")
best_onset_thre, note_p, note_r, note_f, note_overlap = find_best_onset_threshold(npys, best_thre, (0.18, 0.6), step_num=10, generation=4, log=True)

output_results.extend([best_thre, max_acc, max_p, max_r, max_f1, best_onset_thre, note_p, note_r, note_f, note_overlap])

In [None]:
# 只看URMP的独奏结果
npyfolder = f"{model_folder_name}/URMP_eval"
npys = [os.path.join(npyfolder, f) for f in os.listdir(npyfolder) if not f.endswith("0.npy")]
best_thre, max_acc, max_p, max_r, max_f1 = find_best_threshold(npys, (0.3, 0.4), step_num=10, generation=4, log=True)
print("note level evaluation at best threshold:")
best_onset_thre, note_p, note_r, note_f, note_overlap = find_best_onset_threshold(npys, best_thre, (0.18, 0.6), step_num=10, generation=4, log=True)

output_results.extend([best_thre, max_acc, max_p, max_r, max_f1, best_onset_thre, note_p, note_r, note_f, note_overlap])

In [None]:
# 输出参数数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model.cqt))
print(count_parameters(model))

In [None]:
print('|'.join([f"{x:.5f}" for x in output_results]))