In [None]:
import torch
import os
import tqdm

import sys
sys.path.append("..")
from utils.checkpoint import save_ckpt_template as save_ckpt, load_ckpt_template as load_ckpt

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True   # 加速卷积计算
else:
    device = torch.device('cpu')

print(device)

In [None]:
from data.septimbre.data import Instruments
dataset = Instruments(
    folder = '../data/septimbre/small_256',
    mix = 2,
    input = '.cqt.npy',
    output = '.npy'
)
print("训练集大小: ", len(dataset))
input, label = dataset[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
# 如果开了benchmark，batchsize最好一样大，且是两个数据集大小的公因数
dataloader = torch.utils.data.DataLoader(dataset, batch_size=18, shuffle=True, pin_memory=True, num_workers=4)

val = Instruments(
    folder = '../data/septimbre/tiny_256',
    mix = 2,
    input = '.cqt.npy',
    output = '.npy'
)
print("测试集大小: ", len(val))
input, label = val[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
valloader = torch.utils.data.DataLoader(val, batch_size=18, shuffle=False, pin_memory=True, num_workers=4)

In [None]:
from resepnet import Cluster
model = Cluster().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.4, patience=3, threshold=1e-3)
checkpoint_path = "cluster.pth"
loss_path = "cluster.loss.txt"
min_loss = float('inf')
epoch_now = 0

In [None]:
# 加载上一轮模型参数
states = load_ckpt(model, optimizer, checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No checkpoint loaded. Training from scratch.")

In [None]:
from model.loss import LossNorm, DWA, PCGrad
r_cluster = 0.01
lossnorm = LossNorm(0.9)
dwa = DWA(2)
# dwa.set_init([r_cluster, 1])
pcgrad = PCGrad(optimizer)
def MTL_bp(cluster_loss, amt_loss, use_lossnorm = True, use_dwa = True, use_pcgrad = True):
    losses = torch.stack([cluster_loss, amt_loss])
    if use_lossnorm:
        losses = lossnorm(losses) * losses
    if use_dwa:
        losses = dwa(losses) * losses
    if use_pcgrad:
        pcgrad.pc_backward(losses)
        pcgrad.step()
        pcgrad.zero_grad()
    else:
        losses.sum().backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
epoch_total = 10
stage_save_interval = 75    # save checkpoint every $stage_save_interval$ epochs

optimizer.zero_grad()
for epoch in range(epoch_now+1, epoch_total):
    # training
    model.train()
    train_loss = [0, 0]
    virtual_batch = 0
    for (input, target) in tqdm.tqdm(dataloader):
        input = input.to(device)    # input: (batch, mix, 2, time)
        target = target.to(device)  # target: (batch, mix, 7 * 12, 660)

        if dataset.mix > 1:
            mixed = torch.mean(input, dim=1, keepdim=False)  # 混合后的CQT
        else:
            mixed = input.squeeze(dim=1)

        mixed = mixed + torch.randn_like(mixed) * 0.003   # 加入噪声 经过听觉和可视化觉得0.01是一个合适的值

        emb, mask, onset = model(mixed)
        emb = emb / torch.sqrt(emb.pow(2).sum(dim=1, keepdim=True) + 1e-8)
        clloss, amtloss = Cluster.loss(emb, mask, onset, target)
        train_loss[0] += clloss.item()
        train_loss[1] += amtloss.item()
        MTL_bp(clloss * 0.01, amtloss, use_pcgrad=False)
        # loss = clloss * r_cluster + amtloss
        # loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()

    train_loss[0] /= len(dataset)
    train_loss[1] /= len(dataset)

    # validation
    model.eval()
    val_loss = [0, 0]
    with torch.no_grad():
        for (input, target) in tqdm.tqdm(valloader):
            input = input.to(device)
            target = target.to(device)

            if val.mix > 1:
                mixed = torch.sum(input, dim=1, keepdim=False)
            else:
                mixed = input.squeeze(dim=1)

            emb, mask, onset = model(mixed)
            emb = emb / torch.sqrt(emb.pow(2).sum(dim=1, keepdim=True) + 1e-8)
            clloss, amtloss = Cluster.loss(emb, mask, onset, target)
            val_loss[0] += clloss.item()
            val_loss[1] += amtloss.item()

        val_loss[0] /= len(val)
        val_loss[1] /= len(val)
    
    # train_loss_sum = train_loss[0] * r_cluster + train_loss[1]
    # val_loss_sum = val_loss[0] * r_cluster + val_loss[1]
    train_loss_sum = LossNorm.norm_sum(torch.tensor([train_loss[0]/100, train_loss[1]])).item()
    val_loss_sum = LossNorm.norm_sum(torch.tensor([val_loss[0]/100, val_loss[1]])).item()

    final_loss = val_loss_sum * 0.85 + train_loss_sum * 0.15
    schedular.step(final_loss)

    # save checkpoint
    checkpoint_filename = f"epoch{epoch}.pth" if epoch % stage_save_interval == 0 else checkpoint_path
    save_ckpt(epoch, model, min_loss, final_loss, optimizer, checkpoint_filename)
    if final_loss < min_loss:
        min_loss = final_loss

    print(f"====> Epoch: {epoch} Average train loss: {train_loss_sum:.4f} = {train_loss[0]:.4f} & {train_loss[1]:.4f}; Average val loss: {val_loss_sum:.4f} = {val_loss[0]:.4f} & {val_loss[1]:.4f}")
    with open(loss_path, 'a') as f:
        f.write(f"{epoch}:\ttrain_loss: {train_loss_sum:.4f}\tval_loss: {val_loss_sum:.4f}\t{train_loss[0]:.4f}\t{train_loss[1]:.4f}\t{val_loss[0]:.4f}\t{val_loss[1]:.4f}\n")
    epoch_now = epoch

In [None]:
# 绘制训练和验证损失曲线
import matplotlib.pyplot as plt

epochs = []
train_losses = []
val_losses = []

with open(loss_path, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        epoch = int(parts[0].split(':')[0])
        train_loss = float(parts[1].split(': ')[1])
        val_loss = float(parts[2].split(': ')[1])
        
        epochs.append(epoch)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# 测试模型

In [None]:
# 加载最佳模型
states = load_ckpt(model, optimizer, "best_" + checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No best checkpoint loaded.")

In [None]:
# 切换到cpu 因为benchmark会导致初次运行时间长
device = torch.device('cpu')
model = model.to(device)
torch.backends.cudnn.benchmark = False

In [None]:
# CQT配置
import tomllib
with open('../model/config.toml', 'br') as f:
    CQTconfig = tomllib.load(f)['CQT']
s_per_frame = CQTconfig['hop'] / CQTconfig['fs']

from model.CQT import CQTsmall_fir

cqt = CQTsmall_fir(
    False,
    fs = CQTconfig['fs'],
    fmin = CQTconfig['fmin'],
    octaves = CQTconfig['octaves'],
    bins_per_octave = CQTconfig['bins_per_octave'],
    hop = CQTconfig['hop'],
    filter_scale = CQTconfig['filter_scale'],
    requires_grad = True
).to(device)

In [None]:
# 读取音频，分析为CQT
import torchaudio
from utils.midiarray import numpy2midi
from utils.wavtool import waveInfo
import matplotlib.pyplot as plt
import numpy as np

test_wave_path = "../data/inferMusic/short mix.wav"
waveInfo(test_wave_path)

waveform, sample_rate = torchaudio.load(test_wave_path, normalize=True)
waveform = waveform.unsqueeze(0)
print(waveform.shape)
test_cqt_data = cqt(waveform).to(device)
print(test_cqt_data.shape)

In [None]:
from sklearn.cluster import SpectralClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import torch
import matplotlib.pyplot as plt

# 假设 model 和 test_cqt_data 已经定义
model.eval()
with torch.no_grad():
    emb, mask, onset = model(test_cqt_data)
    emb = emb / torch.sqrt(emb.pow(2).sum(dim=1, keepdim=True) + 1e-8)
    emb = emb.cpu().numpy()[0]      # (16, 84, frame)
    mask = mask.cpu().numpy()[0]    # (84, frame)
    onset = onset.cpu().numpy()[0]

# mask大于阈值的数目记为n
positions = np.where(mask > 0.5)
emb_extracted = emb[:, positions[0], positions[1]].T        # (n, 16)

# 计算余弦相似度矩阵
similarity_matrix = cosine_similarity(emb_extracted)

# 进行谱聚类
spectral = SpectralClustering(n_clusters=2, affinity='precomputed', assign_labels="cluster_qr")
labels = spectral.fit_predict(np.exp(similarity_matrix))

class1 = np.zeros(mask.shape)
class2 = np.zeros(mask.shape)
class1[positions[0], positions[1]] = (labels == 0).astype(int)
class2[positions[0], positions[1]] = (labels == 1).astype(int)

plt.figure(figsize=(12, 15))

plt.subplot(3, 1, 1)
plt.title('note')
plt.imshow(mask + onset, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 2)
plt.title('class1')
plt.imshow(class1, aspect='auto', origin='lower', cmap='gray')

plt.subplot(3, 1, 3)
plt.title('class2')
plt.imshow(class2, aspect='auto', origin='lower', cmap='gray')

plt.tight_layout()
plt.show()