In [1]:
AUDIO_DIR = "/Users/zainhazzouri/projects/Bachelor_Thesis/Data1/Kaggle/"
SAMPLE_RATE = 22050 # sample rate of the audio file
bit_depth = 16 # bit depth of the audio file
hop_length = 512
n_mfcc = 20 # number of MFCCs features
n_fft=1024, # window size
n_mels = 256 # number of mel bands to generate
win_length = None # window length



In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import glob
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class SpeechMusicDataset(Dataset):
    def __init__(self, music_waves, speech_waves, transform=None):
        self.music_waves = music_waves
        self.speech_waves = speech_waves
        self.transform = transform
        self.file_list = self.music_waves + self.speech_waves

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        waveform, _ = torchaudio.load(file_path)
        label = 0 if 'music_wav' in file_path else 1

        if self.transform:
            waveform = self.transform(waveform)

        return waveform, label

# Preprocessing function for the dataset
def preprocess(waveform, target_length=8000, sample_rate=SAMPLE_RATE, n_mfcc=n_mfcc):
    waveform_length = waveform.size(1)

    if waveform_length < target_length:
        num_padding = target_length - waveform_length
        padding = torch.zeros(1, num_padding)
        waveform = torch.cat((waveform, padding), 1)
    elif waveform_length > target_length:
        waveform = waveform[:, :target_length]

    mfcc = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc)(waveform)
    return mfcc

# Set device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_built():  # if you have apple silicon mac
    device = "mps"  # if it doesn't work try device = torch.device('mps')
else:
    device = "cpu"
print(f"Using {device}")

# Set the path to the folder containing the music and speech datasets
AUDIO_DIR = "/Users/zainhazzouri/projects/Bachelor_Thesis/Data1/Kaggle/"

# Load the dataset
music_waves = glob.glob(AUDIO_DIR + "music_wav" + "/*.wav")
speech_waves = glob.glob(AUDIO_DIR + "speech_wav" + "/*.wav")
transform = preprocess

dataset = SpeechMusicDataset(music_waves, speech_waves, transform=transform)






def pad_waveform(waveform, desired_length):
    if waveform.shape[-1] < desired_length:
        padding = desired_length - waveform.shape[-1]
        waveform = torch.nn.functional.pad(waveform, (0, padding))
    return waveform






# Load the dataset
music_waves = glob.glob(AUDIO_DIR + "music_wav" + "/*.wav")
speech_waves = glob.glob(AUDIO_DIR + "speech_wav" + "/*.wav")
transform = preprocess

dataset = SpeechMusicDataset(music_waves, speech_waves, transform=transform)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])



# Define Wave-U-Net architecture
class WaveUNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=10, num_features=40):
        super(WaveUNet, self).__init__()

        # Encoding layers
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, num_features, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)),
            nn.ReLU(),
            nn.Conv2d(num_features, num_features * 2, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)),
            nn.ReLU(),
            nn.Conv2d(num_features * 2, num_features * 4, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)),
            nn.ReLU(),
        )

        # Decoding layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(num_features * 4, num_features * 2, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)),
            nn.ReLU(),
            nn.ConvTranspose2d(num_features * 2, num_features, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)),
            nn.ReLU(),
            nn.ConvTranspose2d(num_features, num_features, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1)),
            nn.ReLU(),
        )

        # Global average pooling and fully connected layer
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x


# Training parameters
batch_size = 16
num_epochs = 50
learning_rate = 0.001

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss, and optimizer
model = WaveUNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch: {epoch+1}/{num_epochs}")

    model.train()
    running_loss = 0.0

    # Add tqdm progress bar
    for i, (inputs, targets) in enumerate(tqdm(train_loader, desc="Training", ncols=100)):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Calculate loss
        loss = criterion(outputs, targets)

        # Backward pass
        loss.backward()

        # Optimize
        optimizer.step()

        # Update loss
        running_loss += loss.item()

    # Calculate average loss for the epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Loss: {epoch_loss:.4f}")

print("Training finished.")

# Save the trained model
torch.save(model.state_dict(), "waveunet_speech_music_discrimination.pth")




Using mps
Epoch: 1/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:05<00:00,  1.27it/s]


Loss: 1.5180
Epoch: 2/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.27it/s]


Loss: 0.8268
Epoch: 3/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.53it/s]


Loss: 0.6912
Epoch: 4/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.23it/s]


Loss: 0.9955
Epoch: 5/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.64it/s]


