In [None]:
import glob
import random
import os
import torch
import torch.optim as optim
import time
from Model.BiLSTM import *
from torch.optim.lr_scheduler import ReduceLROnPlateau

shuffle_file = "../Trained/schuffle.txt"

wav_files, csv_files = [], []
with open(shuffle_file, 'r') as f:
    for line in f:
        wav_path, csv_path = line.strip().split(',')
        wav_files.append(wav_path)
        csv_files.append(csv_path)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Tạo model & optimizer
model = BiLSTM().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

checkpoint_path = "../Trained/BiLSTM_shuffle.pth"
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    print("Checkpoint loaded.")
else:
    print("No checkpoint found.")

index = 83
test_wav_files = glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/test_data/*.wav")
test_csv_files = glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/test_labels/*.csv")

# Train trên danh sách file đã shuffle từ điểm dừng
for i, (wav_path, csv_path) in enumerate(zip(wav_files, csv_files)):
    if i < index :
        continue
    print(time.strftime("%H:%M:%S"), f"Training on  {i} : {os.path.basename(wav_path)}")
    train_on_wav(model, optimizer, criterion, wav_path, csv_path, test_wav_files, test_csv_files, epochs=5)

print("Training complete!")

In [None]:
import torch
import os
from Model.BiLSTM import *

# Đường dẫn đến file cần dự đoán
wav_path = "../Data/MusicNet_Dataset/musicnet/musicnet/test_data/2628.wav"
# Tải model đã huấn luyện
model = BiLSTM().to(device)

model.load_state_dict(torch.load("../Trained/BiLSTM_shuffle.pth", map_location=device, weights_only=True))
# Dự đoán
note_events = predict(model, wav_path, None, device)
i=1
for event in note_events:
    print(f"{i} Note: {event['note']}, Start: {event['start']:2f}, Duration: {event['duration']:2f}")
    i+=1
# In các nốt đã phát hiện
detected_notes = sorted(set(event['note'] for event in note_events))
print("Các nốt đã phát hiện:", detected_notes)

In [None]:
from collections import Counter

# Đếm số lần mỗi nốt xuất hiện
note_counts = Counter(event['note'] for event in note_events)

# In ra danh sách nốt và số lần xuất hiện
print("Các nốt đã phát hiện và số lần xuất hiện:")
for note, count in sorted(note_counts.items()):
    print(f"Note {note}: {count} lần")

In [None]:
import pandas as pd


def count_instruments(csv_path):
    df = pd.read_csv(csv_path)

    if 'note' not in df.columns:
        print("Không tìm thấy cột 'note' trong file CSV.")
        return

    instrument_counts = df['note'].value_counts().sort_index()

    for instr, count in instrument_counts.items():
        print(f"Note '{instr}': {count} lần")
        
count_instruments("../Data/MusicNet_Dataset/musicnet/musicnet/test_labels/1819.csv")

In [None]:
import glob
import random
import os
import torch
import torch.optim as optim
import time
from Model.BiLSTM import *
from torch.optim.lr_scheduler import ReduceLROnPlateau

shuffle_file = "../Trained/schuffle.txt"

wav_files, csv_files = [], []
with open(shuffle_file, 'r') as f:
    for line in f:
        wav_path, csv_path = line.strip().split(',')
        wav_files.append(wav_path)
        csv_files.append(csv_path)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Tạo model & optimizer
model = BiLSTM_new().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

checkpoint_path = "../Trained/BiLSTM_4layer.pth"
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint)
    print("Checkpoint loaded.")
else:
    print("No checkpoint found.")

index = 280
test_wav_files = glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/test_data/*.wav")
test_csv_files = glob.glob("../Data/MusicNet_Dataset/musicnet/musicnet/test_labels/*.csv")

batch_size = 5
for i in range(index, len(wav_files), batch_size):
    wav_batch = wav_files[i:i+batch_size]
    csv_batch = csv_files[i:i+batch_size]

    print(time.strftime("%H:%M:%S"), f"Training on files {i} to {i+len(wav_batch)-1}")
    
    train_multiple_wavs(
        model, optimizer, criterion,
        wav_paths=wav_batch,
        csv_paths=csv_batch,
        test_wav_files=test_wav_files,
        test_csv_files=test_csv_files,
        checkpoint_path=checkpoint_path,
        epochs=5
    )

print("Training complete!")

In [1]:
from Model.BiLSTM import *
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Đường dẫn đến file cần dự đoán
wav_path = "../Data/MusicNet_Dataset/musicnet/musicnet/test_data/2628.wav"
# Tải model đã huấn luyện
model = BiLSTM_new().to(device)

model.load_state_dict(torch.load("../Trained/BiLSTM_4layer.pth", map_location=device, weights_only=True))
# Dự đoán
note_events = predict(model, wav_path, device)
i=1
for event in note_events:
    print(f"{i} Note: {event['note']}, Start: {event['start']:2f}, Duration: {event['duration']:2f}")
    i+=1
# In các nốt đã phát hiện
detected_notes = sorted(set(event['note'] for event in note_events))
print("Các nốt đã phát hiện:", detected_notes)

1 Note: 51, Start: 0.000000, Duration: 0.487619
2 Note: 54, Start: 0.000000, Duration: 2.925714
3 Note: 50, Start: 0.000000, Duration: 3.169524
4 Note: 54, Start: 3.169524, Duration: 1.706667
5 Note: 50, Start: 4.144762, Duration: 0.975238
6 Note: 55, Start: 0.000000, Duration: 5.607619
7 Note: 58, Start: 2.194286, Duration: 3.900952
8 Note: 62, Start: 4.144762, Duration: 1.950476
9 Note: 57, Start: 3.413333, Duration: 3.413333
10 Note: 51, Start: 6.582857, Duration: 0.975238
11 Note: 57, Start: 7.801905, Duration: 0.243810
12 Note: 54, Start: 7.314286, Duration: 0.975238
13 Note: 43, Start: 8.045714, Duration: 0.487619
14 Note: 55, Start: 8.289524, Duration: 0.487619
15 Note: 58, Start: 8.533333, Duration: 0.243810
16 Note: 55, Start: 9.020952, Duration: 0.975238
17 Note: 58, Start: 9.020952, Duration: 1.462857
18 Note: 50, Start: 5.607619, Duration: 5.120000
19 Note: 54, Start: 10.240000, Duration: 1.219048
20 Note: 43, Start: 11.215238, Duration: 0.731429
21 Note: 62, Start: 8.53333