# 阈值和频谱均值方差有关吗？
省流：没关系

In [None]:
import sys
sys.path.append('..')

from model.CQT import CQTsmall_fir
from model.config import CONFIG
from model.layers import EnergyNorm
import torch
class doCQT(torch.nn.Module):
    def __init__(self):
        super(doCQT, self).__init__()
        self.cqt = CQTsmall_fir(config=CONFIG.CQT)
        self.norm = EnergyNorm(output_type=1, log_scale=False)
    def forward(self, x):
        cqt_spec = self.cqt(x)
        normed_spec = self.norm(cqt_spec)
        return normed_spec
spectrum = doCQT()
spectrum.eval()

In [None]:
import os
from instrument_agnostic_eval_utils import *
import sys
sys.path.append('..')

s_per_frame = 256 / 22050

import sys
model_folder_name = "basicamt"
sys.path.append(f'../{model_folder_name}')
model = torch.load(f"../{model_folder_name}/best_basicamt_model.pth", weights_only=False)
model.eval()

In [None]:
from instrument_agnostic_eval_utils import frame_eval

def find_best_threshold(pred, midi, step_num=10, generation=4):
    start = 0.02
    end = 0.98
    step = (end - start) / (step_num - 1)

    best_thre = -1
    max_f1 = -1

    for g in range(generation):
        best_thre_idx = -1
        last_f1 = -1
        for i in range(step_num):
            thre = start + i * step
            result = frame_eval(pred, midi, thre)
            p = result['Precision']
            r = result['Recall']
            f1 = 2*p*r/(p+r) if p+r > 0 else 0
            if f1 > max_f1:
                max_f1 = f1
                best_thre = thre
                best_thre_idx = i
            if f1 < last_f1:
                break
            last_f1 = f1
        if best_thre_idx == -1:
            start = max(0, best_thre - step)
            end = best_thre + step
            step = (end - start) / (step_num + 1)
            start += step
        else:
            start = max(0, start + (best_thre_idx - 1) * step)
            end = start + 2 * step
            # start = max(0, thresholds[best_thre_idx] - step)
            # end = thresholds[best_thre_idx] + step
            step = (end - start) / (step_num + 1)
            start += step

    return best_thre, max_f1

def getInfo(wave: torch.Tensor, midi):
    # 输入wave和midi，返回评估所需的信息
    with torch.no_grad():
        spec = spectrum(wave).squeeze(0).squeeze(0).cpu().numpy() # (freq_bin, time_frame)
        onset, pred_activation = model(wave)
        pred_activation = pred_activation.squeeze(0).cpu().numpy()  # (88, time_frame)
    
    freqs, times = pred_activation.shape
    if midi.shape[1] < times:
        padding = np.zeros((freqs, times - midi.shape[1]))
        midi = np.concatenate((midi, padding), axis=1)
    elif midi.shape[1] > times:
        midi = midi[:, :times]
    
    best_thre, _ = find_best_threshold(pred_activation, midi, step_num=10, generation=4)
    std = np.std(spec)
    mean = np.mean(spec)
    return best_thre, std, mean

In [None]:
import torchaudio
from utils.midiarray import midi2numpy
def process_dataset(data_folder = "BACH10_processed"):
    best_thres = []
    stds = []
    means = []
    sub_folders = [os.path.join(data_folder, f) for f in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, f))]
    for sub_folder in sub_folders:
        wav_path = next((os.path.join(sub_folder, f) for f in os.listdir(sub_folder) if f.endswith('.wav')), None)
        mid_path = next((os.path.join(sub_folder, f) for f in os.listdir(sub_folder) if f.endswith('.mid')), None)
        wave_data, fs = torchaudio.load(wav_path) # type: ignore
        # (1, N)
        midi_data = midi2numpy(mid_path, s_per_frame) # type: ignore
        # (84, T)
        best_thre, std, mean = getInfo(wave_data[0], midi_data)
        best_thres.append(best_thre)
        stds.append(std)
        means.append(mean)
    return best_thres, stds, means

In [None]:
best_thres, stds, means = process_dataset("URMP_processed")

In [None]:
','.join(map(str, best_thres))

In [None]:
','.join(map(str, stds))

In [None]:
','.join(map(str, means))