<a href="https://colab.research.google.com/github/fred-dev/wav_gan/blob/main/Fred_WAV_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
  audio_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
  json_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
  model_path = "/content/drive/MyDrive/colab_storage/colab_output"
  output_path = "/content/drive/MyDrive/colab_storage/colab_output/"



In [3]:
!pip install wandb


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.14.0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m25.6 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0
  Downloading GitPython-3.1.31-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 KB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.18.0-py2.py3-none-any.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 KB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
Colle

In [4]:
# 1. Import required libraries
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as T
import torchaudio
import numpy as np
import wandb
from datetime import datetime
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_sequence




In [5]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [74]:
# 2. Define the dataset class
from torch.nn.utils.rnn import pack_sequence

class AudioDataset(Dataset):
    def __init__(self, audio_folder, json_folder, transform=None):
        self.audio_folder = audio_folder
        self.json_folder = json_folder
        self.transform = transform
        self.file_list = [f for f in os.listdir(audio_folder) if f.endswith(".wav")]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_path = os.path.join(self.audio_folder, self.file_list[idx])
        json_path = os.path.join(self.json_folder, os.path.splitext(self.file_list[idx])[0].rstrip('_P') + ".json")

        print(f"Loading {idx+1}/{len(self.file_list)}: {audio_path}")

        waveform, _ = torchaudio.load(audio_path)

        with open(json_path) as f:
            data = json.load(f)

        params = [
            data["coord"]["lat"],
            data["coord"]["lon"],
            data["wind"]["deg"],
            data["main"]["humidity"],
            data["wind"]["speed"],
            data["wind"]["deg"],
            data["main"]["pressure"],
            data["elevation"],
            data["minutesOfDay"],
            data["dayOfYear"],
        ]

        if self.transform:
            waveform = self.transform(waveform)

        return waveform.squeeze(), torch.tensor(params, dtype=torch.float32)


def collate_fn(batch):
    # Sort the batch by sequence length (descending order)
    batch = sorted(batch, key=lambda x: x[0].size(1), reverse=True)
    
    # Create a list of the sequence lengths for packed sequences
    seq_lengths = [x[0].size(1) for x in batch]
    
    # Pad the batch to have sequences of equal length
    padded_batch = [(
        F.pad(item[0], pad=(0, 0, 0, max(seq_lengths) - item[0].size(-1))),  # Padded sequence
        item[1],  # Original sequence length
        item[2]  # Parameters
    ) for item in batch]
    
    # Convert the padded batch to a packed sequence
    packed_batch = nn.utils.rnn.pack_sequence([x[0] for x in padded_batch])
    
    return packed_batch, torch.tensor([x[1] for x in padded_batch], dtype=torch.long), [x[2] for x in padded_batch]



In [65]:
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, audio_dim):
        super(Generator, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.audio_dim = audio_dim

        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, audio_dim)

    def forward(self, x):
        output, _ = self.lstm(x)
        output = self.linear(output)
        return output


class Discriminator(nn.Module):
    def __init__(self, audio_dim, hidden_dim, num_layers, output_dim):
        super(Discriminator, self).__init__()

        self.audio_dim = audio_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim

        self.lstm = nn.LSTM(audio_dim, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        output, _ = self.lstm(x)
        output = self.linear(output[:, -1, :])
        return output





In [60]:
def train_discriminator(real_data, fake_data, params, optimizer, criterion):
    optimizer.zero_grad()

    real_preds = discriminator(torch.cat((real_data, params), dim=1))
    real_loss = criterion(real_preds, torch.ones_like(real_preds))

    fake_preds = discriminator(torch.cat((fake_data, params), dim=1))
    fake_loss = criterion(fake_preds, torch.zeros_like(fake_preds))

    total_loss = real_loss + fake_loss
    total_loss.backward()
    optimizer.step()

    return total_loss.item()

def train_generator(fake_data, params, optimizer, criterion):
    optimizer.zero_grad()

    preds = discriminator(torch.cat((fake_data, params), dim=1))
    loss = criterion(preds, torch.ones_like(preds))

    loss.backward()
    optimizer.step()

    return loss.item()


In [40]:
# 5. Connect to Weights and Biases for tracking progress
wandb.init(project="audio-gan")

In [76]:
def train_gan(audio_folder, json_folder, epochs, batch_size, learning_rate, device, save_interval, model_path):
    # Define the MEL spectrogram transformation
    mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_mels=128, hop_length=1024, n_fft=2048)
    mel_spectrogram_transform.n_mels = 30 # Update the number of Mel bands

    dataset = AudioDataset(audio_folder, json_folder, transform=mel_spectrogram_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)

    input_dim = 30 # Update the input dimension
    hidden_dim = 128
    num_layers = 1
    audio_dim = 30  # Assuming the Mel spectrogram has 30 dimensions. Update the audio dimension
    generator = Generator(input_dim, hidden_dim, num_layers, audio_dim).to(device)

    discriminator = Discriminator(audio_dim, hidden_dim, num_layers, 1).to(device)


    initial_generator_path = os.path.join(model_path, "generator_initial.pth")
    initial_discriminator_path = os.path.join(model_path, "discriminator_initial.pth")

    if not os.path.exists(initial_generator_path):
        torch.save(generator.state_dict(), initial_generator_path)

    if not os.path.exists(initial_discriminator_path):
        torch.save(discriminator.state_dict(), initial_discriminator_path)

    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    for epoch in range(1, epochs + 1):
        for batch_idx, (packed_real_data, params) in enumerate(dataloader):
            packed_real_data, params = packed_real_data.to(device), params.to(device)
            batch_size = packed_real_data.batch_sizes[0]

            # Train discriminator
            optimizer_D.zero_grad()

            noise = torch.randn(batch_size, input_dim, device=device)
            z = torch.cat((noise.unsqueeze(1), params.unsqueeze(1)), dim=2)
            packed_fake_data = generator(z)

            real_data, _ = pad_packed_sequence(packed_real_data, batch_first=True)
            fake_data, _ = pad_packed_sequence(packed_fake_data, batch_first=True)

            real_validity = discriminator(torch.cat((real_data, params.unsqueeze(1).repeat(1, real_data.size(1), 1)), dim=2))

            fake_validity = discriminator(torch.cat((fake_data.detach(), params.unsqueeze(1).repeat(1, fake_data.size(1), 1)), dim=2))

            real_loss = criterion(real_validity, torch.ones(batch_size, 1, device=device))
            fake_loss = criterion(fake_validity, torch.zeros(batch_size, 1, device=device))
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # Train generator
            optimizer_G.zero_grad()

            fake_validity = discriminator(torch.cat((fake_data, params.unsqueeze(1).repeat(1, fake_data.size(1), 1)), dim=2))
            g_loss = criterion(fake_validity, torch.ones(batch_size, 1, device=device))

            g_loss.backward()
            optimizer_G.step()

            print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx+1}/{len(dataloader)}] Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

        if epoch % save_interval == 0:
            torch.save(generator.state_dict(), os.path.join(model_path, f"generator_epoch_{epoch}.pth"))
            torch.save(discriminator.state_dict(), os.path.join(model_path, f"discriminator_epoch_{epoch}.pth"))

    torch.save(generator.state_dict(), os.path.join(model_path, "generator_final.pth"))
    torch.save(discriminator.state_dict(), os.path.join(model_path, "discriminator_final.pth"))




