In [12]:
import torch
import librosa
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import numpy as np

# -------------------- Label Mapping --------------------
LABELS = {
    "piano": 0,
    "drums": 1,
    "bass": 2,
    "guitar": 3,
    "no_music": 4
}

# -------------------- Dataset Class --------------------
class BabySlakhDataset(Dataset):
    def __init__(self, stem_dict, segment_duration=2.0, sample_rate=22050, n_mels=128, energy_threshold=0.01):
        """
        Args:
            stem_dict (dict): Dictionary mapping file paths to labels.
            segment_duration (float): Duration (in seconds) of each segment.
            sample_rate (int): Audio sample rate.
            n_mels (int): Number of Mel spectrogram bins.
            energy_threshold (float): Threshold to determine if audio segment has music.
        """
        self.sample_rate = sample_rate
        self.segment_samples = int(segment_duration * sample_rate)
        self.mel_transform = T.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
        self.segments = []

        for path, label in stem_dict.items():
            waveform, sr = librosa.load(path, sr=self.sample_rate)
            total_samples = len(waveform)
            for start in range(0, total_samples, self.segment_samples):
                end = start + self.segment_samples
                segment = waveform[start:end]
                if len(segment) < self.segment_samples:
                    continue  # skip incomplete segment

                energy = np.mean(np.abs(segment))
                if energy < energy_threshold:
                    self.segments.append((segment, LABELS["no_music"]))
                else:
                    self.segments.append((segment, label))

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        waveform_segment, label = self.segments[idx]
        mel_spec = self.mel_transform(torch.tensor(waveform_segment).unsqueeze(0))
        mel_spec = torch.log1p(mel_spec)
        return mel_spec, torch.tensor(label, dtype=torch.long)

# -------------------- CNN Model --------------------
class InstrumentClassifier(nn.Module):
    def __init__(self, num_classes=5):
        super(InstrumentClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        return self.fc1(x)



# -------------------- Evaluation Function --------------------
def evaluate(model, dataloader, label="Validation"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for mel_spec, labels in dataloader:
            mel_spec, labels = mel_spec.to(device), labels.to(device)
            outputs = model(mel_spec)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total if total > 0 else 0.0
    print(f"{label} Accuracy: {accuracy:.2f}%")
    return accuracy

# -------------------- Updated Main --------------------
if __name__ == "__main__":
    # Replace these with your actual stem paths
    stem_dict = {
        "S00.wav": LABELS["guitar"],
        "S01.wav": LABELS["drums"],
        "S02.wav": LABELS["piano"],
        "S03.wav": LABELS["bass"],
    }

    # Load full dataset (each item = one 2s segment)
    full_dataset = BabySlakhDataset(stem_dict)

    # Split: 70% train, 15% val, 15% test
    total_len = len(full_dataset)
    train_len = int(0.7 * total_len)
    val_len = int(0.15 * total_len)
    test_len = total_len - train_len - val_len

    train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_len, val_len, test_len])

    # Loaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

    # Model
    model = InstrumentClassifier(num_classes=len(LABELS))

    # Training with validation check each epoch
    def train_with_val(model, train_loader, val_loader, epochs, lr=0.001):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        for epoch in range(epochs):
            model.train()
            total_loss = 0
            correct = 0
            total = 0

            for mel_spec, labels in train_loader:
                mel_spec, labels = mel_spec.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(mel_spec)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)

            train_acc = 100 * correct / total
            print(f"Epoch {epoch+1}/{epochs} | Train Loss: {total_loss:.4f} | Train Accuracy: {train_acc:.2f}%")
            evaluate(model, val_loader, label="Validation")

    # Train the model with val tracking
    train_with_val(model, train_loader, val_loader, epochs=20)

    # Final test accuracy
    print("\n--- Final Test Evaluation ---")
    evaluate(model, test_loader, label="Test")




Epoch 1/20 | Train Loss: 65.8460 | Train Accuracy: 23.21%
Validation Accuracy: 26.39%
Epoch 2/20 | Train Loss: 61.1590 | Train Accuracy: 28.57%
Validation Accuracy: 51.39%
Epoch 3/20 | Train Loss: 54.4405 | Train Accuracy: 54.46%
Validation Accuracy: 69.44%
Epoch 4/20 | Train Loss: 45.6826 | Train Accuracy: 63.99%
Validation Accuracy: 83.33%
Epoch 5/20 | Train Loss: 38.1972 | Train Accuracy: 62.20%
Validation Accuracy: 76.39%
Epoch 6/20 | Train Loss: 32.2331 | Train Accuracy: 78.57%
Validation Accuracy: 93.06%
Epoch 7/20 | Train Loss: 27.5749 | Train Accuracy: 80.95%
Validation Accuracy: 94.44%
Epoch 8/20 | Train Loss: 24.1439 | Train Accuracy: 83.33%
Validation Accuracy: 94.44%
Epoch 9/20 | Train Loss: 21.5486 | Train Accuracy: 86.61%
Validation Accuracy: 93.06%
Epoch 10/20 | Train Loss: 19.5062 | Train Accuracy: 87.20%
Validation Accuracy: 91.67%
Epoch 11/20 | Train Loss: 18.6431 | Train Accuracy: 87.20%
Validation Accuracy: 93.06%
Epoch 12/20 | Train Loss: 17.0206 | Train Accuracy: 