In [55]:
import librosa
import numpy as np
from pathlib import Path
import IPython.display as ipd
import torch
from torch.utils.data import Dataset, DataLoader

def preprocess_audio(file_path, target_sample_rate=16384, target_length=16384):

    audio, sr = librosa.load(file_path, sr=target_sample_rate)

    # Truncate or pad the audio to the target length
    if len(audio) > target_length:
        audio = audio[:target_length]
    elif len(audio) < target_length:
        audio = np.pad(audio, (0, target_length - len(audio)), 'constant')
    
    return audio

In [56]:
class DrumDataset(Dataset):
    def __init__(self, file_paths):
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        # Load and process the audio
        signal, sr = librosa.load(self.file_paths[idx], sr=None)
        mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sr, n_fft=2048, hop_length=512, n_mels=128)
        mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

        # Normalize the spectrogram
        mel_spectrogram = (mel_spectrogram - mel_spectrogram.min()) / (mel_spectrogram.max() - mel_spectrogram.min())
        mel_spectrogram = torch.tensor(mel_spectrogram, dtype=torch.float32)

        # Add channel dimension for CNN
        mel_spectrogram = mel_spectrogram.unsqueeze(0)

        return mel_spectrogram

In [49]:
import torch.nn as nn
class WaveGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, output_length=16384, ngf=64):
        super().__init__()
        self.latent_dim = latent_dim
        self.output_length = output_length
        self.ngf = ngf
        
        self.layer1 = nn.ConvTranspose1d(latent_dim, ngf * 16, kernel_size=25, stride=4, padding=11)
        self.layer2 = nn.ConvTranspose1d(ngf * 16, ngf * 8, kernel_size=25, stride=4, padding=11)
        self.layer3 = nn.ConvTranspose1d(ngf * 8, ngf * 4, kernel_size=25, stride=4, padding=11)
        self.layer4 = nn.ConvTranspose1d(ngf * 4, ngf * 2, kernel_size=25, stride=4, padding=11)
        self.layer5 = nn.ConvTranspose1d(ngf * 2, ngf, kernel_size=25, stride=4, padding=11)
        self.layer6 = nn.ConvTranspose1d(ngf, 1, kernel_size=25, stride=4, padding=11)
        self.activation = nn.Tanh()

    def forward(self, z):
        z = z.view(z.size(0), self.latent_dim, 1)
        z = self.layer1(z)
        
        z = self.layer2(z)
        
        z = self.layer3(z)
        
        z = self.layer4(z)
        z = self.layer5(z)
        
        z = self.layer6(z)
        
        z = self.activation(z)
        
        return z

In [60]:
import torch.nn as nn

class WaveGANDiscriminator(nn.Module):
    def __init__(self, ndf=64):
        super(WaveGANDiscriminator, self).__init__()
        self.ndf = ndf

        self.layer1 = nn.Conv1d(1, ndf, 25, stride=4, padding=11)
        self.layer2 = nn.Conv1d(ndf, ndf * 2, 25, stride=4, padding=11)
        self.layer3 = nn.Conv1d(ndf * 2, ndf * 4, 25, stride=4, padding=11)
        self.layer4 = nn.Conv1d(ndf * 4, ndf * 8, 25, stride=4, padding=11)
        self.layer5 = nn.Conv1d(ndf * 8, ndf * 16, 25, stride=4, padding=11)
        self.layer6 = nn.Conv1d(ndf * 16, 1, 25, stride=4, padding=11)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.activation(x)
        x = torch.mean(x, dim=-1, keepdim=True)
        x = self.activation(x)  # Apply sigmoid activation
        x = x.view(x.size(0))
        
        return x

In [61]:
generator = WaveGANGenerator()
discriminator = WaveGANDiscriminator()

In [66]:
from torch import optim

criterion = nn.BCELoss()
lr  = 0.0002
beta1 = 0.5

optimizerG = optim.Adam(generator.parameters(), lr=0.0005, betas=(beta1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=0.0005, betas=(beta1, 0.999))

In [67]:
from torch.utils.data import DataLoader

samples = [preprocess_audio(p) for p in Path().glob('../data/Kicks/*.wav')]

batch_size = 8
dataloader = DataLoader(samples, batch_size=batch_size, shuffle=False)

real_label = 1
fake_label = 0

In [68]:
# for i, real_audio in enumerate(dataloader):
#     print(real_audio.shape)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 5
latent_dim = 100

for epoch in range(num_epochs):
    for i, real_audio in enumerate(dataloader):
        # ---------------------
        # (1) Train Discriminator
        # ---------------------
        discriminator.zero_grad()
        
        # Real audio: Ensure it's shaped (batch_size, 1, length)
        real_audio = real_audio.to(device)
        real_audio = real_audio.unsqueeze(1)  # Add channel dimension to make it (batch_size, 1, length)

        # Set label size dynamically based on the batch size
        batch_size = real_audio.size(0)
        label = torch.full((batch_size,), real_label, device=device).float()  # Label should match batch size
        
        # Forward pass real audio through the discriminator
        # if(i==0): print(real_audio.shape)
        output = discriminator(real_audio)  # Expecting (batch_size,) output
        # if(i==0): print(output.size(), label.size(), real_audio.size(0))
        lossD_real = criterion(output, label)
        lossD_real.backward()

        # Generate fake audio and ensure it's (batch_size, 1, length)
        noise = torch.randn(batch_size, latent_dim, device=device)
        fake_audio = generator(noise)
        if fake_audio.dim() == 2:
            fake_audio = fake_audio.unsqueeze(1)  # Add channel dimension to make it (batch_size, 1, length)

        label.fill_(fake_label)  # Fill the label for fake data to match batch size
        output = discriminator(fake_audio.detach()).view(-1)
        lossD_fake = criterion(output, label)
        lossD_fake.backward()

        # Update the discriminator
        optimizerD.step()

        # ---------------------
        # (2) Train Generator
        # ---------------------
        generator.zero_grad()
        label.fill_(real_label)  # Generator wants discriminator to classify fake as real
        output = discriminator(fake_audio).view(-1)
        lossG = criterion(output, label)
        lossG.backward()

        # Update the generator
        optimizerG.step()

        # Print loss stats every few iterations
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] "
                  f"Loss_D: {lossD_real + lossD_fake:.4f} Loss_G: {lossG:.4f}")

In [None]:
import soundfile as sf

# Generate a fake audio sample
with torch.no_grad():
    noise = torch.randn(1, latent_dim, device=device)  # Generate one latent vector
    generated_audio = generator(noise).cpu().numpy()

# Save the generated audio to a file
sf.write('generated_audio.wav', generated_audio.squeeze(), samplerate=16000)