<a href="https://colab.research.google.com/github/fred-dev/wav_gan/blob/main/Fred_WAV_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
  audio_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
  json_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
  model_path = "/content/drive/MyDrive/colab_storage/colab_output"
  output_path = "/content/drive/MyDrive/colab_storage/colab_output/"


In [None]:
!pip install wandb


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# 1. Import required libraries
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio.transforms as T
import torchaudio
import numpy as np
import wandb
from datetime import datetime
import torch.nn.functional as F



In [None]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33ms222405968[0m ([33msyntheticornithology[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
# 2. Define the dataset class
class AudioDataset(Dataset):
    def __init__(self, audio_folder, json_folder, transform=None):
        self.audio_folder = audio_folder
        self.json_folder = json_folder
        self.transform = transform
        self.file_list = [f for f in os.listdir(audio_folder) if f.endswith(".wav")]

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_path = os.path.join(self.audio_folder, self.file_list[idx])
        json_path = os.path.join(self.json_folder, os.path.splitext(self.file_list[idx])[0].rstrip('_P') + ".json")


        waveform, _ = torchaudio.load(audio_path)

        with open(json_path) as f:
            data = json.load(f)

        params = [
            data["coord"]["lat"],
            data["coord"]["lon"],
            data["wind"]["deg"],
            data["main"]["humidity"],
            data["wind"]["speed"],
            data["wind"]["deg"],
            data["main"]["pressure"],
            data["elevation"],
            data["minutesOfDay"],
            data["dayOfYear"],
        ]

        if self.transform:
            waveform = self.transform(waveform)

        return waveform, torch.tensor(params, dtype=torch.float32)



In [None]:
# 3. Define the generator and discriminator models
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.model = nn.Sequential(
            nn.Linear(256 * 128, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 256 * 128)
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)  # change input size to match output of Generator's first layer
        self.model = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, output_size),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.fc1(x)
        return self.model(x)




In [None]:
# 4. Define training functions
def train_discriminator(real_data, fake_data, optimizer, criterion):
    optimizer.zero_grad()

    real_preds = discriminator(real_data)
    real_loss = criterion(real_preds, torch.ones_like(real_preds))

    fake_preds = discriminator(fake_data)
    fake_loss = criterion(fake_preds, torch.zeros_like(fake_preds))

    total_loss = real_loss + fake_loss
    total_loss.backward()
    optimizer.step()

    return total_loss.item()

def train_generator(fake_data, optimizer, criterion):
    optimizer.zero_grad()

    preds = discriminator(fake_data)
    loss = criterion(preds, torch.ones_like(preds))

    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
# 5. Connect to Weights and Biases for tracking progress
wandb.init(project="audio-gan")

In [None]:
def custom_collate_fn(batch):
    data, params = zip(*batch)
    max_length = max([d.size(2) for d in data])
    
    padded_data = []
    for d in data:
        pad_len = max_length - d.size(2)
        padded_d = F.pad(d, (0, pad_len))
        padded_data.append(padded_d)
    
    data_tensor = torch.stack(padded_data)
    params_tensor = torch.stack(params)
    
    return data_tensor, params_tensor


In [None]:
# 6. Train the model
def train_gan(audio_folder, json_folder, epochs, batch_size, learning_rate, device, save_interval, model_path):
    # Define the MEL spectrogram transformation
    mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_mels=128, hop_length=1024, n_fft=2048).to(device)
    
    dataset = AudioDataset(audio_folder, json_folder, transform=mel_spectrogram_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)


    generator = Generator(10 + 128 * 128, 128 * 128).to(device)
    discriminator = Discriminator(128 * 128 + 10, 1).to(device)

    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    for epoch in range(1, epochs + 1):
        for batch_idx, (real_data, params) in enumerate(dataloader):
            real_data, params = real_data.to(device), params.to(device)
            batch_size = real_data.size(0)

            # Train discriminator
            optimizer_D.zero_grad()

            noise = torch.randn(batch_size, 128 * 128 - 10, device=device)
            z = torch.cat((noise, params), dim=1)
            fake_data = generator(z)

            real_validity = discriminator(torch.cat((real_data, params), dim=1))
            fake_validity = discriminator(torch.cat((fake_data.detach(), params), dim=1))

            real_loss = criterion(real_validity, torch.ones(batch_size, 1, device=device))
            fake_loss = criterion(fake_validity, torch.zeros(batch_size, 1, device=device))
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # Train generator
            optimizer_G.zero_grad()

            fake_validity = discriminator(torch.cat((fake_data, params), dim=1))
            g_loss = criterion(fake_validity, torch.ones(batch_size, 1, device=device))

            g_loss.backward()
            optimizer_G.step()

            print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx+1}/{len(dataloader)}] Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

        if epoch % save_interval == 0:
            torch.save(generator.state_dict(), os.path.join(model_path, f"generator_epoch_{epoch}.pth"))
            torch.save(discriminator.state_dict(), os.path.join(model_path, f"discriminator_epoch_{epoch}.pth"))

    torch.save(generator.state_dict(), os.path.join(model_path, "generator_final.pth"))
    torch.save(discriminator.state_dict(), os.path.join(model_path, "discriminator_final.pth"))



In [None]:
def generate_audio(generator_path, params, duration, output_folder, device):
    generator = Generator(10 + 128 * 128, 128 * 128).to(device)
    generator.load_state_dict(torch.load(generator_path))

    noise = torch.randn(1, 128 * 128 - 10, device=device)
    z = torch.cat((noise, torch.tensor(params, dtype=torch.float32).view(1, -1).to(device)), dim=1)
    fake_data = generator(z)

    fake_data = fake_data.view(1, 128, 128)
    mel_inverse = T.InverseMelScale(n_stft=1024, n_mels=128, sample_rate=44100)
    griffin_lim = T.GriffinLim(n_fft=2048, n_iter=32)

    waveform = griffin_lim(mel_inverse(fake_data))
    waveform = waveform[:, :int(duration * 44100)]

    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    output_audio_path = os.path.join(output_folder, f"generated_audio_{timestamp}.wav")
    output_json_path = os.path.join(output_folder, f"generated_audio_{timestamp}.json")

    torchaudio.save(output_audio_path, waveform, sample_rate=44100)

    parameter_names = [
        "Latitude",
        "Longitude",
        "Degrees",
        "Humidity",
        "Wind speed",
        "Wind direction",
        "Pressure",
        "Elevation",
        "Minutes of day",
        "Day of year",
    ]

    parameter_data = {name: value for name, value in zip(parameter_names, params)}

    with open(output_json_path, "w") as f:
        json.dump(parameter_data, f, indent=4)



In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_gan(audio_folder, json_folder, epochs=100, batch_size=32, learning_rate=0.0002, device=device, save_interval=50, model_path=model_path)


RuntimeError: ignored

In [None]:
# Example file generation
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
    generator_path = os.path.join(model_path, "generator_final.pth")
    params = [-24.8874 ,150.9657 , 23.16 , 73 , 4.78 , 8 , 1015 , 506 , 546 , 110]  # Replace with actual parameters
    duration = 5.0  # In seconds

    generate_audio(generator_path, params, duration, output_path, device)

FileNotFoundError: ignored