In [None]:
import torch
import os
import tqdm

import sys
sys.path.append("../../../")
from utils.checkpoint import save_ckpt_template as save_ckpt, load_ckpt_template as load_ckpt

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True   # 加速卷积计算
else:
    device = torch.device('cpu')

print(device)

# CQT配置
from model.config import CONFIG
s_per_frame = CONFIG.s_per_frame

In [None]:
from data.musicnet.data import MusicNetData

def collate_fn(batch):
    inputs, outputs = zip(*batch)
    inputs = torch.stack(inputs)  # [batch, ...]
    return inputs, list(outputs)

dataset = MusicNetData(
    folder = '../../../data/musicnet/train',
    input = '.wav',
    output = '.npy'
)
print("训练集大小: ", len(dataset))
input, label = dataset[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
# 如果开了benchmark，batchsize最好一样大，且是两个数据集大小的公因数
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=18, shuffle=True, pin_memory=True, num_workers=4,
    collate_fn=collate_fn
)

val = MusicNetData(
    folder = '../../../data/musicnet/test',
    input = '.wav',
    output = '.npy'
)
print("测试集大小: ", len(val))
input, label = val[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
valloader = torch.utils.data.DataLoader(
    val, batch_size=18, shuffle=False, pin_memory=True, num_workers=4,
    collate_fn=collate_fn
)

In [None]:
from septimbre_rescale import SepTimbreAMT_joint
model = SepTimbreAMT_joint(CONFIG.CQT).to(device)
optimizer_amt = torch.optim.AdamW(model.note_branch.parameters(), lr=3e-4)
cluster_params = [p for n, p in model.named_parameters() if not n.startswith('note_branch')]
optimizer_cluster = torch.optim.AdamW(cluster_params, lr=3e-4)
schedular_amt = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_amt, mode='min', factor=0.3, patience=3, threshold=1e-3)
# schedular_cluster = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_cluster, mode='min', factor=0.3, patience=4, threshold=1e-3)
schedular_cluster = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer_cluster, T_0=4, T_mult=2, eta_min=1e-6
)
checkpoint_path = "septimbre.pth"
loss_path = "septimbre.loss.txt"
min_loss = float('inf')
epoch_now = 0
train_amt = True

