In [None]:
import torch
import sys
sys.path.append("../model")
model = torch.load("sepamt_model.pth", map_location=torch.device('cpu'), weights_only=False)
s_per_frame = 256 / 22050

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from utils.midiarray import midi2numpy

def draw_midi_with_channel(ax, midiarr, colors=['Reds', 'Blues'], labels=['Piano', 'Violin']):
    """
    用不同颜色绘制midiarray
    ax: 通常是plt的subplot的一张图
    midiarr: 3D numpy array, shape = (n_channel, n_time, n_pitch) 或者midi路径
    colors: 颜色列表，名字需要和plt.cm中的颜色一致
    """

    # 从文件打开midi
    if isinstance(midiarr, str):
        midiarr = midi2numpy(midiarr, s_per_frame, track_separate=True)
        if len(midiarr.shape) != 3:
            raise ValueError("midiarr must be a 3D numpy array")

    Colors = [plt.get_cmap(color) if isinstance(color, str) else color for color in colors]
    # 白色背景
    background = np.ones_like(midiarr[0])
    ax.imshow(background, cmap='gray_r', aspect='auto', origin='lower')

    for i, ch in enumerate(midiarr):
        # 预处理红色通道
        red_data = np.zeros_like(ch, dtype=float)
        alpha_red = np.zeros_like(ch, dtype=float)
        # 设置不同透明度
        red_data[ch > 0] = 1.0  # 红色值
        alpha_red[ch == 1] = 0.6  # 半透明
        alpha_red[ch == 2] = 1.0  # 不透明
        # 绘制红色（使用 RGBA 数组）
        red_rgba = Colors[i](red_data)  # 获取 RGBA 颜色
        red_rgba[..., 3] = alpha_red  # 修改透明度通道
        ax.imshow(red_rgba, aspect='auto', origin='lower')

    if len(labels) != len(colors):
        return
    legend_elements = [Patch(facecolor=Colors[i](0.8), alpha=0.6, label=labels[i]) for i in range(len(labels))]
    ax.legend(handles=legend_elements, loc='upper right')

# 基本使用

In [None]:
# 读取音频
import torchaudio
from utils.wavtool import waveInfo
from utils.midiarray import midi2numpy
import os

# test_wave_path = "../data/inferMusic/short mix.wav"
# test_wave_path = "../data/inferMusic/three_mix.wav"
test_wave_path = "../data/inferMusic/flute_violin_piano.wav"
# test_wave_path = "../data/inferMusic/flute_guitar_Humoresque.wav"
# test_wave_path = "../data/inferMusic/flute_violin.wav"
# test_wave_path = "../data/inferMusic/孤独な巡礼simple.wav"

# 尝试获取midi
test_midi = os.path.splitext(test_wave_path)[0] + ".mid"
if os.path.exists(test_midi):
    test_midi = midi2numpy(test_midi, s_per_frame, track_separate=True)
else:
    test_midi = None

info = waveInfo(test_wave_path)

waveform, sample_rate = torchaudio.load(test_wave_path, normalize=True)
waveform = waveform.unsqueeze(0)
if info["sample_rate"] > 44000:
    waveform = model.cqt.down2sample(waveform)
    print(f"$ downsampled to {info["sample_rate"]//2}Hz")
print(waveform.shape)

model.eval()
with torch.no_grad():
    onset, mask, emb = model(waveform)
    emb = emb / torch.sqrt(emb.pow(2).sum(dim=1, keepdim=True) + 1e-8)
    emb = emb.cpu().numpy()[0]      # (18, 84, frame)
    mask = mask.cpu().numpy()[0]    # (84, frame)
    onset = onset.cpu().numpy()[0]

if test_midi is not None:
    emb = emb[:, :, :test_midi.shape[2]]  # 截取到和midi一样长
    mask = mask[:, :test_midi.shape[2]]    # 截取到和midi一样长
    onset = onset[:, :test_midi.shape[2]]  # 截取到和midi一样长