Loss: 0.8157
Epoch: 6/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.23it/s]


Loss: 0.7512
Epoch: 7/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.79it/s]


Loss: 0.7527
Epoch: 8/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.94it/s]


Loss: 0.7241
Epoch: 9/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.04it/s]


Loss: 0.7374
Epoch: 10/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.13it/s]


Loss: 0.6971
Epoch: 11/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.99it/s]


Loss: 0.6992
Epoch: 12/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.09it/s]


Loss: 0.7014
Epoch: 13/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.71it/s]


Loss: 0.7047
Epoch: 14/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.22it/s]


Loss: 0.6978
Epoch: 15/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.54it/s]


Loss: 0.6927
Epoch: 16/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.05it/s]


Loss: 0.6947
Epoch: 17/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 14.85it/s]


Loss: 0.6886
Epoch: 18/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.96it/s]


Loss: 0.6924
Epoch: 19/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.57it/s]


Loss: 0.6950
Epoch: 20/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.48it/s]


Loss: 0.6970
Epoch: 21/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.93it/s]


Loss: 0.6807
Epoch: 22/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.75it/s]


Loss: 0.6813
Epoch: 23/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.12it/s]


Loss: 0.6720
Epoch: 24/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.35it/s]


Loss: 0.6693
Epoch: 25/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.91it/s]


Loss: 0.6554
Epoch: 26/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.89it/s]


Loss: 0.6273
Epoch: 27/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.70it/s]


Loss: 0.6240
Epoch: 28/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.99it/s]


Loss: 0.5883
Epoch: 29/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.82it/s]


Loss: 0.4485
Epoch: 30/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.49it/s]


Loss: 0.3740
Epoch: 31/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.05it/s]


Loss: 0.3752
Epoch: 32/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.12it/s]


Loss: 0.3651
Epoch: 33/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.32it/s]


Loss: 0.2847
Epoch: 34/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.65it/s]


Loss: 0.3351
Epoch: 35/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.30it/s]


Loss: 0.2395
Epoch: 36/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.56it/s]


Loss: 0.1883
Epoch: 37/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.73it/s]


Loss: 0.2177
Epoch: 38/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.97it/s]


Loss: 0.1616
Epoch: 39/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.05it/s]


Loss: 0.1732
Epoch: 40/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.01it/s]


Loss: 0.0870
Epoch: 41/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.28it/s]


Loss: 0.1419
Epoch: 42/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.88it/s]


Loss: 0.0665
Epoch: 43/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.82it/s]


Loss: 0.0453
Epoch: 44/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.71it/s]


Loss: 0.0555
Epoch: 45/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.49it/s]


Loss: 0.1068
Epoch: 46/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.72it/s]


Loss: 0.1856
Epoch: 47/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.01it/s]


Loss: 0.0819
Epoch: 48/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.22it/s]


Loss: 0.0411
Epoch: 49/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.36it/s]


Loss: 0.0153
Epoch: 50/50


Training: 100%|███████████████████████████████████████████████████████| 7/7 [00:00<00:00, 15.64it/s]

Loss: 0.0059
Training finished.





In [8]:
# Show model summary
try:
    from torchsummary import summary
    print("\nModel summary:")

    original_device = device
    if device == 'mps': # because MPS is not supported by torchsummary
        device = 'cpu'
    model.to(device)

    summary(model, input_size=(1, 40, 431), device=device)

    if original_device == 'mps': # Restore original device
        device = original_device
        model.to(device)

except ImportError:
    print("\nPlease install torchsummary to display the model summary. Use `pip install torchsummary`.")


Model summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 40, 20, 216]           1,040
              ReLU-2          [-1, 40, 20, 216]               0
            Conv2d-3          [-1, 80, 10, 108]          80,080
              ReLU-4          [-1, 80, 10, 108]               0
            Conv2d-5           [-1, 160, 5, 54]         320,160
              ReLU-6           [-1, 160, 5, 54]               0
   ConvTranspose2d-7          [-1, 80, 10, 108]         320,080
              ReLU-8          [-1, 80, 10, 108]               0
   ConvTranspose2d-9          [-1, 40, 20, 216]          80,040
             ReLU-10          [-1, 40, 20, 216]               0
  ConvTranspose2d-11          [-1, 40, 40, 432]          40,040
             ReLU-12          [-1, 40, 40, 432]               0
AdaptiveAvgPool2d-13             [-1, 40, 1, 1]               0
           Linear-14   