In [1]:
AUDIO_DIR = "/Users/zainhazzouri/projects/Bachelor_Thesis/Data/Kaggle"
SAMPLE_RATE = 22050 # sample rate of the audio file
bit_depth = 16 # bit depth of the audio file
hop_length = 512
n_mfcc = 20 # number of MFCCs features
n_fft=1024, # window size
n_mels = 256 # number of mel bands to generate
win_length = None # window length


# Training parameters
batch_size = 16
learning_rate = 0.001
num_epochs = 50




In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import glob
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import nbimporter
from CNN_Model import WaveUNet

class AudioDataset(Dataset):
    def __init__(self, music_waves, speech_waves, mixed_waves, silence_waves, transform=None):
        self.music_waves = music_waves
        self.speech_waves = speech_waves
        self.mixed_waves = mixed_waves
        self.silence_waves = silence_waves
        self.transform = transform
        self.file_list = self.music_waves + self.speech_waves + self.mixed_waves + self.silence_waves

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        waveform, _ = torchaudio.load(file_path)
        if 'music_wav' in file_path:
            label = 0
        elif 'speech_wav' in file_path:
            label = 1
        elif 'mixed_wav' in file_path:
            label = 2
        else:
            label = 3

        if self.transform:
            waveform = self.transform(waveform)

        return waveform, label



# Preprocessing function for the dataset
def preprocess(waveform, target_length=8000, sample_rate=SAMPLE_RATE, n_mfcc=n_mfcc):
    waveform_length = waveform.size(1)

    if waveform_length < target_length:
        num_padding = target_length - waveform_length
        padding = torch.zeros(1, num_padding)
        waveform = torch.cat((waveform, padding), 1)
    elif waveform_length > target_length:
        waveform = waveform[:, :target_length]

    mfcc = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc)(waveform)
    return mfcc

# Set device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_built():  # if you have apple silicon mac
    device = "mps"  # if it doesn't work try device = torch.device('mps')
else:
    device = "cpu"
print(f"Using {device}")

# Set the path to the folder containing the music and speech datasets
AUDIO_DIR = "/Users/zainhazzouri/projects/Bachelor_Thesis/Data/Kaggle/"

# Load the dataset
music_waves = glob.glob(AUDIO_DIR + "music_wav" + "/*.wav")
speech_waves = glob.glob(AUDIO_DIR + "speech_wav" + "/*.wav")
mixed_waves = glob.glob(AUDIO_DIR + "Mix_wav" + "/*.wav")
silence_waves = glob.glob(AUDIO_DIR + "silence_wav" + "/*.wav")


transform = preprocess

# Load the dataset
dataset = AudioDataset(music_waves, speech_waves, mixed_waves, silence_waves, transform=preprocess)





def pad_waveform(waveform, desired_length):
    if waveform.shape[-1] < desired_length:
        padding = desired_length - waveform.shape[-1]
        waveform = torch.nn.functional.pad(waveform, (0, padding))
    return waveform



# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss, and optimizer
model = WaveUNet(num_classes=4).to(device)  # Update num_classes to 4
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Using mps


In [3]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (inputs, labels) in enumerate(tqdm(train_loader)):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / (i + 1)
    epoch_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

    # Validation
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

        val_epoch_loss = val_running_loss / (i + 1)
        val_epoch_acc = 100 * val_correct / val_total
        print(f"Validation Loss: {val_epoch_loss:.4f}, Validation Accuracy: {val_epoch_acc:.2f}%")


100%|██████████| 7/7 [00:01<00:00,  4.82it/s]


Epoch 1/50, Loss: 0.9838, Accuracy: 51.46%
Validation Loss: 1.2126, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.92it/s]


Epoch 2/50, Loss: 0.8674, Accuracy: 52.43%
Validation Loss: 0.8766, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.88it/s]


Epoch 3/50, Loss: 0.8996, Accuracy: 49.51%
Validation Loss: 0.7724, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.59it/s]


Epoch 4/50, Loss: 0.8294, Accuracy: 52.43%
Validation Loss: 0.8772, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 14.62it/s]


Epoch 5/50, Loss: 0.7686, Accuracy: 52.43%
Validation Loss: 0.7798, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 14.89it/s]


Epoch 6/50, Loss: 0.7228, Accuracy: 44.66%
Validation Loss: 0.6863, Validation Accuracy: 61.54%


100%|██████████| 7/7 [00:00<00:00, 14.80it/s]


Epoch 7/50, Loss: 0.6979, Accuracy: 54.37%
Validation Loss: 0.7783, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 14.80it/s]


