因为basicamt用的是两轮训练，第二轮没有预计算CQT（因为CQT层加入了训练）；而basicpitch论文中直接用了CQT结果（且CQT不是训练参数），所以应该用basicamt的第二轮训练的数据，这个文件计算了CQT

In [None]:
import torch
import torchaudio
import numpy as np

import sys
sys.path.append("..")

device = torch.device('cpu')
print(device)

In [None]:
# CQT配置
import tomllib
with open('../model/config.toml', 'br') as f:
    CQTconfig = tomllib.load(f)['CQT']
s_per_frame = CQTconfig['hop'] / CQTconfig['fs']

from model.CQT import CQTsmall
from model.layers import EnergyNorm

cqt = CQTsmall(
    CQTconfig['fs'],
    fmin=CQTconfig['fmin'],
    octaves=CQTconfig['octaves'],
    bins_per_octave=CQTconfig['bins_per_octave'],
    hop=CQTconfig['hop'],
    filter_scale=CQTconfig['filter_scale'],
    requires_grad=False
).to(device)

norm = EnergyNorm(output_type=0)


In [None]:
import os
data_folder = "../data/septimbre/multi_large_256"

for sub_dir in os.listdir(data_folder):
    sub_dir = os.path.join(data_folder, sub_dir)
    if not os.path.isdir(sub_dir):
        continue
    for file_name in os.listdir(sub_dir):
        if file_name.endswith(".wav"):
            midi_id = os.path.splitext(file_name)[0]
            wav_name = os.path.join(sub_dir, file_name)
            # 获取npy
            npy_name = os.path.join(sub_dir, file_name.replace(".wav", ".npy"))
            midiarray = np.load(npy_name)
            frames = midiarray.shape[1]
            # 计算CQT
            cqt_name = os.path.join(sub_dir, file_name.replace(".wav", ".cqt.npy"))
            waveform, sample_rate = torchaudio.load(wav_name, normalize=True)
            waveform = waveform.unsqueeze(0)    # 增加batch维
            cqt_data = norm(cqt(waveform)).squeeze(0).numpy() # 去掉batch维 [2, 288, time]
            np.save(cqt_name, cqt_data[:, :, :frames])  # 截取前frames帧


In [None]:
# 可视化
import matplotlib.pyplot as plt
cqt_data = np.load("../data/septimbre/multi_large_256/inst0/0.cqt.npy")
np_data = np.load("../data/septimbre/multi_large_256/inst0/0.npy")

cqt_data = np.sqrt(cqt_data[0]**2 + cqt_data[1]**2)

# 绘图
plt.figure(figsize=(14, 15))

plt.subplot(2, 1, 1)
plt.imshow(np_data, aspect='auto', origin='lower', cmap='gray')
plt.title('Random Piano Roll Data')
plt.xlabel('Time Frame')
plt.ylabel('MIDI Note')
plt.gcf().set_size_inches(14, 12)  # 设置图像大小

plt.subplot(2, 1, 2)
plt.imshow(cqt_data, aspect='auto', origin='lower', cmap='hot')
plt.title('Random CQT Data')
plt.xlabel('Time Frame')
plt.ylabel('Frequency Bin')
plt.gcf().set_size_inches(14, 13)  # 设置图像大小
plt.colorbar()

plt.tight_layout()
plt.show()

print(cqt_data.shape, np_data.shape)