In [67]:
def generate_audio(generator_path, params, duration, output_folder, device):
    generator = Generator(input_dim=10 + 128 * 128, output_dim=128 * 128, hidden_size=256, num_layers=2).to(device)
    generator.load_state_dict(torch.load(generator_path))
    generator.eval()

    params = torch.tensor(params, dtype=torch.float32).unsqueeze(0).to(device)
    num_steps = int(duration * 44100 / 1024)
    generated_waveforms = []

    print("Generating audio...")

    with torch.no_grad():
        hidden = None
        for step in range(num_steps):
            if hidden is None:
                noise = torch.randn(1, 1, 128 * 128 - 10, device=device)
            else:
                noise = torch.randn(1, 1, 128 * 128 - 10, device=device)

            z = torch.cat((noise, params.unsqueeze(1)), dim=2)
            output, hidden = generator(z, hidden)
            output_waveform = output.squeeze().detach().cpu()
            generated_waveforms.append(output_waveform)

            if step % (num_steps // 10) == 0:
                print(f"Step {step}/{num_steps}")

    print("Audio generation completed.")

    generated_waveform = torch.cat(generated_waveforms, dim=0)
    generated_waveform = generated_waveform.view(1, 128, -1)
    
    mel_inverse = T.InverseMelScale(n_stft=1024, n_mels=128, sample_rate=44100)
    griffin_lim = T.GriffinLim(n_fft=2048, n_iter=32)

    waveform = griffin_lim(mel_inverse(generated_waveform))
    waveform = waveform[:, :int(duration * 44100)]

    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    output_audio_path = os.path.join(output_folder, f"generated_audio_{timestamp}.wav")
    output_json_path = os.path.join(output_folder, f"generated_audio_{timestamp}.json")

    torchaudio.save(output_audio_path, waveform, sample_rate=44100)

    parameter_names = [
        "Latitude",
        "Longitude",
        "Degrees",
        "Humidity",
        "Wind speed",
        "Wind direction",
        "Pressure",
        "Elevation",
        "Minutes of day",
        "Day of year",
    ]

    parameter_data = {name: value for name, value in zip(parameter_names, params)}

    with open(output_json_path, "w") as f:
        json.dump(parameter_data, f, indent=4)




In [77]:
if __name__ == "__main__":
    os.makedirs(model_path, exist_ok=True)

   

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_gan(audio_folder, json_folder, epochs=100, batch_size=32, learning_rate=0.0002, device=device, save_interval=50, model_path=model_path)



Loading 464/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2015-10-29-08-43-44-210_P.wav
Loading 256/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2019-09-02-17-38-22-137_P.wav
Loading 460/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2015-10-17-08-34-47-374_P.wav
Loading 367/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2011-12-27-11-50-38-439_P.wav
Loading 341/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2011-12-19-10-01-36-886_P.wav
Loading 661/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2020-06-18-10-39-29-801_P.wav
Loading 165/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2022-11-12-08-20-41-169_P.wav
Loading 514/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2016-11-26-10-11-00-372_P.wav
Loading 567/758: /content/drive/MyDrive/colab_storage/ronxgin_data_samples/2018-08-25-02-07-14-927_P.wav
Loading 512/758: /content/drive/MyDrive/colab_storage/r

IndexError: ignored

In [None]:
# Example file generation
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
    generator_path = os.path.join(model_path, "generator_final.pth")
    params = [-24.8874 ,150.9657 , 23.16 , 73 , 4.78 , 8 , 1015 , 506 , 546 , 110]  # Replace with actual parameters
    duration = 5.0  # In seconds

    generate_audio(generator_path, params, duration, output_path, device)