In [None]:
import torch
import sys
sys.path.append("../model")
model = torch.load("cluster_model_1.pth", map_location=torch.device('cpu'))
s_per_frame = 256 / 22050

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from utils.midiarray import midi2numpy

def draw_midi_with_channel(ax, midiarr, colors=['Reds', 'Blues'], labels=['Piano', 'Violin']):
    """
    用不同颜色绘制midiarray
    ax: 通常是plt的subplot的一张图
    midiarr: 3D numpy array, shape = (n_channel, n_time, n_pitch) 或者midi路径
    colors: 颜色列表，名字需要和plt.cm中的颜色一致
    """

    # 从文件打开midi
    if isinstance(midiarr, str):
        midiarr = midi2numpy(midiarr, s_per_frame, track_separate=True)
        if len(midiarr.shape) != 3:
            raise ValueError("midiarr must be a 3D numpy array")

    Colors = [plt.get_cmap(color) if isinstance(color, str) else color for color in colors]
    # 白色背景
    background = np.ones_like(midiarr[0])
    ax.imshow(background, cmap='gray_r', aspect='auto', origin='lower')

    for i, ch in enumerate(midiarr):
        # 预处理红色通道
        red_data = np.zeros_like(ch, dtype=float)
        alpha_red = np.zeros_like(ch, dtype=float)
        # 设置不同透明度
        red_data[ch > 0] = 1.0  # 红色值
        alpha_red[ch == 1] = 0.6  # 半透明
        alpha_red[ch == 2] = 1.0  # 不透明
        # 绘制红色（使用 RGBA 数组）
        red_rgba = Colors[i](red_data)  # 获取 RGBA 颜色
        red_rgba[..., 3] = alpha_red  # 修改透明度通道
        ax.imshow(red_rgba, aspect='auto', origin='lower')

    if len(labels) != len(colors):
        return
    legend_elements = [Patch(facecolor=Colors[i](0.8), alpha=0.6, label=labels[i]) for i in range(len(labels))]
    ax.legend(handles=legend_elements, loc='upper right')

# 基本使用

In [None]:
# 读取音频
import torchaudio
from utils.wavtool import waveInfo
from utils.midiarray import midi2numpy
import os

# test_wave_path = "../data/inferMusic/short mix.wav"
test_wave_path = "../data/inferMusic/three_mix.wav"
# test_wave_path = "../data/inferMusic/孤独な巡礼simple.wav"

# 尝试获取midi
test_midi = test_wave_path.replace(".wav", ".mid")
if os.path.exists(test_midi):
    test_midi = midi2numpy(test_midi, s_per_frame, track_separate=True)
else:
    test_midi = None

info = waveInfo(test_wave_path)

waveform, sample_rate = torchaudio.load(test_wave_path, normalize=True)
waveform = waveform.unsqueeze(0)
if info["sample_rate"] > 44000:
    waveform = model.cqt.down2sample(waveform)
    print(f"$ downsampled to {info["sample_rate"]//2}Hz")
print(waveform.shape)

## 用聚类
需要明确类别数目

In [None]:
from sklearn.cluster import SpectralClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import matplotlib.pyplot as plt

N = 3
# 假设 model 和 test_cqt_data 已经定义
model.eval()
with torch.no_grad():
    emb, mask, onset = model(waveform)
    emb = emb / torch.sqrt(emb.pow(2).sum(dim=1, keepdim=True) + 1e-8)
    emb = emb.cpu().numpy()[0]      # (18, 84, frame)
    mask = mask.cpu().numpy()[0]    # (84, frame)
    onset = onset.cpu().numpy()[0]

if test_midi is not None:
    emb = emb[:, :, :test_midi.shape[2]]  # 截取到和midi一样长
    mask = mask[:, :test_midi.shape[2]]    # 截取到和midi一样长
    onset = onset[:, :test_midi.shape[2]]  # 截取到和midi一样长

# mask大于阈值的数目记为n
positions = np.where(mask > 0.55)
emb_extracted = emb[:, positions[0], positions[1]].T        # (n, 18)

# 计算余弦相似度矩阵
similarity_matrix = cosine_similarity(emb_extracted)

# 进行谱聚类
spectral = SpectralClustering(n_clusters=N, affinity='precomputed', assign_labels="cluster_qr")
labels = spectral.fit_predict(np.exp(similarity_matrix))

pre_figures = 2 + (0 if test_midi is None else 1)
sub_figures = N + pre_figures

plt.figure(figsize=(10, 5*sub_figures))

if test_midi is not None:
    plt.subplot(sub_figures, 1, 1)
    plt.title('midi')
    draw_midi_with_channel(plt.gca(), test_midi, colors=['Reds', 'Blues', 'Greens'], labels=[])

plt.subplot(sub_figures, 1, pre_figures - 1)
plt.title('note')
plt.imshow(mask, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(sub_figures, 1, pre_figures)
plt.title('onset')
plt.imshow(onset, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

clustered_classes = []
for i in range(N):
    class_i = np.zeros(mask.shape)
    class_i[positions[0], positions[1]] = (labels == i).astype(int)
    clustered_classes.append(class_i)
    plt.subplot(sub_figures, 1, i + pre_figures + 1)
    plt.title(f'class{i}')
    plt.imshow(class_i, aspect='auto', origin='lower', cmap='gray')

plt.tight_layout()
plt.show()

# 取代聚类

In [None]:
# 找到 mask 中最大值的位置
max_position = np.unravel_index(np.argmax(mask, axis=None), mask.shape)
# 提取对应的 emb 值
max_emb = emb[:, max_position[0], max_position[1]]

similarity_scores = np.tensordot(max_emb, emb*mask, axes=([0], [0])).reshape(emb.shape[1], emb.shape[2])

plt.figure(figsize=(12, 10))
plt.subplot(2, 1, 1)
plt.title('class_max')
plt.imshow(similarity_scores, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(2, 1, 2)
plt.imshow(similarity_scores > 0.5, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.show()

In [None]:
mask2 = mask - similarity_scores
s = similarity_scores#[similarity_scores < 0.01] = 0
emb2 = emb - s * max_emb[:, None, None]

plt.figure(figsize=(12, 5))
plt.title('mask-class_max')
plt.imshow(mask2, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.show()

In [None]:
# 找到 mask 中最大值的位置
max_position2 = np.unravel_index(np.argmax(mask2, axis=None), mask.shape)
# 提取对应的 emb 值
max_emb2 = emb[:, max_position2[0], max_position2[1]]

similarity_scores2 = np.tensordot(max_emb2, emb*mask, axes=([0], [0])).reshape(emb.shape[1], emb.shape[2])

plt.figure(figsize=(12, 10))
plt.subplot(2, 1, 1)
plt.title('class_max')
plt.imshow(similarity_scores2, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(2, 1, 2)
plt.imshow(similarity_scores2 > 0.3, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.show()