## 用聚类
需要明确类别数目

In [None]:
from sklearn.cluster import SpectralClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt

N = 3 if test_midi is None else len(test_midi)  # 聚类数目

# mask大于阈值的数目记为n
positions = np.where(mask > 0.5)
emb_extracted = emb[:, positions[0], positions[1]].T        # (n, 18)

# 计算余弦相似度矩阵
similarity_matrix = cosine_similarity(emb_extracted)

# 进行谱聚类
spectral = SpectralClustering(n_clusters=N, affinity='precomputed', assign_labels="cluster_qr")
labels = spectral.fit_predict(np.exp(similarity_matrix))

pre_figures = 2 + (0 if test_midi is None else 1)
sub_figures = N + pre_figures

plt.figure(figsize=(10, 5*sub_figures))

if test_midi is not None:
    plt.subplot(sub_figures, 1, 1)
    plt.title('midi')
    draw_midi_with_channel(plt.gca(), test_midi, colors=['Reds', 'Blues', 'Greens'], labels=[])

plt.subplot(sub_figures, 1, pre_figures - 1)
plt.title('note')
plt.imshow(mask, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(sub_figures, 1, pre_figures)
plt.title('onset')
plt.imshow(onset, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

clustered_classes = []
for i in range(N):
    class_i = np.zeros(mask.shape)
    class_i[positions[0], positions[1]] = (labels == i).astype(int)
    clustered_classes.append(class_i)
    plt.subplot(sub_figures, 1, i + pre_figures + 1)
    plt.title(f'class{i}')
    plt.imshow(class_i, aspect='auto', origin='lower', cmap='gray')

plt.tight_layout()
plt.show()

## 音符层级的聚类
先获取“音色无关转录”结果，并进行音符创建后处理，然后音符内用softmax平均得到特征

In [None]:
# 必须先获取了转录结果
from sklearn.cluster import SpectralClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt
from utils.postprocess import output_to_notes_polyphonic

mask = mask / mask.max()  # 归一化到0-1
onset = onset / onset.max()  # 归一化到0-1

note_events = output_to_notes_polyphonic(
    mask, onset,
    frame_thresh = 0.3,
    onset_thresh = 0.4,
    neighbor_trick = False,
    midi_offset = 0 # 为了提取对应位置，不做偏移
)

embeddings = []
for start, end, f, amp in note_events:
    _mask = mask[f, start:end]  # (end-start, )
    _emb = emb[:, f, start:end] # (18, end-start)
    # weight = np.ones_like(_mask)
    weight = _mask
    # weight = _mask * _mask  # 发现用平方比用一次方好
    # weight = np.sqrt(_mask)
    # weight = np.exp(_mask * _mask)    # 用softmax，稍微强调一下最强的
    weighted_emb = (_emb * weight).sum(axis=1)  # (18, )
    normalized_emb = weighted_emb / np.linalg.norm(weighted_emb)
    embeddings.append(normalized_emb)

# N = 2
N = 3 if test_midi is None else len(test_midi)  # 聚类数目

similarity_matrix = cosine_similarity(embeddings)
spectral = SpectralClustering(n_clusters=N, affinity='precomputed', assign_labels="cluster_qr")
labels = spectral.fit_predict(np.exp(similarity_matrix))

# 画图 如果是两列，一张图大小用(6,3.5)合适，字的大小合适
pre_figures = 2 + (0 if test_midi is None else 1)
sub_figures = N + pre_figures

plt.figure(figsize=(10, 5*sub_figures))

if test_midi is not None:
    plt.subplot(sub_figures, 1, 1)
    plt.title('midi')
    draw_midi_with_channel(plt.gca(), test_midi, colors=['Reds', 'Blues', 'Greens'], labels=[])

plt.subplot(sub_figures, 1, pre_figures - 1)
plt.title('note')
plt.imshow(mask, aspect='auto', origin='lower', cmap='gray')
# plt.colorbar()    # 缩放了不需要了

plt.subplot(sub_figures, 1, pre_figures)
plt.title('onset')
plt.imshow(onset, aspect='auto', origin='lower', cmap='gray')
# plt.colorbar()

clustered_classes = []
for i in range(N):
    class_i = np.zeros(mask.shape)
    indices = np.where(labels == i)[0]
    for idx in indices:
        start, end, f, amp = note_events[idx]
        class_i[f, start+1:end] = amp
        class_i[f, start] = 2 * amp # onset
    clustered_classes.append(class_i)
    plt.subplot(sub_figures, 1, i + pre_figures + 1)
    plt.title(f'class{i}')
    plt.imshow(class_i, aspect='auto', origin='lower', cmap='hot')

plt.tight_layout()
plt.show()

In [None]:
# 输出为midi
# 根据 labels 将 note_events 分为 N 类
note_classes = [[] for _ in range(N)]
for idx, label in enumerate(labels):
    # 将音高加上24
    start, end, pitch, amp = note_events[idx]
    note_classes[label].append((start, end, pitch + 24, amp))

from utils.midiarray import notes2midi, midi_merge
midis = [notes2midi(note_class) for note_class in note_classes]
midi_merge(midis).save("../test/clustered_output.mid")

# 取代聚类

In [None]:
frame_threshold = 0.2

def iterate_filter(mask, emb, threshold=0.):
    # 找到 mask 中最大值的位置
    max_position = np.unravel_index(np.argmax(mask, axis=None), mask.shape)
    # 提取对应的 emb 值
    max_emb = emb[:, max_position[0], max_position[1]]

    similarity_scores = np.tensordot(max_emb, emb*mask, axes=([0], [0])).reshape(emb.shape[1], emb.shape[2])

    mask2 = mask - similarity_scores
    # clip to avoid ramaining classes getting close
    similarity_scores_clipped = np.where(similarity_scores > threshold, similarity_scores, 0)
    emb2 = emb - similarity_scores_clipped * max_emb[:, None, None]
    emb2 = emb2 / (np.linalg.norm(emb2, axis=0, keepdims=True) + 1e-6)
    return similarity_scores, mask2, emb2

In [None]:
similarity_scores1, mask1, emb1 = iterate_filter(mask, emb, frame_threshold)

plt.figure(figsize=(12, 15))
plt.subplot(3, 1, 1)
plt.title('class_max')
plt.imshow(similarity_scores1, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 2)
plt.imshow(similarity_scores1 > frame_threshold, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 3)
plt.imshow(mask1, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.title("after iteration 1")
plt.show()

In [None]:
similarity_scores2, mask2, emb2 = iterate_filter(mask1, emb1, frame_threshold)

plt.figure(figsize=(12, 15))
plt.subplot(3, 1, 1)
plt.title('class_max')
plt.imshow(similarity_scores2, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 2)
plt.imshow(similarity_scores2 > frame_threshold, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 3)
plt.imshow(mask2, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.title("after iteration 2")
plt.show()

In [None]:
similarity_scores3, mask3, emb3 = iterate_filter(mask2, emb2, frame_threshold)

plt.figure(figsize=(12, 15))
plt.subplot(3, 1, 1)
plt.title('class_max')
plt.imshow(similarity_scores3, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 2)
plt.imshow(similarity_scores3 > frame_threshold, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 3)
plt.imshow(mask3, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.title("after iteration 3")
plt.show()

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(12, 10), sharex=True, sharey=True)

axs[0, 0].set_title('1. AMT result')
axs[0, 0].imshow(mask, aspect='auto', origin='lower')

axs[1, 0].set_title('2. similarity to max')
axs[1, 0].imshow(similarity_scores1, aspect='auto', origin='lower')

axs[2, 0].set_title('3. note binarization')
axs[2, 0].imshow(similarity_scores1 > frame_threshold, aspect='auto', origin='lower')

axs[0, 1].set_title('4. probability correction')
axs[0, 1].imshow(mask2, aspect='auto', origin='lower')

axs[1, 1].set_title('5. similarity to max')
axs[1, 1].imshow(similarity_scores2, aspect='auto', origin='lower')

axs[2, 1].set_title('6. note binarization')
axs[2, 1].imshow(similarity_scores2 > 0.3, aspect='auto', origin='lower')

plt.tight_layout()
plt.show()


## 导出为ONNX

In [None]:
import torch
import torch.nn as nn
from septimbre import SepTimbreAMT_44100, Encoder_44100

model_44100 = SepTimbreAMT_44100(model)

In [None]:
# 导出为ONNX
model_44100.eval()

input_audio = torch.randn((1, 1, 22050), dtype=torch.float32)    # (fixed, fixed, dynamic)

with torch.no_grad():
    torch.onnx.export(
        model_44100,
        (input_audio,),
        'septimbre_44100.onnx',
        input_names = ["audio"],
        output_names = ["onset","frame","embedding"],
        dynamic_shapes = {
            'x': {2:'time'},    # same as model.forward input
            # auto infer output shapes
        },
        dynamo=True,
        verbose=False,
        verify=True,
        external_data=False,
        autograd_inlining=False
    )

### 先测试FLOPS RTFX 等指标
Real-Time Factor(RTFX), FLOPS, actual inference time, memory peak

In [None]:
uv pip install memory_profiler

In [None]:
# 使用onnxruntime推理
import onnxruntime as ort
import numpy as np
import torchaudio
import time
import psutil
import os
from memory_profiler import memory_usage
import sys
sys.path.append("..")
from utils.postprocess import cluster_notes
from utils.midiarray import notes2midi

input_wave_path = "../data/inferMusic/flute_violin.wav"  # 5min, 44100Hz
N_clusters = 2
ort_session = ort.InferenceSession("septimbre_44100.onnx") # 创建一个推理session

# 加载音频
waveform, sample_rate = torchaudio.load(input_wave_path, normalize=True)
waveform = waveform.unsqueeze(0).cpu()
if waveform.shape[1] > 1:
    waveform = waveform.mean(1, keepdim=True)
waveform = waveform.numpy().astype(np.float32)
wave_duration = waveform.shape[2] / sample_rate

# 监控内存
inference_time = 0
def run_onnx():
    global inference_time
    t0 = time.time()
    outputs = ort_session.run(None, {'audio': waveform})
    t1 = time.time()
    print(f"inference time inside function: {t1 - t0:.3f} seconds")
    onset = outputs[0][0]   # (84, time)
    frame = outputs[1][0]    # (84, time)
    emb = outputs[2][0]     # (12, 84, time)
    t2 = time.time()
    note_events = cluster_notes(
        frame, onset, emb, N_clusters, 24,
        frame_thresh=0.32,
        onset_thresh=0.4,
        neighbor_trick=False
    )
    t3 = time.time()
    print(f"post-processing time inside function: {t3 - t2:.3f} seconds")
    inference_time = (t1 - t0) + (t3 - t2)
    return note_events

mem_usage, outputs = memory_usage((run_onnx, ), retval=True, interval=0.1)

print(f"inference time: {inference_time:.3f} seconds for {wave_duration:.3f} seconds audio with sampleRate {sample_rate} Hz, real-time factor: {inference_time/wave_duration:.4f}")

notes2midi(outputs).save("../test/septimbre.mid")

### 专门导出encoder部分

In [None]:
encoder_44100 = Encoder_44100(model).eval()

input_audio = torch.randn((1, 1, 22050), dtype=torch.float32)    # (fixed, fixed, dynamic)

with torch.no_grad():
    torch.onnx.export(
        encoder_44100,
        (input_audio,),
        'septimbre_encoder_44100.onnx',
        input_names = ["audio"],
        output_names = ["embedding"],
        dynamic_shapes = {
            'x': {2:'time'},    # same as model.forward input
            # auto infer output shapes
        },
        dynamo=True,
        verbose=False,
        verify=True,
        external_data=False,
        autograd_inlining=False
    )