In [None]:
# 加载上一轮模型参数
states = load_ckpt(model, [optimizer_amt, optimizer_cluster], checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No checkpoint loaded. Training from scratch.")

In [None]:
# from pre_trained
note_branch_state_dict = torch.load("../../sepamt_note_branch.pth", map_location=device, weights_only=True)
model.note_branch.load_state_dict(note_branch_state_dict)
train_amt = False
for param in model.note_branch.parameters():
    param.requires_grad = False

def set_eval_mode(module):
    module.eval()
    for child in module.children():
        set_eval_mode(child)
set_eval_mode(model.note_branch)

In [None]:
model_train = model.train()
if not train_amt:
    set_eval_mode(model.note_branch)
model_train = torch.compile(model_train, mode="max-autotune")
model_eval = torch.compile(model.eval(), mode="max-autotune")

In [None]:
from model.loss import LossNorm, DWA, PCGrad
lossnorm = LossNorm(0.91)
dwa = DWA(2)
pcgrad = PCGrad(optimizer_cluster)
def MTL_bp(cluster_loss, amt_loss, use_lossnorm = True, use_dwa = True, use_pcgrad = False):
    losses = torch.stack([cluster_loss, amt_loss])
    if use_lossnorm:
        losses = lossnorm(losses) * losses
    if use_dwa:
        losses = dwa(losses) * losses

    if use_pcgrad:
        pcgrad.pc_backward(losses)
        pcgrad.step()
        pcgrad.zero_grad()
    else:
        losses.sum().backward()
        optimizer_cluster.step()
        optimizer_cluster.zero_grad()

    if train_amt:
        optimizer_amt.step()
        optimizer_amt.zero_grad()

In [None]:
epoch_total = 62
stage_save_interval = 50    # save checkpoint every $stage_save_interval$ epochs

optimizer_amt.zero_grad()
optimizer_cluster.zero_grad()
schedular_cluster.step(epoch=epoch_now)
for epoch in range(epoch_now+1, epoch_total):
    # training
    # model_train.train()
    # if not train_amt:
    #     set_eval_mode(model.note_branch)
    train_loss = [0., 0.]
    virtual_batch = 0
    for (input, target) in tqdm.tqdm(dataloader):
        input = input.to(device)    # input: (batch, channel, time)
        for i in range(len(target)):
            target[i] = target[i].to(device)
        # target: list of (mix, notes, time) tensor
        input = input + torch.randn_like(input) * 0.0005 * torch.rand(input.shape[0], 1, 1, device=input.device)
        onset, frame, emb = model_train(input)
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        clloss, amtloss = SepTimbreAMT_joint.loss(emb, frame, onset, target)
        train_loss[0] += clloss.item()
        train_loss[1] += amtloss.item()
        MTL_bp(clloss, amtloss, use_lossnorm = False, use_dwa = False, use_pcgrad=False)
        model.clampK()

    train_loss[0] /= len(dataset)
    train_loss[1] /= len(dataset)

    # validation
    # model_eval.eval()
    val_loss = [0., 0.]
    with torch.no_grad():
        for (input, target) in tqdm.tqdm(valloader):
            input = input.to(device)
            for i in range(len(target)):
                target[i] = target[i].to(device)
            onset, frame, emb = model_eval(input)
            emb = torch.nn.functional.normalize(emb, p=2, dim=1)
            clloss, amtloss = SepTimbreAMT_joint.loss(emb, frame, onset, target)
            val_loss[0] += clloss.item()
            val_loss[1] += amtloss.item()

        val_loss[0] /= len(val)
        val_loss[1] /= len(val)

    train_loss_sum = LossNorm.norm_sum(torch.tensor([train_loss[0], train_loss[1]])).item()
    val_loss_sum = LossNorm.norm_sum(torch.tensor([val_loss[0], val_loss[1]])).item()

    final_loss = val_loss[0]    # only care about cluster loss if amt branch is frozen
    if train_amt:
        final_loss = val_loss_sum
        schedular_amt.step(val_loss[1])
    schedular_cluster.step(epoch=epoch_now)

    # save checkpoint
    checkpoint_filename = f"epoch{epoch}.pth" if epoch % stage_save_interval == 0 else checkpoint_path
    save_ckpt(epoch, model, min_loss, final_loss, [optimizer_amt, optimizer_cluster], checkpoint_filename)
    if final_loss < min_loss:
        min_loss = final_loss

    print(f"====> Epoch: {epoch} Average train loss: {train_loss_sum:.4f} = {train_loss[0]:.4f} & {train_loss[1]:.4f}; Average val loss: {val_loss_sum:.4f} = {val_loss[0]:.4f} & {val_loss[1]:.4f}")
    with open(loss_path, 'a') as f:
        f.write(f"{epoch}:\ttrain_loss: {train_loss_sum:.4f}\tval_loss: {val_loss_sum:.4f}\t{train_loss[0]:.4f}\t{train_loss[1]:.4f}\t{val_loss[0]:.4f}\t{val_loss[1]:.4f}\n")
    epoch_now = epoch

In [None]:
# 绘制训练和验证损失曲线
import matplotlib.pyplot as plt

epochs = []
train_losses = []
val_losses = []

with open(loss_path, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        epoch = int(parts[0].split(':')[0])
        train_loss = float(parts[1].split(': ')[1])
        val_loss = float(parts[2].split(': ')[1])
        
        epochs.append(epoch)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# 测试模型

In [None]:
# 加载最佳模型
states = load_ckpt(model, [optimizer_amt, optimizer_cluster], "best_"+checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No best checkpoint loaded.")

In [None]:
# 切换到cpu 因为benchmark会导致初次运行时间长
device = torch.device('cpu')
model = model.to(device)
torch.backends.cudnn.benchmark = False

In [None]:
torch.save(model, "sepamt_model.pth")