Epoch 8/50, Loss: 0.7096, Accuracy: 52.43%
Validation Loss: 0.6910, Validation Accuracy: 61.54%


100%|██████████| 7/7 [00:00<00:00, 14.89it/s]


Epoch 9/50, Loss: 0.7268, Accuracy: 43.69%
Validation Loss: 0.7195, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 14.77it/s]


Epoch 10/50, Loss: 0.6972, Accuracy: 52.43%
Validation Loss: 0.7628, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 14.88it/s]


Epoch 11/50, Loss: 0.7066, Accuracy: 52.43%
Validation Loss: 0.7105, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.05it/s]


Epoch 12/50, Loss: 0.7156, Accuracy: 42.72%
Validation Loss: 0.7075, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.07it/s]


Epoch 13/50, Loss: 0.6986, Accuracy: 52.43%
Validation Loss: 0.7175, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.40it/s]


Epoch 14/50, Loss: 0.7072, Accuracy: 52.43%
Validation Loss: 0.7016, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 14.89it/s]


Epoch 15/50, Loss: 0.7111, Accuracy: 52.43%
Validation Loss: 0.7102, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.01it/s]


Epoch 16/50, Loss: 0.6916, Accuracy: 52.43%
Validation Loss: 0.7358, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.00it/s]


Epoch 17/50, Loss: 0.6674, Accuracy: 52.43%
Validation Loss: 0.7005, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.02it/s]


Epoch 18/50, Loss: 0.6377, Accuracy: 52.43%
Validation Loss: 0.6535, Validation Accuracy: 38.46%


100%|██████████| 7/7 [00:00<00:00, 15.36it/s]


Epoch 19/50, Loss: 0.6002, Accuracy: 55.34%
Validation Loss: 0.5676, Validation Accuracy: 88.46%


100%|██████████| 7/7 [00:00<00:00, 14.49it/s]


Epoch 20/50, Loss: 0.6052, Accuracy: 73.79%
Validation Loss: 0.4730, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.79it/s]


Epoch 21/50, Loss: 0.4404, Accuracy: 85.44%
Validation Loss: 2.0857, Validation Accuracy: 50.00%


100%|██████████| 7/7 [00:00<00:00, 15.31it/s]


Epoch 22/50, Loss: 0.4012, Accuracy: 89.32%
Validation Loss: 0.4574, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.68it/s]


Epoch 23/50, Loss: 0.3084, Accuracy: 88.35%
Validation Loss: 0.5083, Validation Accuracy: 73.08%


100%|██████████| 7/7 [00:00<00:00, 14.68it/s]


Epoch 24/50, Loss: 0.3520, Accuracy: 86.41%
Validation Loss: 0.6250, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 15.05it/s]


Epoch 25/50, Loss: 0.3412, Accuracy: 87.38%
Validation Loss: 0.7697, Validation Accuracy: 76.92%


100%|██████████| 7/7 [00:00<00:00, 15.23it/s]


Epoch 26/50, Loss: 0.2193, Accuracy: 91.26%
Validation Loss: 0.5069, Validation Accuracy: 88.46%


100%|██████████| 7/7 [00:00<00:00, 14.92it/s]


Epoch 27/50, Loss: 0.1367, Accuracy: 92.23%
Validation Loss: 1.5187, Validation Accuracy: 73.08%


100%|██████████| 7/7 [00:00<00:00, 14.70it/s]


Epoch 28/50, Loss: 0.1571, Accuracy: 94.17%
Validation Loss: 0.6382, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.69it/s]


Epoch 29/50, Loss: 0.1157, Accuracy: 94.17%
Validation Loss: 0.7040, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 15.01it/s]


Epoch 30/50, Loss: 0.1129, Accuracy: 94.17%
Validation Loss: 2.5732, Validation Accuracy: 65.38%


100%|██████████| 7/7 [00:00<00:00, 14.98it/s]


Epoch 31/50, Loss: 0.0752, Accuracy: 97.09%
Validation Loss: 0.6982, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.73it/s]


Epoch 32/50, Loss: 0.0679, Accuracy: 97.09%
Validation Loss: 1.6551, Validation Accuracy: 73.08%


100%|██████████| 7/7 [00:00<00:00, 15.10it/s]


Epoch 33/50, Loss: 0.0371, Accuracy: 98.06%
Validation Loss: 0.9659, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 15.13it/s]


Epoch 34/50, Loss: 0.0304, Accuracy: 98.06%
Validation Loss: 1.7005, Validation Accuracy: 76.92%


