<a href="https://colab.research.google.com/github/fred-dev/wav_gan/blob/main/Fred_WAV_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive") # Don't change this.
audio_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
json_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
model_path = "/content/drive/MyDrive/colab_storage/colab_output"
output_path = "/content/drive/MyDrive/colab_storage/colab_output/"

In [None]:
!pip install wandb

In [None]:
# 1. Import required libraries
import os
import json
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import Dataset, DataLoader, BatchSampler, SubsetRandomSampler
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.nn.functional import multi_head_attention_forward
from torch.nn import MultiheadAttention

In [None]:
# 2. Define the dataset class
class AudioDataset(Dataset):
    def __init__(self, audio_folder, json_folder, transform=None):
        # Initialize instance variables for the dataset class
        self.audio_folder = audio_folder
        self.json_folder = json_folder
        self.transform = transform
        self.MAX_LENGTH = 400
        self.MAX_NUM_FRAMES = 5 * 44100  # 5 seconds of audio frames
        self.mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_mels=30)
        
        # Create a sorted list of audio files in the audio_folder
        self.file_list = sorted([f for f in os.listdir(audio_folder) if os.path.splitext(f)[1] == '.wav'])

    # Define the length method to return the total number of audio files in the dataset
    def __len__(self):
        return len(self.file_list)
    
    # Define the getitem method to return a specific sample from the dataset given its index
    def __getitem__(self, idx):
        # Construct the path of the audio and JSON files for the current index
        audio_path = os.path.join(self.audio_folder, self.file_list[idx])
        json_path = os.path.join(self.json_folder, os.path.splitext(self.file_list[idx])[0].rstrip('_P') + ".json")

        print(f"Loading {idx+1}/{len(self.file_list)}: {audio_path}")
        
        # Load the waveform from the audio file and limit the number of frames to MAX_NUM_FRAMES
        waveform, _ = torchaudio.load(audio_path, num_frames=self.MAX_NUM_FRAMES)

        # Load the JSON file containing parameters
        with open(json_path) as f:
            data = json.load(f)

        # Extract the parameters from the JSON data
        params = [
            data["coord"]["lat"],
            data["coord"]["lon"],
            data["wind"]["deg"],
            data["main"]["humidity"],
            data["wind"]["speed"],
            data["wind"]["deg"],
            data["main"]["pressure"],
            data["elevation"],
            data["minutesOfDay"],
            data["dayOfYear"],
        ]

        # Apply the Mel spectrogram transformation to the waveform
        mel_spec = self.mel_transform(waveform.squeeze())
        # Truncate the Mel spectrogram to MAX_LENGTH
        mel_spec = mel_spec[:, :self.MAX_LENGTH]

        # Convert the parameters to a tensor and expand its dimensions to match the Mel spectrogram
        params_tensor = torch.tensor(params, dtype=torch.float32).unsqueeze(1)
        params_tensor = params_tensor.expand(-1, mel_spec.size(1))

        # Concatenate the Mel spectrogram and the parameters tensor
        features = torch.cat((mel_spec, params_tensor), dim=0)

        # Return the features tensor and the index as a tensor
        return features, torch.tensor(idx, dtype=torch.int64)

def collate_fn(batch):
    # Sort the batch by sequence length (descending order) to facilitate padding
    batch = sorted(batch, key=lambda x: x[0].size(1), reverse=True)
    batch_size = len(batch)  # Determine the batch size from the input batch

    # Create a list of the sequence lengths for packed sequences
    seq_lengths = [x[0].size(1) for x in batch]

    # Initialize a tensor of zeros with the appropriate dimensions for padding
    padded_waveforms = torch.zeros(batch_size, batch[0][0].size(0), max(seq_lengths))
    for i, (waveform, _) in enumerate(batch):
        # Copy the waveform data into the padded_waveforms tensor
        padded_waveforms[i, :, :seq_lengths[i]] = waveform

    # Convert the padded batch to a packed sequence
    # This allows for efficient processing of variable-length sequences
    packed_batch = nn.utils.rnn.pack_padded_sequence(padded_waveforms, seq_lengths, batch_first=False, enforce_sorted=True)

    # Extract the list of parameter tensors from the batch
    params_list = [x[1] for x in batch]
    
    # Stack the parameter tensors into a single tensor
    params_tensor = torch.stack(params_list, dim=0)

    # Return the packed batch, the parameters tensor, and a list of indices
    return packed_batch, params_tensor


