In [None]:
import torch
import os
import tqdm

import sys
sys.path.append("..")
from utils.checkpoint import save_ckpt_template as save_ckpt, load_ckpt_template as load_ckpt

if torch.cuda.is_available():
    device = torch.device('cuda:1')
    torch.backends.cudnn.benchmark = True   # 加速卷积计算
else:
    device = torch.device('cpu')

print(device)

# CQT配置
from model.config import CONFIG
s_per_frame = CONFIG.s_per_frame

In [None]:
from data.septimbre.data import Instruments
dataset = Instruments(
    folder = '../data/septimbre/small_256',
    mix = 2,
    input = '.wav',
    output = '.npy'
)
print("训练集大小: ", len(dataset))
input, label = dataset[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
# 如果开了benchmark，batchsize最好一样大，且是两个数据集大小的公因数
dataloader = torch.utils.data.DataLoader(dataset, batch_size=18, shuffle=True, pin_memory=True, num_workers=4)

val = Instruments(
    folder = '../data/septimbre/tiny_256',
    mix = 2,
    input = '.wav',
    output = '.npy'
)
print("测试集大小: ", len(val))
input, label = val[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
valloader = torch.utils.data.DataLoader(val, batch_size=18, shuffle=False, pin_memory=True, num_workers=4)

train_3 = Instruments(
    folder = '../data/septimbre/tiny_256',
    mix = 3,
    input = '.wav',
    output = '.npy'
)
print("测试集大小: ", len(train_3))
input, label = train_3[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
train3loader = torch.utils.data.DataLoader(train_3, batch_size=18, shuffle=True, pin_memory=True, num_workers=4)

In [None]:
from septimbre import SepTimbreAMT
model = SepTimbreAMT(CONFIG.CQT).to(device)
optimizer_amt = torch.optim.AdamW(model.note_branch.parameters(), lr=3e-4)
cluster_params = [p for n, p in model.named_parameters() if not n.startswith('note_branch')]
optimizer_cluster = torch.optim.AdamW(cluster_params, lr=3e-4)
schedular_amt = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_amt, mode='min', factor=0.3, patience=3, threshold=1e-3)
# schedular_cluster = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cluster, mode='min', factor=0.3, patience=4, threshold=1e-3)
schedular_cluster = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer_cluster, T_0=4, T_mult=2, eta_min=1e-6
)
checkpoint_path = "septimbre.pth"
loss_path = "septimbre.loss.txt"
min_loss = float('inf')
epoch_now = 0
train_amt = True

In [None]:
# # from basicamt borrow parameters
# sys.path.append("../basicamt")
# basicamt_model = torch.load("../basicamt/best_basicamt_model.pth", map_location=device, weights_only=False)
# model.note_branch.fromBasicAMT(basicamt_model)

In [None]:
# from pre_trained
note_branch_state_dict = torch.load("./sepamt_note_branch.pth", map_location=device, weights_only=True)
model.note_branch.load_state_dict(note_branch_state_dict)
train_amt = False
for param in model.note_branch.parameters():
    param.requires_grad = False

def set_eval_mode(module):
    module.eval()
    for child in module.children():
        set_eval_mode(child)
set_eval_mode(model.note_branch)

In [None]:
# 加载上一轮模型参数
states = load_ckpt(model, [optimizer_amt, optimizer_cluster], checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No checkpoint loaded. Training from scratch.")

In [None]:
model_train = model.train()
if not train_amt:
    set_eval_mode(model.note_branch)
model_train = torch.compile(model_train, mode="max-autotune")
model_eval = torch.compile(model.eval(), mode="max-autotune")

In [None]:
from model.loss import LossNorm, DWA, PCGrad
lossnorm = LossNorm(0.91)
dwa = DWA(2)
pcgrad = PCGrad(optimizer_cluster)
def MTL_bp(cluster_loss, amt_loss, use_lossnorm = True, use_dwa = True, use_pcgrad = False):
    losses = torch.stack([cluster_loss, amt_loss])
    if use_lossnorm:
        losses = lossnorm(losses) * losses
    if use_dwa:
        losses = dwa(losses) * losses

    if use_pcgrad:
        pcgrad.pc_backward(losses)
        pcgrad.step()
        pcgrad.zero_grad()
    else:
        losses.sum().backward()
        optimizer_cluster.step()
        optimizer_cluster.zero_grad()

    if train_amt:
        optimizer_amt.step()
        optimizer_amt.zero_grad()

In [None]:
epoch_total = 62
stage_save_interval = 40    # save checkpoint every $stage_save_interval$ epochs

optimizer_amt.zero_grad()
optimizer_cluster.zero_grad()
schedular_cluster.step(epoch=epoch_now)
for epoch in range(epoch_now+1, epoch_total):
    # training
    # model_train.train()
    # if not train_amt:
    #     set_eval_mode(model.note_branch)
    train_loss = [0., 0.]
    virtual_batch = 0
    for (input, target) in tqdm.tqdm(train3loader):
        input = input.to(device)    # input: (batch, mix, 2, time)
        target = target.to(device)  # target: (batch, mix, 7 * 12, 660)

        if train_3.mix > 1:
            w = torch.rand(train_3.mix, device=input.device) * (0.4) + 0.8
            mixed = torch.sum(input * w.view(1, -1, 1, 1), dim=1, keepdim=False) / torch.sum(w)
            midi_mixed, _ = target.max(dim=-3, keepdim=False)
        else:
            mixed = input.squeeze(dim=1)
            midi_mixed = target.squeeze(dim=1)

        # midi_mixed: (batch, 7 * 12, 660)
        # mixed: (batch, 1, samples)
        mixed = mixed + torch.randn_like(mixed) * 0.0005 * torch.rand(mixed.shape[0], 1, 1, device=mixed.device)
        onset, frame, emb = model_train(mixed)
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        clloss, amtloss = SepTimbreAMT.loss(emb, frame, onset, target)
        train_loss[0] += clloss.item()
        train_loss[1] += amtloss.item()
        MTL_bp(clloss, amtloss, use_lossnorm = False, use_dwa = False, use_pcgrad=False)
        model.clampK()


    for (input, target) in tqdm.tqdm(dataloader):
        input = input.to(device)    # input: (batch, mix, 2, time)
        target = target.to(device)  # target: (batch, mix, 7 * 12, 660)

        if dataset.mix > 1:
            w = torch.rand(dataset.mix, device=input.device) * (0.4) + 0.8
            mixed = torch.sum(input * w.view(1, -1, 1, 1), dim=1, keepdim=False) / torch.sum(w)
            midi_mixed, _ = target.max(dim=-3, keepdim=False)
        else:
            mixed = input.squeeze(dim=1)
            midi_mixed = target.squeeze(dim=1)

        # midi_mixed: (batch, 7 * 12, 660)
        # mixed: (batch, 1, samples)
        mixed = mixed + torch.randn_like(mixed) * 0.0005 * torch.rand(mixed.shape[0], 1, 1, device=mixed.device)

        onset, frame, emb = model_train(mixed)
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        clloss, amtloss = SepTimbreAMT.loss(emb, frame, onset, target)
        train_loss[0] += clloss.item()
        train_loss[1] += amtloss.item()
        MTL_bp(clloss, amtloss, use_lossnorm = False, use_dwa = False, use_pcgrad=False)
        # loss = clloss * r_cluster + amtloss
        # loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()
        model.clampK()

    train_loss[0] /= len(dataset) + len(train_3)
    train_loss[1] /= len(dataset) + len(train_3)

    # validation
    # model_eval.eval()
    val_loss = [0., 0.]
    with torch.no_grad():
        for (input, target) in tqdm.tqdm(valloader):
            input = input.to(device)
            target = target.to(device)

            if val.mix > 1:
                mixed = torch.sum(input, dim=1, keepdim=False)
            else:
                mixed = input.squeeze(dim=1)

            onset, frame, emb = model_eval(mixed)
            emb = torch.nn.functional.normalize(emb, p=2, dim=1)
            clloss, amtloss = SepTimbreAMT.loss(emb, frame, onset, target)
            val_loss[0] += clloss.item()
            val_loss[1] += amtloss.item()

        val_loss[0] /= len(val)
        val_loss[1] /= len(val)

    if train_amt:
        train_loss_sum = LossNorm.norm_sum(torch.tensor([train_loss[0], train_loss[1]])).item()
        val_loss_sum = LossNorm.norm_sum(torch.tensor([val_loss[0], val_loss[1]])).item()
    else:
        train_loss_sum = train_loss[0]
        val_loss_sum = val_loss[0]

    final_loss = val_loss[0]
    if train_amt:
        final_loss = val_loss_sum
        schedular_amt.step(val_loss[1])
    # schedular_cluster.step(val_loss[0])
    schedular_cluster.step(epoch=epoch_now)

    # save checkpoint
    checkpoint_filename = f"epoch{epoch}.pth" if epoch % stage_save_interval == 0 else checkpoint_path
    save_ckpt(epoch, model, min_loss, final_loss, [optimizer_amt, optimizer_cluster], checkpoint_filename)
    if final_loss < min_loss:
        min_loss = final_loss

    print(f"====> Epoch: {epoch} Average train loss: {train_loss_sum:.4f} = {train_loss[0]:.4f} & {train_loss[1]:.4f}; Average val loss: {val_loss_sum:.4f} = {val_loss[0]:.4f} & {val_loss[1]:.4f}")
    with open(loss_path, 'a') as f:
        f.write(f"{epoch}:\ttrain_loss: {train_loss_sum:.4f}\tval_loss: {val_loss_sum:.4f}\t{train_loss[0]:.4f}\t{train_loss[1]:.4f}\t{val_loss[0]:.4f}\t{val_loss[1]:.4f}\n")
    epoch_now = epoch

In [None]:
# 绘制训练和验证损失曲线
import matplotlib.pyplot as plt

epochs = []
train_losses = []
val_losses = []

with open(loss_path, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        epoch = int(parts[0].split(':')[0])
        train_loss = float(parts[1].split(': ')[1])
        val_loss = float(parts[2].split(': ')[1])
        
        epochs.append(epoch)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# 测试模型

In [None]:
# 加载最佳模型
states = load_ckpt(model, [optimizer_amt, optimizer_cluster], "best_"+checkpoint_path)
# states = load_ckpt(model, [optimizer_amt, optimizer_cluster], "epoch32.pth")
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No best checkpoint loaded.")

In [None]:
# 切换到cpu 因为benchmark会导致初次运行时间长
device = torch.device('cpu')
model = model.to(device)
torch.backends.cudnn.benchmark = False

In [None]:
# 读取音频
import torchaudio
from utils.midiarray import numpy2midi
from utils.wavtool import waveInfo
import matplotlib.pyplot as plt
import numpy as np

test_wave_path = "../data/inferMusic/short mix.wav"
# test_wave_path = "../data/inferMusic/孤独な巡礼simple.wav"
waveInfo(test_wave_path)

waveform, sample_rate = torchaudio.load(test_wave_path, normalize=True)
waveform = waveform.unsqueeze(0)
print(waveform.shape)

In [None]:
from sklearn.cluster import SpectralClustering
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import torch
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    onset, frame, emb = model(waveform)
    emb = emb / torch.sqrt(emb.pow(2).sum(dim=1, keepdim=True) + 1e-8)
    emb = emb.cpu().numpy()[0]      # (16, 84, frame)
    frame = frame.cpu().numpy()[0]    # (84, frame)
    onset = onset.cpu().numpy()[0]
    print(emb.shape, frame.shape, onset.shape)

# mask大于阈值的数目记为n
positions = np.where(frame > 0.4)
emb_extracted = emb[:, positions[0], positions[1]].T        # (n, 16)

# 计算余弦相似度矩阵
similarity_matrix = cosine_similarity(emb_extracted)

# 进行谱聚类
print("clustering...")
spectral = SpectralClustering(n_clusters=2, affinity='precomputed', assign_labels="cluster_qr")
labels = spectral.fit_predict(np.exp(similarity_matrix))

class1 = np.zeros(frame.shape)
class2 = np.zeros(frame.shape)
class1[positions[0], positions[1]] = (labels == 0).astype(int)
class2[positions[0], positions[1]] = (labels == 1).astype(int)

plt.figure(figsize=(12, 15))

plt.subplot(3, 1, 1)
plt.title('note')
plt.imshow(frame, aspect='auto', origin='lower', cmap='gray')
plt.colorbar()

plt.subplot(3, 1, 2)
plt.title('class1')
plt.imshow(class1, aspect='auto', origin='lower', cmap='gray')

plt.subplot(3, 1, 3)
plt.title('class2')
plt.imshow(class2, aspect='auto', origin='lower', cmap='gray')

plt.tight_layout()
plt.show()

In [None]:
torch.save(model, "sepamt_model.pth")