100%|██████████| 7/7 [00:00<00:00, 14.64it/s]


Epoch 35/50, Loss: 0.0175, Accuracy: 99.03%
Validation Loss: 2.0333, Validation Accuracy: 76.92%


100%|██████████| 7/7 [00:00<00:00, 14.62it/s]


Epoch 36/50, Loss: 0.0130, Accuracy: 99.03%
Validation Loss: 1.8766, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.79it/s]


Epoch 37/50, Loss: 0.0092, Accuracy: 100.00%
Validation Loss: 1.9472, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.63it/s]


Epoch 38/50, Loss: 0.0023, Accuracy: 100.00%
Validation Loss: 2.0702, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.20it/s]


Epoch 39/50, Loss: 0.0006, Accuracy: 100.00%
Validation Loss: 2.3868, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.41it/s]


Epoch 40/50, Loss: 0.0006, Accuracy: 100.00%
Validation Loss: 2.6545, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 15.24it/s]


Epoch 41/50, Loss: 0.0001, Accuracy: 100.00%
Validation Loss: 2.7549, Validation Accuracy: 84.62%


100%|██████████| 7/7 [00:00<00:00, 14.54it/s]


Epoch 42/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 2.9355, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.96it/s]


Epoch 43/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.0686, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.94it/s]


Epoch 44/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.1738, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.83it/s]


Epoch 45/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.2605, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.64it/s]


Epoch 46/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.3559, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 15.08it/s]


Epoch 47/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.4527, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 15.22it/s]


Epoch 48/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.5391, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.45it/s]


Epoch 49/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.6185, Validation Accuracy: 80.77%


100%|██████████| 7/7 [00:00<00:00, 14.86it/s]


Epoch 50/50, Loss: 0.0000, Accuracy: 100.00%
Validation Loss: 3.7274, Validation Accuracy: 80.77%


In [4]:
# Evaluation function

#1. Average validation loss: This metric is calculated using the same loss function (`criterion`) used during training, which is CrossEntropyLoss in this case. The average validation loss is computed by summing the losses for all validation samples and then dividing by the number of validation samples. A lower average validation loss indicates better performance.
#
# 2. Validation accuracy: This metric measures the percentage of correctly classified samples in the validation set. The accuracy is calculated by counting the number of correct predictions, i.e., when the predicted label matches the true label, and then dividing by the total number of validation samples. A higher validation accuracy indicates better performance.
#
# These two metrics together provide a good evaluation of the model's performance on the validation set. The average validation loss helps assess the model's ability to minimize the loss function, while the validation accuracy measures how well the model is classifying the samples.


def evaluate(val_loader, model, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs = model(inputs)

            # Calculate loss
            loss = criterion(outputs, targets)

            # Update loss
            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    # Calculate average loss and accuracy
    avg_loss = running_loss / len(val_loader)
    accuracy = 100 * correct / total

    return avg_loss, accuracy

# Evaluate the model
val_loss, val_accuracy = evaluate(val_loader, model, criterion, device)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.2f}%")

Validation Loss: 3.7274
Validation Accuracy: 80.77%


In [5]:

# Save the trained model
torch.save(model.state_dict(), "waveunet_speech_music_discrimination.pth")
print("Model saved.")




Model saved.


In [6]:
# Show model summary
try:
    from torchsummary import summary
    print("\nModel summary:")

    original_device = device
    if device == 'mps': # because MPS is not supported by torchsummary
        device = 'cpu'
    model.to(device)

    summary(model, input_size=(1, 40, 431), device=device)

    if original_device == 'mps': # Restore original device
        device = original_device
        model.to(device)

except ImportError:
    print("\nPlease install torchsummary to display the model summary. Use `pip install torchsummary`.")


Model summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 40, 20, 216]           1,040
              ReLU-2          [-1, 40, 20, 216]               0
            Conv2d-3          [-1, 80, 10, 108]          80,080
              ReLU-4          [-1, 80, 10, 108]               0
            Conv2d-5           [-1, 160, 5, 54]         320,160
              ReLU-6           [-1, 160, 5, 54]               0
   ConvTranspose2d-7          [-1, 80, 10, 108]         320,080
              ReLU-8          [-1, 80, 10, 108]               0
   ConvTranspose2d-9          [-1, 40, 20, 216]          80,040
             ReLU-10          [-1, 40, 20, 216]               0
  ConvTranspose2d-11          [-1, 40, 40, 432]          40,040
             ReLU-12          [-1, 40, 40, 432]               0
AdaptiveAvgPool2d-13             [-1, 40, 1, 1]               0
           Linear-14   