# Define the Generator class
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, audio_dim, additional_features=10):
        super(Generator, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.audio_dim = audio_dim
        self.additional_features = additional_features

        self.attention = MultiheadAttention(embed_dim=input_dim + additional_features, num_heads=4)
        self.lstm = nn.LSTM(input_dim + additional_features, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, audio_dim)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        output, _ = self.lstm(attn_output)
        output = self.linear(output)
        return output


# Define the Discriminator class
class Discriminator(nn.Module):
    def __init__(self, audio_dim, hidden_dim, num_layers, output_dim):
        super(Discriminator, self).__init__()

        self.audio_dim = audio_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.output_dim = output_dim

        self.attention = MultiheadAttention(embed_dim=audio_dim + 10, num_heads=4)
        self.lstm = nn.LSTM(audio_dim + 10, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, params):
        x_with_params = torch.cat((x, params.unsqueeze(1).repeat(1, x.size(1), 1)), dim=2)
        attn_output, _ = self.attention(x_with_params, x_with_params, x_with_params)
        output, _ = self.lstm(attn_output)
        output = self.linear(output[:, -1, :])
        return output




def train_discriminator(real_data, fake_data, params, optimizer, criterion):
    # Zero the gradients of the optimizer
    optimizer.zero_grad()
    
    # Get predictions from the discriminator for the real data
    real_preds = discriminator(real_data, params)

    # Calculate the loss for the real data (real_preds should be close to 1)
    real_loss = criterion(real_preds, torch.ones_like(real_preds))

    # Get predictions from the discriminator for the fake (generated) data
    fake_preds = discriminator(fake_data, params)

    # Calculate the loss for the fake data (fake_preds should be close to 0)
    fake_loss = criterion(fake_preds, torch.zeros_like(fake_preds))

    # Calculate the total loss by summing the real and fake losses
    total_loss = real_loss + fake_loss
    
    # Perform backpropagation to update the discriminator's weights
    total_loss.backward()
    optimizer.step()

    return total_loss.item()

def train_generator(fake_data, params, optimizer, criterion):
    # Zero the gradients of the optimizer
    optimizer.zero_grad()

    # Get predictions from the discriminator for the fake (generated) data
    preds = discriminator(fake_data, params)

    # Calculate the loss for the generator (preds should be close to 1)
    loss = criterion(preds, torch.ones_like(preds))

    # Perform backpropagation to update the generator's weights
    loss.backward()
    optimizer.step()

    return loss.item()

def train_gan(audio_folder, json_folder, epochs, batch_size, learning_rate, device, save_interval, model_path):
    MAX_LENGTH = 431 

    # Define the MEL spectrogram transformation
    mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_mels=128, hop_length=1024, n_fft=2048)
    mel_spectrogram_transform.n_mels = 30 # Set the number of Mel bands

    # Create the dataset and dataloader 
    dataset = AudioDataset(audio_folder, json_folder, transform=mel_spectrogram_transform)

    dataloader = DataLoader(dataset, batch_sampler=BatchSampler(SubsetRandomSampler(range(len(dataset))), batch_size, drop_last=True), collate_fn=collate_fn)

    
    input_dim = 30
    hidden_dim = 128
    num_layers = 1
    audio_dim = 30  #Set the audio dimension 30 Mel bands
    
    # Initialize the generator and discriminator models
    generator = Generator(input_dim, hidden_dim, num_layers, audio_dim).to(device)
    discriminator = Discriminator(audio_dim, hidden_dim, num_layers, 1).to(device)

    # Save the initial state of the generator and discriminator models, if they dont exist already
    initial_generator_path = os.path.join(model_path, "generator_initial.pth")
    initial_discriminator_path = os.path.join(model_path, "discriminator_initial.pth")

    if not os.path.exists(initial_generator_path):
        torch.save(generator.state_dict(), initial_generator_path)

    if not os.path.exists(initial_discriminator_path):
        torch.save(discriminator.state_dict(), initial_discriminator_path)

    # Set up the loss function and optimizers for generator and discriminator
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    # Train the GAN for the specified N epochs
    for epoch in range(1, epochs + 1):
        for batch_idx, (packed_real_data, params) in enumerate(dataloader):
            packed_real_data, params = packed_real_data.to(device), params.to(device)
            batch_size = packed_real_data.batch_sizes[0]


            # Train discriminator
            optimizer_D.zero_grad()

            noise = torch.randn(batch_size, input_dim, device=device)
            z = noise.unsqueeze(1) + params.unsqueeze(1).unsqueeze(2)
            packed_fake_data = generator(z)

            real_data, _ = pad_packed_sequence(packed_real_data, batch_first=True)
            fake_data, _ = pad_packed_sequence(packed_fake_data, batch_first=True)

            # Calculate the losses for real and fake data and update the discriminator's weights
            real_validity = discriminator(torch.cat((real_data, params.unsqueeze(1).repeat(1, real_data.size(1), 1)), dim=2))
            fake_validity = discriminator(torch.cat((fake_data.detach(), params.unsqueeze(1).repeat(1, fake_data.size(1), 1)), dim=2))

            real_loss = criterion(real_validity, torch.ones(batch_size, 1, device=device))
            fake_loss = criterion(fake_validity, torch.zeros(batch_size, 1, device=device))
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # Train generator
            optimizer_G.zero_grad()

            # Calculate the generator's loss and update its weights
            fake_validity = discriminator(torch.cat((fake_data, params.unsqueeze(1).repeat(1, fake_data.size(1), 1)), dim=2))
            g_loss = criterion(fake_validity, torch.ones(batch_size, 1, device=device))

            g_loss.backward()
            optimizer_G.step()

            # Print the losses for the current batch
            print(f"Epoch [{epoch}/{epochs}] Batch [{batch_idx+1}/{len(dataloader)}] Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")
        
         # Save the generator and discriminator models 
        if epoch % save_interval == 0:
            torch.save(generator.state_dict(), os.path.join(model_path, f"generator_epoch_{epoch}.pth"))
            torch.save(discriminator.state_dict(), os.path.join(model_path, f"discriminator_epoch_{epoch}.pth"))

    # Save the final generator and discriminator models after training
    torch.save(generator.state_dict(), os.path.join(model_path, "generator_final.pth"))
    torch.save(discriminator.state_dict(), os.path.join(model_path, "discriminator_final.pth"))



