In [13]:
import glob

wav_files = sorted(glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/train_data/*.wav"))
csv_files = sorted(glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/train_labels/*.csv"))

# Đảm bảo số lượng file trùng khớp
assert len(wav_files) == len(csv_files), "Số lượng file WAV và CSV không khớp!"

In [14]:
from Model.CNN_3L_pro import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from Model.CNN_3L_pro import CNN_Pro
import os

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Khởi tạo mô hình, hàm mất mát và optimizer
model = CNN_Pro().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

if os.path.exists("CNN_pro.pth"):       
    checkpoint = torch.load("CNN_Pro.pth",weights_only=True, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Huấn luyện từng file một
for wav_path, csv_path in zip(wav_files, csv_files):
    print(f"\nfile : {wav_path}")
    # Load dữ liệu
    X_train, y_train = load_wav_csv(wav_path, csv_path)
    # Tạo DataLoader
    dataset = MusicDataset(X_train, y_train)
    train_loader = data.DataLoader(dataset, batch_size=64, shuffle=False)

    for epoch in range(20):  # Giảm số epoch cho từng file
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        if epoch >= 19:
            print(f"Epoch {epoch+1}/20, Loss: {running_loss/len(train_loader):.4f}, "
              f"Accuracy: {100 * correct / total:.2f}%")

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, "CNN_Pro.pth")
    print(f"Updated checkpoint sau file {wav_path}")

print("done training !")


file : ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1727.wav
Epoch 20/20, Loss: 0.1506, Accuracy: 95.67%
Updated checkpoint sau file ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1727.wav

file : ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1728.wav
Epoch 20/20, Loss: 0.1240, Accuracy: 96.13%
Updated checkpoint sau file ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1728.wav

file : ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1729.wav
Epoch 20/20, Loss: 0.1439, Accuracy: 95.81%
Updated checkpoint sau file ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1729.wav

file : ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1730.wav
Epoch 20/20, Loss: 0.1822, Accuracy: 96.15%
Updated checkpoint sau file ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1730.wav

file : ../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1733.wav
Epoch 20/20, Loss: 0.1735, Accuracy: 95.63%
Updated checkpoint sau file ../Data/MusicNet_Dataset/musicnet

In [15]:
import torch

# Load trọng số từ file đã lưu
checkpoint = torch.load("CNN_Pro.pth",weights_only=True, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  



wav_file = "../Data/MusicNet_Dataset/musicnet/musicnet/test_data/2303.wav"  
predicted_notes = predict_notes(wav_file)
for start, end, note in predicted_notes:
    print(f"Start: {start} µs, End: {end} µs, Note: {note}")

Predicted notes: [128 128 128 128 128 128 128 128 128 128]


In [16]:
import librosa

wav_path = "../Data/MusicNet_Dataset/musicnet/musicnet/train_data/1727.wav"
y, sr = librosa.load(wav_path, sr=44100)  # Tải file WAV
hop_length = 512  # Giữ nguyên như trong model

# Tính số frame thực tế có trong file
num_frames = len(y) // hop_length

print(f"Sample rate (sr): {sr}")
print(f"Số mẫu trong file WAV: {len(y)}")
print(f"Số frame hợp lý (num_frames): {num_frames}")

Sample rate (sr): 44100
Số mẫu trong file WAV: 19715328
Số frame hợp lý (num_frames): 38506
