# 训练不具备分离能力的网络
basicamt.py

In [1]:
import torch
import os
import tqdm

import sys
sys.path.append("../../..")
from utils.checkpoint import save_ckpt_template as save_ckpt, load_ckpt_template as load_ckpt

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.backends.cudnn.benchmark = True   # 加速卷积计算
else:
    device = torch.device('cpu')

print(device)

cuda


In [2]:
from data.septimbre.data import Instruments
dataset = Instruments(
    folder = '../../../data/septimbre/piano_large_short_256',
    mix = 1,
    input = '.cqt.npy',
    output = '.npy'
)
print("训练集大小: ", len(dataset))
input, label = dataset[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
# 如果开了benchmark，batchsize最好一样大，且是两个数据集大小的公因数
dataloader = torch.utils.data.DataLoader(dataset, batch_size=18, shuffle=True, pin_memory=True, num_workers=4)

val = Instruments(
    folder = '../../../data/septimbre/piano_medium_short_256',
    mix = 1,
    input = '.cqt.npy',
    output = '.npy'
)
print("测试集大小: ", len(val))
input, label = val[0]
print("输入大小: ", input.shape)
print("输出大小: ", label.shape)
valloader = torch.utils.data.DataLoader(val, batch_size=18, shuffle=False, pin_memory=True, num_workers=4)

训练集大小:  1800
输入大小:  torch.Size([1, 2, 288, 360])
输出大小:  torch.Size([1, 84, 360])
测试集大小:  180
输入大小:  torch.Size([1, 2, 288, 360])
输出大小:  torch.Size([1, 84, 360])


In [3]:
from basicamt_noDilation import BasicAMT_noDilation
model = BasicAMT_noDilation().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
schedular = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.32, patience=2, threshold=1e-3)
checkpoint_path = "basicamt_noDilation.pth"
loss_path = "basicamt_noDilation.loss.txt"
min_loss = float('inf')
epoch_now = 0