def generate_audio(generator_path, params, duration, output_folder, device):
    # Load the generator model
    generator = Generator(input_dim=40, hidden_dim=128, num_layers=1, audio_dim=30).to(device)
    generator.load_state_dict(torch.load(generator_path))
    generator.eval()
    
    # Convert input parameters to a tensor and move it to the device
    params = torch.tensor(params, dtype=torch.float32).unsqueeze(0).to(device)
    
    # Calculate the number of steps needed for the given duration
    num_steps = int(duration * 44100 / 1024)
    
    # Initialize an empty list to store generated waveforms
    generated_waveforms = []

    print("Generating audio...")

    with torch.no_grad():
        # Generate random noise
        for step in range(num_steps):
            # Generate random noise
            noise = torch.randn(1, 1, 128 * 128 - 10, device=device)

            # Concatenate the noise and parameters
            z = torch.cat((noise, params.unsqueeze(1)), dim=2)

            # Generate the output waveform using the generator
            output, _ = generator(z)
            output_waveform = output.squeeze().detach().cpu()

            # Append the generated waveform to the list
            generated_waveforms.append(output_waveform)

            if step % (num_steps // 10) == 0:
                print(f"Step {step}/{num_steps}")

    print("Audio generation completed.")

    # Concatenate generated waveforms into a single tensor
    generated_waveform = torch.cat(generated_waveforms, dim=0)
    generated_waveform = generated_waveform.view(1, 128, -1)
    
    # Initialize the Inverse Mel Scale and Griffin-Lim transforms
    mel_inverse = T.InverseMelScale(n_stft=1024, n_mels=128, sample_rate=44100)
    griffin_lim = T.GriffinLim(n_fft=2048, n_iter=32)

    # Convert the generated Mel spectrogram back to a waveform
    waveform = griffin_lim(mel_inverse(generated_waveform))
    waveform = waveform[:, :int(duration * 44100)]

    # Save the generated audio and parameters
    timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    output_audio_path = os.path.join(output_folder, f"generated_audio_{timestamp}.wav")
    output_json_path = os.path.join(output_folder, f"generated_audio_{timestamp}.json")

    torchaudio.save(output_audio_path, waveform, sample_rate=44100)

    parameter_names = [
        "Latitude",
        "Longitude",
        "Degrees",
        "Humidity",
        "Wind speed",
        "Wind direction",
        "Pressure",
        "Elevation",
        "Minutes of day",
        "Day of year",
    ]

    # Save the input parameters as a JSON file  
    parameter_data = {name: value for name, value in zip(parameter_names, params)}

    with open(output_json_path, "w") as f:
        json.dump(parameter_data, f, indent=4)

In [None]:
wandb.login()

In [None]:
# 5. Connect to Weights and Biases for tracking progress
wandb.init(project="audio-gan")

In [None]:
if __name__ == "__main__":
    os.makedirs(model_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_gan(audio_folder, json_folder, epochs=100, batch_size=40, learning_rate=0.0002, device=device, save_interval=50, model_path=model_path)

In [None]:
# Example file generation
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  
    generator_path = os.path.join(model_path, "generator_final.pth")
    params = [-24.8874 ,150.9657 , 23.16 , 73 , 4.78 , 8 , 1015 , 506 , 546 , 110]  # Replace with actual parameters
    duration = 5.0  # In seconds

    generate_audio(generator_path, params, duration, output_path, device)