In [58]:
import glob

wav_files = sorted(glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/train_data/*.wav"))
csv_files = sorted(glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/train_labels/*.csv"))

# Đảm bảo số lượng file trùng khớp
assert len(wav_files) == len(csv_files), "Số lượng file WAV và CSV không khớp!"

In [60]:
from Model.CNN_3layer import CNN_3L

In [62]:
import torch.nn as nn
import torch.optim as optim
from Model.CNN_3L_pro import *
import torch
import os

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Khởi tạo mô hình, hàm mất mát và optimizer
model = CNN_3L().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

if os.path.exists("checkpoint.pth"):       
    checkpoint = torch.load("checkpoint.pth",weights_only=True, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Huấn luyện từng file một
for wav_path, csv_path in zip(wav_files, csv_files):
    print(f"\nfile : {wav_path}")
    # Load dữ liệu
    X_train, y_train = load_wav_csv(wav_path, csv_path)
    # Tạo DataLoader
    dataset = MusicDataset(X_train, y_train)
    train_loader = data.DataLoader(dataset, batch_size=64, shuffle=False)

    for epoch in range(20):  # Giảm số epoch cho từng file
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            print("Raw outputs:", outputs.max(), outputs.min(), outputs.mean())
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        if epoch >= 19:
            print(f"Epoch {epoch+1}/20, Loss: {running_loss/len(train_loader):.4f}, "
              f"Accuracy: {100 * correct / total:.2f}%")

    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, "checkpoint.pth")
    print(f"Updated checkpoint sau file {wav_path}")

print("done training !")

RuntimeError: Error(s) in loading state_dict for CNN_Pro:
	Missing key(s) in state_dict: "conv4.weight", "conv4.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var". 
	size mismatch for conv1.weight: copying a param with shape torch.Size([32, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 1, 5, 5]).
	size mismatch for fc1.weight: copying a param with shape torch.Size([256, 32768]) from checkpoint, the shape in current model is torch.Size([512, 4096]).
	size mismatch for fc1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for fc2.weight: copying a param with shape torch.Size([129, 256]) from checkpoint, the shape in current model is torch.Size([129, 512]).

In [None]:
import torch
import librosa
import numpy as np

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load mô hình đã huấn luyện
model = CNN_Pro().to(device)
checkpoint = torch.load("checkpoint.pth",weights_only=True, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
for name, param in model.named_parameters():
    if "fc2" in name:  # Lớp cuối cùng
        print(name, param.mean().item(), param.std().item())
model.eval()  # Chuyển sang chế độ dự đoán

def predict_notes(wav_path, sr=44100, hop_length=512, window_size=128, step=64):
    y, sr = librosa.load(wav_path, sr=sr)
    mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, hop_length=hop_length)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    print("Mel spectrogram shape:", mel_spec_db.shape)
    print("Max value:", np.max(mel_spec_db), "Min value:", np.min(mel_spec_db))
    # Chia nhỏ spectrogram thành các đoạn như khi huấn luyện
    def create_windows(X, window_size, step):
        X_windows = []
        indices = []
        for i in range(0, X.shape[1] - window_size, step):
            X_windows.append(X[:, i:i + window_size])
            indices.append(i)
        return np.array(X_windows), indices

    X_test, indices = create_windows(mel_spec_db, window_size, step)
    X_test = torch.tensor(X_test, dtype=torch.float32).unsqueeze(1).to(device)

    # Dự đoán
    with torch.no_grad():
        outputs = model(X_test)
        predicted_notes = torch.argmax(outputs, dim=1).cpu().numpy()

    # Chuyển frame thành thời gian
    def frame_to_time(frame, sr, hop_length):
        return (frame * hop_length) / sr

    results = []
    prev_note = None
    start_time = None
    
    for idx, note in zip(indices, predicted_notes):
        time = frame_to_time(idx, sr, hop_length)
    
        if note != prev_note:  
            if prev_note is not None and prev_note != 128:  # Lưu nốt trước đó nếu nó không phải "không có nốt"
                results.append((prev_note, start_time, time))
            if note != 128:  # Khi đổi sang nốt mới, cập nhật start_time nếu không phải "không có nốt"
                start_time = time  
            prev_note = note  
    
    # Nếu file kết thúc với một nốt đang chơi, lưu lại nó
    if prev_note is not None and prev_note != 128:
        results.append((prev_note, start_time, time))

    return results

# Dùng model để dự đoán nốt nhạc của một file WAV
wav_file = "../Data/MusicNet_Dataset/musicnet/musicnet/train_data/2478.wav"
predicted_notes = predict_notes(wav_file)
print("Predicted notes:", predicted_notes)
print("Unique predicted values:", np.unique(predicted_notes))

for note, start, end in predicted_notes:
    if note == 128:  # Nếu là nhãn "không có nốt", bỏ qua
        continue
    note_name = librosa.midi_to_note(note)  # Chuyển số MIDI thành tên nốt nhạc
    print(f"Note: {note_name}, Start: {start:.2f}s, End: {end:.2f}s")