In [4]:
# 加载上一轮模型参数
states = load_ckpt(model, optimizer, checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No checkpoint loaded. Training from scratch.")

Checkpoint file 'basicamt_noDilation.pth' does not exist.
No checkpoint loaded. Training from scratch.


In [5]:
epoch_total = 50
stage_save_interval = 45    # save checkpoint every $stage_save_interval$ epochs

optimizer.zero_grad()
for epoch in range(epoch_now+1, epoch_total):
    # training
    model.train()
    train_loss = 0
    virtual_batch = 0
    for (input, target) in tqdm.tqdm(dataloader):
        input = input.to(device)    # input: (batch, mix, 2, time)
        target = target.to(device)  # target: (batch, mix, 7 * 12, 660)

        # 由于basicamt的目标没有分离音色，因此mix仅仅是数据增强的手段
        if dataset.mix > 1:
            mixed = torch.mean(input, dim=1, keepdim=False)  # 混合后的CQT
            midi_mixed, _ = target.max(dim=-3, keepdim=False)
        else:
            mixed = input.squeeze(dim=1)
            midi_mixed = target.squeeze(dim=1)

        mixed = mixed + torch.randn_like(mixed) * 0.01   # 加入噪声 经过听觉和可视化觉得0.01是一个合适的值
        # midi_mixed: (batch, 7 * 12, 660)

        onset, note = model(mixed)
        # onset & note: (batch, 7 * 12, 660)
        loss = BasicAMT_noDilation.loss(onset, note, midi_mixed)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        model.clampK()

    train_loss /= len(dataloader)

    # validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for (input, target) in tqdm.tqdm(valloader):
            input = input.to(device)
            target = target.to(device)

            if val.mix > 1:
                mixed = torch.sum(input, dim=1, keepdim=False)
                midi_mixed, _ = target.max(dim=-3, keepdim=False)
            else:
                mixed = input.squeeze(dim=1)
                midi_mixed = target.squeeze(dim=1)

            onset, note = model(mixed)
            loss = BasicAMT_noDilation.loss(onset, note, midi_mixed)
            val_loss += loss.item()

        val_loss /= len(valloader)
    
    final_loss = val_loss * 0.85 + train_loss * 0.15
    schedular.step(final_loss)

    # save checkpoint
    checkpoint_filename = f"epoch{epoch}.pth" if epoch % stage_save_interval == 0 else checkpoint_path
    save_ckpt(epoch, model, min_loss, final_loss, optimizer, checkpoint_filename)
    if final_loss < min_loss:
        min_loss = final_loss

    print(f"====> Epoch: {epoch} Average train loss: {train_loss:.4f}; Average val loss: {val_loss:.4f}")
    with open(loss_path, 'a') as f:
        f.write(f"{epoch}:\ttrain_loss: {train_loss:.4f}\tval_loss: {val_loss:.4f}\n")
    epoch_now = epoch

100%|██████████| 100/100 [00:05<00:00, 17.73it/s]
100%|██████████| 10/10 [00:00<00:00, 70.21it/s]


====> Epoch: 1 Average train loss: 116632.4355; Average val loss: 48888.2879


100%|██████████| 100/100 [00:02<00:00, 34.70it/s]
100%|██████████| 10/10 [00:00<00:00, 74.53it/s]


====> Epoch: 2 Average train loss: 22872.1482; Average val loss: 20476.3230


100%|██████████| 100/100 [00:02<00:00, 34.64it/s]
100%|██████████| 10/10 [00:00<00:00, 60.66it/s]


====> Epoch: 3 Average train loss: 11665.0997; Average val loss: 13699.6149


100%|██████████| 100/100 [00:02<00:00, 34.66it/s]
100%|██████████| 10/10 [00:00<00:00, 67.49it/s]


====> Epoch: 4 Average train loss: 8344.4377; Average val loss: 10291.8371


100%|██████████| 100/100 [00:02<00:00, 34.63it/s]
100%|██████████| 10/10 [00:00<00:00, 62.45it/s]


====> Epoch: 5 Average train loss: 6660.1412; Average val loss: 11438.1713


100%|██████████| 100/100 [00:02<00:00, 34.69it/s]
100%|██████████| 10/10 [00:00<00:00, 62.92it/s]


====> Epoch: 6 Average train loss: 4256.8037; Average val loss: 14056.7487


100%|██████████| 100/100 [00:02<00:00, 34.29it/s]
100%|██████████| 10/10 [00:00<00:00, 70.49it/s]


====> Epoch: 7 Average train loss: 3155.5347; Average val loss: 15922.2385


100%|██████████| 100/100 [00:02<00:00, 34.38it/s]
100%|██████████| 10/10 [00:00<00:00, 74.87it/s]


====> Epoch: 8 Average train loss: 2665.0036; Average val loss: 16839.9940


100%|██████████| 100/100 [00:02<00:00, 34.48it/s]
100%|██████████| 10/10 [00:00<00:00, 68.00it/s]


====> Epoch: 9 Average train loss: 2508.9234; Average val loss: 17022.4516


100%|██████████| 100/100 [00:02<00:00, 34.32it/s]
100%|██████████| 10/10 [00:00<00:00, 63.64it/s]


====> Epoch: 10 Average train loss: 2366.7737; Average val loss: 16456.2772


100%|██████████| 100/100 [00:02<00:00, 34.73it/s]
100%|██████████| 10/10 [00:00<00:00, 70.21it/s]


====> Epoch: 11 Average train loss: 2281.2812; Average val loss: 17982.3525


100%|██████████| 100/100 [00:02<00:00, 34.55it/s]
100%|██████████| 10/10 [00:00<00:00, 66.51it/s]


====> Epoch: 12 Average train loss: 2238.3675; Average val loss: 17468.9014


100%|██████████| 100/100 [00:02<00:00, 34.64it/s]
100%|██████████| 10/10 [00:00<00:00, 67.29it/s]


====> Epoch: 13 Average train loss: 2202.0871; Average val loss: 18788.6777


100%|██████████| 100/100 [00:02<00:00, 34.40it/s]
100%|██████████| 10/10 [00:00<00:00, 59.27it/s]


====> Epoch: 14 Average train loss: 2172.5862; Average val loss: 18368.6885


100%|██████████| 100/100 [00:02<00:00, 34.58it/s]
100%|██████████| 10/10 [00:00<00:00, 63.04it/s]


====> Epoch: 15 Average train loss: 2160.2485; Average val loss: 18624.4393


100%|██████████| 100/100 [00:02<00:00, 34.51it/s]
100%|██████████| 10/10 [00:00<00:00, 62.97it/s]


====> Epoch: 16 Average train loss: 2147.5717; Average val loss: 18294.7106


100%|██████████| 100/100 [00:02<00:00, 33.94it/s]
100%|██████████| 10/10 [00:00<00:00, 68.99it/s]


====> Epoch: 17 Average train loss: 2138.1103; Average val loss: 18418.2979


100%|██████████| 100/100 [00:02<00:00, 34.12it/s]
100%|██████████| 10/10 [00:00<00:00, 64.24it/s]


====> Epoch: 18 Average train loss: 2132.1784; Average val loss: 18430.2217


100%|██████████| 100/100 [00:02<00:00, 34.28it/s]
100%|██████████| 10/10 [00:00<00:00, 63.85it/s]


====> Epoch: 19 Average train loss: 2127.2150; Average val loss: 18705.7637


100%|██████████| 100/100 [00:02<00:00, 34.38it/s]
100%|██████████| 10/10 [00:00<00:00, 62.96it/s]


====> Epoch: 20 Average train loss: 2125.0483; Average val loss: 18768.8759


100%|██████████| 100/100 [00:02<00:00, 34.41it/s]
100%|██████████| 10/10 [00:00<00:00, 69.16it/s]


====> Epoch: 21 Average train loss: 2121.4751; Average val loss: 18612.1992


100%|██████████| 100/100 [00:02<00:00, 34.47it/s]
100%|██████████| 10/10 [00:00<00:00, 70.76it/s]


====> Epoch: 22 Average train loss: 2119.1003; Average val loss: 18620.3980


100%|██████████| 100/100 [00:02<00:00, 34.39it/s]
100%|██████████| 10/10 [00:00<00:00, 72.32it/s]


====> Epoch: 23 Average train loss: 2120.5454; Average val loss: 18731.7453


100%|██████████| 100/100 [00:02<00:00, 34.36it/s]
100%|██████████| 10/10 [00:00<00:00, 62.39it/s]


====> Epoch: 24 Average train loss: 2119.6298; Average val loss: 18596.6837


100%|██████████| 100/100 [00:02<00:00, 34.33it/s]
100%|██████████| 10/10 [00:00<00:00, 62.15it/s]


====> Epoch: 25 Average train loss: 2118.6645; Average val loss: 18784.6524


100%|██████████| 100/100 [00:02<00:00, 34.02it/s]
100%|██████████| 10/10 [00:00<00:00, 69.06it/s]


====> Epoch: 26 Average train loss: 2117.8307; Average val loss: 18644.2834


100%|██████████| 100/100 [00:02<00:00, 34.02it/s]
100%|██████████| 10/10 [00:00<00:00, 68.39it/s]


====> Epoch: 27 Average train loss: 2119.0754; Average val loss: 18637.0881


100%|██████████| 100/100 [00:02<00:00, 34.31it/s]
100%|██████████| 10/10 [00:00<00:00, 62.61it/s]


====> Epoch: 28 Average train loss: 2117.9427; Average val loss: 18689.2673


100%|██████████| 100/100 [00:02<00:00, 34.14it/s]
100%|██████████| 10/10 [00:00<00:00, 60.66it/s]


====> Epoch: 29 Average train loss: 2116.9886; Average val loss: 18596.0294


100%|██████████| 100/100 [00:02<00:00, 34.37it/s]
100%|██████████| 10/10 [00:00<00:00, 65.18it/s]


====> Epoch: 30 Average train loss: 2116.4794; Average val loss: 18654.8177


100%|██████████| 100/100 [00:02<00:00, 34.44it/s]
100%|██████████| 10/10 [00:00<00:00, 66.31it/s]


====> Epoch: 31 Average train loss: 2116.6236; Average val loss: 18541.9771


100%|██████████| 100/100 [00:02<00:00, 34.37it/s]
100%|██████████| 10/10 [00:00<00:00, 73.64it/s]


====> Epoch: 32 Average train loss: 2118.0721; Average val loss: 18563.6455


100%|██████████| 100/100 [00:02<00:00, 34.16it/s]
100%|██████████| 10/10 [00:00<00:00, 68.94it/s]


====> Epoch: 33 Average train loss: 2117.2569; Average val loss: 18813.8945


100%|██████████| 100/100 [00:02<00:00, 33.91it/s]
100%|██████████| 10/10 [00:00<00:00, 68.82it/s]


====> Epoch: 34 Average train loss: 2117.5206; Average val loss: 18542.0164


100%|██████████| 100/100 [00:02<00:00, 34.11it/s]
100%|██████████| 10/10 [00:00<00:00, 59.72it/s]


====> Epoch: 35 Average train loss: 2116.5483; Average val loss: 18793.7194


100%|██████████| 100/100 [00:02<00:00, 34.10it/s]
100%|██████████| 10/10 [00:00<00:00, 61.03it/s]


====> Epoch: 36 Average train loss: 2115.8146; Average val loss: 18717.6515


100%|██████████| 100/100 [00:02<00:00, 34.07it/s]
100%|██████████| 10/10 [00:00<00:00, 62.35it/s]


====> Epoch: 37 Average train loss: 2116.2076; Average val loss: 18583.7055


 67%|██████▋   | 67/100 [00:02<00:00, 33.34it/s]


KeyboardInterrupt: 

In [None]:
# 绘制训练和验证损失曲线
import matplotlib.pyplot as plt

epochs = []
train_losses = []
val_losses = []

with open(loss_path, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        epoch = int(parts[0].split(':')[0])
        train_loss = float(parts[1].split(': ')[1])
        val_loss = float(parts[2].split(': ')[1])
        
        epochs.append(epoch)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

## 测试模型

In [None]:
# 加载最佳模型
states = load_ckpt(model, optimizer, "best_" + checkpoint_path)
if states:
    min_loss, avg_loss, epoch_now = states
else:
    print("No best checkpoint loaded.")

In [None]:
# 切换到cpu 因为benchmark会导致初次运行时间长
device = torch.device('cpu')
model = model.to(device)
torch.backends.cudnn.benchmark = False

In [None]:
# CQT配置
from model.config import CONFIG
s_per_frame = CONFIG.s_per_frame

from model.CQT import CQTsmall_fir

cqt = CQTsmall_fir(
    config=CONFIG.CQT
).to(device)

In [None]:
# 读取音频，分析为CQT
import torchaudio
from utils.midiarray import numpy2midi
from utils.wavtool import waveInfo
import matplotlib.pyplot as plt
import numpy as np

test_wave_path = "../../../data/inferMusic/piano_short.wav"
# test_wave_path = r'C:\amt\data\septimbre\small\inst0\0.wav'
waveInfo(test_wave_path)

waveform, sample_rate = torchaudio.load(test_wave_path, normalize=True)
waveform = waveform.unsqueeze(0)
print(waveform.shape)
test_cqt_data = cqt(waveform).to(device)
print(test_cqt_data.shape)

In [None]:
model.eval()
with torch.no_grad():
    onset, note = model(test_cqt_data)
    onset = onset.cpu().numpy()[0]
    note = note.cpu().numpy()[0]
    plt.figure(figsize=(12, 10))

    plt.subplot(2, 1, 1)
    plt.title('Note')
    plt.imshow(note, aspect='auto', origin='lower', cmap='gray')
    plt.colorbar()

    plt.subplot(2, 1, 2)
    plt.title('Onset')
    plt.imshow(onset, aspect='auto', origin='lower', cmap='gray')
    plt.colorbar()

    plt.tight_layout()
    plt.show()