<a href="https://colab.research.google.com/github/fred-dev/synthetic_ornithology_training/blob/main/Wave_GAN_NEW_with_buckets_stateful_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb
!pip install tinytag

In [None]:
import platform
import cpuinfo
import os

# Get CPU information
cpu_info = cpuinfo.get_cpu_info()
cpu_brand = cpu_info["brand_raw"]
print("CPU Brand:", cpu_brand)

# Check if the CPU is from AMD
is_amd_cpu = "AMD" in cpu_brand

# Set OpenBLAS backend if AMD CPU is detected
if is_amd_cpu:
    os.environ["MKL_THREADING_LAYER"] = "GNU"
    print("Switched to OpenBLAS backend for AMD CPU.")
else:
    print("Using default backend.")


In [None]:
# 1. Import required libraries
import math
import os
import numpy as np
from tinytag import TinyTag
import json
import torch
import torchaudio
import torch.nn as nn
import librosa
import wandb
from torch.utils.data import Dataset, DataLoader, BatchSampler, SubsetRandomSampler
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch.nn.functional import multi_head_attention_forward
from torch.nn import MultiheadAttention
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.nn import MultiheadAttention
import logging
from torch import autograd
import sys

print("Python verion: " + sys.version)

print("Pytorch version: " + torch.__version__)





In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True) # Don't change this.
audio_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
audio_folder_all = "/content/drive/MyDrive/colab_storage/audioWeather/audiofilesProcessed"
json_folder = "/content/drive/MyDrive/colab_storage/ronxgin_data_samples"
model_path = "/content/drive/MyDrive/colab_storage/modelData"
output_path = "/content/drive/MyDrive/colab_storage/outputData"
bucketData_path = "/content/drive/MyDrive/colab_storage/bucketData"
json_database_path = "/content/drive/MyDrive/colab_storage/JSON_Database/synthetic_ornithology_complete.json"
json_database_normalised_path = "/content/drive/MyDrive/colab_storage/JSON_Database/synthetic_ornithology_complete_normalised.json"
project_settings_json_path = "/content/drive/MyDrive/colab_storage/Project_settings/s_o_settings.json"



In [None]:

def setupParameters(file_path):
    with open(file_path, 'r') as file:
        systemParams = json.load(file)

    master_sample_rate = systemParams["sample_rate"]
    log_level = systemParams["log_level"]
    sequence_element_length_ms = systemParams["sequence_element_length_ms"]
    bucket_min_duration_ms = systemParams["bucket_min_duration_ms"]
    bucket_max_duration_ms = systemParams["bucket_max_duration_ms"]
    bucket_max_size_variation_ms = systemParams["bucket_max_size_variation_ms"]

    mel_n_fft = systemParams["mel_settings"]["mel_n_fft"]
    mel_hop_length = systemParams["mel_settings"]["mel_hop_length"]
    mel_n_mels = systemParams["mel_settings"]["mel_n_mels"]
    mel_window_fn = systemParams["mel_settings"]["mel_window_fn"]
    mel_normalized = systemParams["mel_settings"]["mel_normalized"]

    num_attention_heads = systemParams["hyperparameters"]["num_attention_heads"]
    num_epochs = systemParams["hyperparameters"]["num_epochs"]
    input_dim =  systemParams["hyperparameters"]["input_dim"]
    hidden_dim = systemParams["hyperparameters"]["hidden_dim"]
    num_layers = systemParams["hyperparameters"]["num_layers"]
    conditioning_dim = systemParams["hyperparameters"]["conditioning_dim"]
    learning_rate = systemParams["hyperparameters"]["learning_rate"]
    steps_per_element = math.ceil((master_sample_rate * sequence_element_length_ms / 1000) / mel_hop_length)


    # Returning all the parameters as a dictionary
    return {
        "master_sample_rate": master_sample_rate,
        "log_level": log_level,
        "sequence_element_length_ms": sequence_element_length_ms,
        "bucket_min_duration_ms": bucket_min_duration_ms,
        "bucket_max_duration_ms": bucket_max_duration_ms,
        "bucket_max_size_variation_ms": bucket_max_size_variation_ms,
        "mel_n_fft": mel_n_fft,
        "mel_hop_length": mel_hop_length,
        "mel_n_mels": mel_n_mels,
        "mel_window_fn": mel_window_fn,
        "mel_normalized": mel_normalized,
        "num_attention_heads": num_attention_heads,
        "num_epochs": num_epochs,
        "input_dim": input_dim,
        "hidden_dim": hidden_dim,
        "num_layers": num_layers,
        "conditioning_dim": conditioning_dim,
        "learning_rate": learning_rate,
        "steps_per_element" : steps_per_element
    }

#Set up the logging configuration
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Global scope
systemParams = setupParameters(project_settings_json_path)
print(systemParams)  # print the parameters

# Create a logger instance
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)  # Set logger level to DEBUG

# Define custom log levels and their corresponding colors
LOG_COLORS = {
    logging.DEBUG: '\033[94m',    # Blue
    logging.INFO: '\033[92m',     # Green
    logging.WARNING: '\033[93m',  # Yellow
    logging.ERROR: '\033[91m',    # Red
    logging.CRITICAL: '\033[91m'  # Red
}

# Create a custom log formatter to add colors
class ColoredFormatter(logging.Formatter):
    def format(self, record):
        log_color = LOG_COLORS.get(record.levelno)
        log_msg = super().format(record)
        return f'{log_color}{record.levelname} - {log_msg}\033[0m'


for handler in logger.handlers:
    logger.removeHandler(handler)

# Create a console handler and set the log level
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)

# Apply the custom log formatter to the console handler
colored_formatter = ColoredFormatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(colored_formatter)

# Add the console handler to the logger
logger.addHandler(console_handler)
logger.propagate = False


In [None]:
logger.debug("This is a debug message")
logger.info("This is an info message")
logger.warning("This is a warning message")
logger.error("This is an error message")
logger.critical("This is a critical message")



In [None]:
#Functions for normalising parameters. These are fixed and named to make things simpler later time.
def normalise_lat(lat):
    return (lat - (-90)) / (90 - (-90))

def normalise_lon(lon):
    return (lon - (-180)) / (180 - (-180))

def normalise_temp(temp, mean, std):
    return (temp - mean) / std

def normalise_humidity(humidity, mean, std):
    return (humidity - mean) / std

def normalise_wind_speed(wind_speed, mean, std):
    return (wind_speed - mean) / std

def normalise_wind_direction(wind_direction, mean, std):
    return (wind_direction - mean) / std

def normalise_pressure(pressure, mean, std):
    return (pressure - mean) / std

def normalise_elevation(elevation, mean, std):
    return (elevation - mean) / std

def normalise_minutes_of_day(minutes_of_day):
    return minutes_of_day / 1440

def normalise_day_of_year(day_of_year):
    return (day_of_year - 1) / 364


def preprocess_json(input_file_path, output_file_path, project_settings_file_path):
    with open(input_file_path, 'r') as file:
        data = json.load(file)

    # Collect values for pre-normalization visualization
    pre_normalized_data = {
        "Temperature": [entry["main"]["temp"] for entry in data],
        "Humidity": [entry["main"]["humidity"] for entry in data],
        "Wind Speed": [entry["wind"]["speed"] for entry in data],
        "Wind Direction": [entry["wind"]["deg"] for entry in data],
        "Pressure": [entry["main"]["pressure"] for entry in data],
        "Elevation": [entry["elevation"] for entry in data],
        "Minutes of Day": [entry["minutesOfDay"] for entry in data],
        "Day of Year": [entry["dayOfYear"] for entry in data]
    }

    # Plot pre-normalization graphs
    fig, axs = plt.subplots(len(pre_normalized_data), figsize=(8, 6 * len(pre_normalized_data)))

    for i, (param, values) in enumerate(pre_normalized_data.items()):
        axs[i].plot(values)
        axs[i].set_xlabel("Entry")
        axs[i].set_ylabel(param)
        axs[i].set_title(f"Pre-Normalization: {param}")

    plt.tight_layout()
    plt.show()

    # Collect values for normalization
    temp_values = []
    humidity_values = []
    wind_speed_values = []
    wind_direction_values = []
    pressure_values = []
    elevation_values = []

    for entry in data:
        temp_values.append(entry["main"]["temp"])
        humidity_values.append(entry["main"]["humidity"])
        wind_speed_values.append(entry["wind"]["speed"])
        wind_direction_values.append(entry["wind"]["deg"])
        pressure_values.append(entry["main"]["pressure"])
        elevation_values.append(entry["elevation"])

    # Calculate mean and standard deviation
    temp_mean = np.mean(temp_values)
    temp_std = np.std(temp_values)
    humidity_mean = np.mean(humidity_values)
    humidity_std = np.std(humidity_values)
    wind_speed_mean = np.mean(wind_speed_values)
    wind_speed_std = np.std(wind_speed_values)
    wind_direction_mean = np.mean(wind_direction_values)
    wind_direction_std = np.std(wind_direction_values)
    pressure_mean = np.mean(pressure_values)
    pressure_std = np.std(pressure_values)
    elevation_mean = np.mean(elevation_values)
    elevation_std = np.std(elevation_values)

    # Read existing project settings from JSON file
    with open(project_settings_file_path, 'r') as settings_file:
        project_settings = json.load(settings_file)

    # Update or add new parameters to the project settings
    project_settings["temp_mean"] = temp_mean
    project_settings["temp_std"] = temp_std
    project_settings["humidity_mean"] = humidity_mean
    project_settings["humidity_std"] = humidity_std
    project_settings["wind_speed_mean"] = wind_speed_mean
    project_settings["wind_speed_std"] = wind_speed_std
    project_settings["wind_direction_mean"] = wind_direction_mean
    project_settings["wind_direction_std"] = wind_direction_std
    project_settings["pressure_mean"] = pressure_mean
    project_settings["pressure_std"] = pressure_std
    project_settings["elevation_mean"] = elevation_mean
    project_settings["elevation_std"] = elevation_std

    # Write updated project settings to JSON file
    with open(project_settings_file_path, 'w') as settings_file:
        json.dump(project_settings, settings_file, indent=4)

    normalized_data = []
    for entry in data:
        filename = entry["filename"]
        lat = entry["coord"]["lat"]
        lon = entry["coord"]["lon"]
        temp = entry["main"]["temp"]
        humidity = entry["main"]["humidity"]
        wind_speed = entry["wind"]["speed"]
        wind_direction = entry["wind"]["deg"]
        pressure = entry["main"]["pressure"]
        elevation = entry["elevation"]
        minutes_of_day = entry["minutesOfDay"]
        day_of_year = entry["dayOfYear"]

        # Normalization
        normalized_lat = normalise_lat(lat)
        normalized_lon = normalise_lon(lon)
        normalized_temp = normalise_temp(temp, temp_mean, temp_std)
        normalized_humidity = normalise_humidity(humidity, humidity_mean, humidity_std)
        normalized_wind_speed = normalise_wind_speed(wind_speed, wind_speed_mean, wind_speed_std)
        normalized_wind_direction = normalise_wind_direction(wind_direction, wind_direction_mean, wind_direction_std)
        normalized_pressure = normalise_pressure(pressure, pressure_mean, pressure_std)
        normalized_elevation = normalise_elevation(elevation, elevation_mean, elevation_std)
        normalized_minutes_of_day = normalise_minutes_of_day(minutes_of_day)
        normalized_day_of_year = normalise_day_of_year(day_of_year)

        entry_dict = {
            "filename": filename,
            "normalized_latitude": normalized_lat,
            "normalized_longitude": normalized_lon,
            "normalized_temperature": normalized_temp,
            "normalized_humidity": normalized_humidity,
            "normalized_wind_speed": normalized_wind_speed,
            "normalized_wind_direction": normalized_wind_direction,
            "normalized_pressure": normalized_pressure,
            "normalized_elevation": normalized_elevation,
            "normalized_minutes_of_day": normalized_minutes_of_day,
            "normalized_day_of_year": normalized_day_of_year
        }
        normalized_data.append(entry_dict)

    with open(output_file_path, 'w') as file:
        json.dump(normalized_data, file, indent=4)

    # Display graphs
    labels = [
        "Normalized Temperature",
        "Normalized Humidity",
        "Normalized Wind Speed",
        "Normalized Wind Direction",
        "Normalized Pressure",
        "Normalized Elevation",
        "Normalized Minutes of Day",
        "Normalized Day of Year",
        "Normalized Latitude",
        "Normalized Longitude"
    ]

    values = [
        [entry["normalized_temperature"] for entry in normalized_data],
        [entry["normalized_humidity"] for entry in normalized_data],
        [entry["normalized_wind_speed"] for entry in normalized_data],
        [entry["normalized_wind_direction"] for entry in normalized_data],
        [entry["normalized_pressure"] for entry in normalized_data],
        [entry["normalized_elevation"] for entry in normalized_data],
        [entry["normalized_minutes_of_day"] for entry in normalized_data],
        [entry["normalized_day_of_year"] for entry in normalized_data],
        [entry["normalized_latitude"] for entry in normalized_data],
        [entry["normalized_longitude"] for entry in normalized_data]
    ]

    fig, axs = plt.subplots(len(labels), figsize=(8, 6 * len(labels)))

    for i, ax in enumerate(axs):
        ax.plot(values[i])
        ax.set_xlabel("Entry")
        ax.set_ylabel(labels[i])
        ax.set_title(labels[i])

    plt.tight_layout()
    plt.show()

In [None]:
preprocess_json(json_database_path, json_database_normalised_path, project_settings_json_path)

In [None]:
def get_audio_duration_ms(file_path):
  #this function just checks the duration of an audio file using tinytag
    try:
        audio = TinyTag.get(file_path)
        duration_ms = int(audio.duration * 1000)
        return duration_ms
    except UnicodeDecodeError:
        logger.error(f"Unicode Decode Error for file: {file_path}")
        return 0


def sort_and_bucket_audio_files(folder, min_length_ms, max_length_ms, max_variance_ms, bucketData_path):
    audio_files = [os.path.join(folder, file) for file in os.listdir(folder) if file.endswith('.wav')]

    logger.info(f"sort_and_bucket_audio_files: Found {len(audio_files)} audio files.")

    # Get the duration of all the audio files
    durations_ms = [get_audio_duration_ms(file) for file in audio_files]
    audio_files = [file for file, duration in zip(audio_files, durations_ms) if min_length_ms <= duration <= max_length_ms]
    durations_ms = [duration for duration in durations_ms if min_length_ms <= duration <= max_length_ms]


    logger.info(f"sort_and_bucket_audio_files: {len(audio_files)} audio files have a duration between {min_length_ms} ms and {max_length_ms} ms.")

    # Sort the durations and the audio file paths so they match
    sorted_indices = np.argsort(durations_ms)
    sorted_files = [audio_files[i] for i in sorted_indices]
    sorted_durations = [durations_ms[i] for i in sorted_indices]

    # Create an object to store the sorting and bucketing data
    bucketed_files = []
    current_bucket = [{'file': sorted_files[0], 'duration': sorted_durations[0]}]
    current_max_duration = sorted_durations[0]

    # Iterate through the sorted files and check if they fit within the max variance
    for file, duration in zip(sorted_files[1:], sorted_durations[1:]):
        if duration - current_max_duration > max_variance_ms:
            # Update the durations of the files in the current bucket
            current_bucket = [{'file': file_data['file'], 'duration': get_audio_duration_ms(file_data['file'])} for file_data in current_bucket]

            # Calculate the longest duration in the current bucket
            current_max_duration = max(file_data['duration'] for file_data in current_bucket)

            bucketed_files.append({'files': current_bucket, 'max_duration': current_max_duration})

            logger.info(f"sort_and_bucket_audio_files: Bucket {len(bucketed_files)} created with {len(current_bucket)} audio files and max duration of {current_max_duration} ms.")

            current_bucket = [{'file': file, 'duration': duration}]
            current_max_duration = duration
        else:
            current_bucket.append({'file': file, 'duration': duration})

    # Update the durations of the files in the last bucket
    current_bucket = [{'file': file_data['file'], 'duration': get_audio_duration_ms(file_data['file'])} for file_data in current_bucket]

    # Calculate the longest duration in the last bucket
    current_max_duration = max(file_data['duration'] for file_data in current_bucket)

    bucketed_files.append({'files': current_bucket, 'max_duration': current_max_duration})


    logger.info(f"sort_and_bucket_audio_files: Bucket {len(bucketed_files)} created with {len(current_bucket)} audio files and max duration of {current_max_duration} ms.")
    logger.info(f"sort_and_bucket_audio_files: Created {len(bucketed_files)} buckets in total.")

    # Saving bucket data as a JSON file
    with open(os.path.join(bucketData_path, 'bucket_data.json'), 'w') as f:
        json.dump(bucketed_files, f, indent=4)

    logger.info(f"sort_and_bucket_audio_files: Bucket data saved to {os.path.join(bucketData_path, 'bucket_data.json')}.")



def prepare_sequences(audio_folder, json_file, bucketData_path, element_duration_ms):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load bucket data
    logger.info("prepare_sequences: Loading bucketing settings")
    with open(os.path.join(bucketData_path, 'bucket_data.json'), 'r') as f:
        buckets = json.load(f)

    # Instantiate the MelSpectrogram transformation
    logger.info("prepare_sequences: Setting up MelSpectrogram transformation")
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=systemParams['master_sample_rate'],
        n_fft=systemParams['mel_n_fft'],
        hop_length=systemParams['mel_hop_length'],
        n_mels=systemParams['mel_n_mels'],
        normalized=systemParams['mel_normalized']
    )

    # Print the settings of the MelSpectrogram transformation
    logger.info("prepare_sequences: MelSpectrogram settings:")
    logger.info(f"prepare_sequences: Sample rate: {mel_spectrogram.sample_rate}")
    logger.info(f"prepare_sequences: N FFT: {mel_spectrogram.n_fft}")
    logger.info(f"prepare_sequences: Hop length: {mel_spectrogram.hop_length}")
    logger.info(f"prepare_sequences: Number of Mel bins: {mel_spectrogram.n_mels}")
    logger.info(f"prepare_sequences: Normalized: {mel_spectrogram.normalized}")

    logger.info("prepare_sequences: Loading normalized data")
    with open(json_file, 'r') as f:
        json_data = json.load(f)

    for bucket in buckets:
        batch_size = int(bucket["max_duration"] / element_duration_ms) + 1
        logger.debug(f"prepare_sequences: batch_size: {batch_size}")

        for file_data in bucket["files"]:
            audio_file = file_data["file"]
            audio_data, sr = librosa.load(audio_file, sr=systemParams['master_sample_rate'])
            logger.debug("prepare_sequences: Loading audio file: " + audio_file)

            if len(audio_data) < batch_size * element_duration_ms * sr / 1000:
                audio_data = np.pad(audio_data, (0, batch_size * int(element_duration_ms * sr / 1000) - len(audio_data)), mode='constant')

            audio_elements = np.array_split(audio_data, batch_size)
            logger.debug("prepare_sequences: Audio split into: " + str(len(audio_elements)) + " elements")
            logger.debug("prepare_sequences: Transforming audio to Spectrogram")

            audio_elements = [mel_spectrogram(torch.tensor(element).to(device)) for element in audio_elements]

            batch_data = torch.stack(audio_elements)  # [batch_size, sequence_length, feature_size]
            logger.debug("prepare_sequences: Audio elements Shape: " + str(len(audio_elements)))

            json_entry = next((entry for entry in json_data if entry["filename"] == os.path.splitext(os.path.basename(audio_file))[0].replace('_P', '')), None)
            if json_entry is None:
                logger.error("prepare_sequences: No matching JSON entry found for audio file:", audio_file)
                continue

            params = [
                json_entry["normalized_latitude"],
                json_entry["normalized_longitude"],
                json_entry["normalized_wind_direction"],
                json_entry["normalized_humidity"],
                json_entry["normalized_wind_speed"],
                json_entry["normalized_wind_direction"],
                json_entry["normalized_pressure"],
                json_entry["normalized_elevation"],
                json_entry["normalized_minutes_of_day"],
                json_entry["normalized_day_of_year"],
            ]

            conditioning_tensor = torch.tensor(params, device=device).unsqueeze(0).repeat(batch_size, 1)  # Repeat for each sequence_length
            batch_params = conditioning_tensor  # [batch_size, num_params]

            logger.info("prepare_sequences: batch_data shape" + str(batch_data.shape))
            logger.info("prepare_sequences: batch_params shape " + str(batch_params.shape))
            logger.info("prepare_sequences: Length of batch data " + str(len(batch_data)))
            logger.info("prepare_sequences: yielding batch")
            yield batch_data, batch_params, len(batch_data)



class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, audio_dim, num_attention_heads, conditioning_dim):
        super(Generator, self).__init__()

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Define device here
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.hidden_state = None

        self.num_attention_heads = num_attention_heads
        self.attention_layers = nn.ModuleList([
            nn.Linear(hidden_dim * 2, hidden_dim) for _ in range(num_attention_heads)
        ])

        self.additional_linear_layer = nn.Linear(hidden_dim * num_attention_heads + hidden_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, audio_dim)
        self.conditioning_layer = nn.Linear(conditioning_dim, hidden_dim)

    def reset_hidden_state(self, batch_size):
        h_0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(self.device)
        c_0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(self.device)
        self.hidden_state = (h_0, c_0)

    def forward(self, x, conditioning):
        batch_size, sequence_length, _ = x.shape

        lstm_output, self.hidden_state = self.lstm(x, self.hidden_state)

        logger.debug(f"Generator forward: Shape of LSTM output: {lstm_output.shape}, expected: ({batch_size}, {sequence_length}, {self.lstm.hidden_size})")

        outputs = []

        for t in range(sequence_length):
            attention_input = lstm_output[:, t, :]
            conditioned_input = self.conditioning_layer(conditioning)

            attention_input_conditioned = torch.cat((attention_input, conditioned_input), dim=-1)

            logger.debug(f"Generator forward: Shape of attention_input_conditioned: {attention_input_conditioned.shape}, expected: ({batch_size}, {self.lstm.hidden_size + self.conditioning_layer.out_features})")

            attention_outputs = []
            for attention_layer in self.attention_layers:
                attention_scores = attention_layer(attention_input_conditioned)
                attention_weights = F.softmax(attention_scores, dim=0).unsqueeze(-1)
                attention_weights = attention_weights.permute(0, 2, 1)
                attended_output = attention_weights * lstm_output
                attention_output = torch.sum(attended_output, dim=1)
                attention_outputs.append(attention_output)

            combined_output = torch.cat(attention_outputs, dim=1)
            logger.debug(f"Generator forward: Shape of combined_output: {combined_output.shape}, expected: ({batch_size}, {self.lstm.hidden_size * self.num_attention_heads})")

            additional_output = self.additional_linear_layer(torch.cat((combined_output, conditioned_input), dim=-1))
            output = self.output_layer(additional_output)
            logger.debug(f"Generator forward: Shape of final output: {output.shape}, expected: ({batch_size}, {self.output_layer.out_features})")

            outputs.append(output.unsqueeze(1))

        final_output = torch.cat(outputs, dim=1)
        logger.debug(f"Generator forward: Shape of final output: {final_output.shape}")
        return final_output



# Training the generator
def train_generator(discriminator, generator, batch_size, conditioning, optimizer, steps_per_element):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info("Training generator")
    # Reset gradients
    optimizer.zero_grad()

    # Reset the generator's hidden state
    generator.reset_hidden_state(batch_size)

    # Generate fake data
    noise = torch.randn(batch_size, steps_per_element, systemParams['input_dim'], device=device)
    logger.debug(f"train_generator: Shape of noise: {noise.shape}, expected: ({batch_size}, steps_per_element, {systemParams['input_dim']})")

    # Apply conditioning to the noise
    conditioned_noise = torch.cat((noise, conditioning), dim=-1)
    logger.debug(f"train_generator: Shape of conditioned_noise: {conditioned_noise.shape}, expected: ({batch_size}, steps_per_element, {systemParams['input_dim'] + conditioning_dim})")

    # Generate fake data
    fake_data = generator(noise, conditioning)
    logger.debug(f"train_generator: Shape of fake_data: {fake_data.shape}, expected: ({batch_size}, steps_per_element, audio_dim)")

    # Reshape the fake_data and conditioning to match the discriminator's input shape
    fake_data_reshaped = fake_data.view(batch_size, -1)
    logger.debug(f"train_generator: Shape of reshaped fake_data: {fake_data_reshaped.shape}, expected: ({batch_size}, {steps_per_element * audio_dim})")

    # Get the discriminator's prediction
    prediction = discriminator(fake_data_reshaped)
    logger.debug(f"train_generator: Shape of discriminator's prediction: {prediction.shape}, expected: ({batch_size}, 1)")

    # Calculate error and backpropagate
    # Aim to fool the discriminator
    g_loss = -torch.mean(prediction)
    logger.info("train_generator: Generator loss: ", g_loss.item())
    g_loss.backward()

    # Update weights with gradients
    optimizer.step()

    return g_loss



# Define the Discriminator class with LSTM
class Discriminator(nn.Module):
    def __init__(self, audio_dim, hidden_dim, num_layers, steps_per_element):
        super(Discriminator, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(audio_dim, hidden_dim, num_layers, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, 1)
        self.steps_per_element = steps_per_element
        self.hidden_cell = (torch.zeros(self.num_layers,1,self.hidden_dim),
                            torch.zeros(self.num_layers,1,self.hidden_dim))

    def reset_hidden_state(self, batch_size):
        self.hidden_cell = (
            torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device),
            torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device)
        )
    # ...

    def forward(self, x):
        batch_size = x.size(0)
        self.reset_hidden_state(batch_size)
        logger.debug(f"Discriminator Actual shape of x: {x.shape}")
        logger.debug(f"Discriminator steps_per_element: {self.steps_per_element})")

        batch_size, seq_length_times_audio_dim = x.shape
        audio_dim = seq_length_times_audio_dim // self.steps_per_element
        x = x.view(batch_size, self.steps_per_element, audio_dim)  # Reshape the tensor to have 3 dimensions

        logger.debug(f"Discriminator input shape: {x.shape}, expected: ({batch_size}, {self.steps_per_element}, {audio_dim})")

        # LSTM
        lstm_output, _ = self.lstm(x, self.hidden_cell)  # We don't care about the hidden states, so ignore them

        # Use only the last sequence output
        last_output = lstm_output[:, -1, :]
        logger.debug(f"Discriminator LSTM last output shape: {last_output.shape}, expected: ({batch_size}, {self.hidden_dim})")

        # Linear layer
        output = self.output_layer(last_output)

        return output




def train_discriminator(discriminator, generator, real_data, conditioning, optimizer, steps_per_element):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info("train_discriminator: Training Discriminator")

    # Reset gradients
    optimizer.zero_grad()

    # 1. Train on Real Data
    real_data = real_data.view(real_data.size(0), -1)  # Reshape real_data
    # Generate predictions for real data
    prediction_real = discriminator(real_data)
    logger.debug("train_discriminator: Discriminator's prediction for real data: %s", prediction_real.mean().item())


    # 2. Train on Fake Data
    # Generate fake data
    noise = torch.randn(real_data.size(0), steps_per_element, systemParams['input_dim'], device=device)

    # Generate fake data
    fake_data = generator(noise, conditioning).detach()
    fake_data = fake_data.view(fake_data.size(0), -1)  # Reshape fake_data
    # Generate predictions for fake data
    prediction_fake = discriminator(fake_data)
    logger.debug("train_discriminator: Discriminator's prediction for fake data: %s", prediction_fake.mean().item())

    # Compute the Wasserstein Loss
    d_loss = -torch.mean(prediction_real) + torch.mean(prediction_fake)

    logger.info("train_discriminator: d_loss.requires_grad: %s", d_loss.requires_grad)


    # Calculate gradient penalty (lambda is the weighting factor, here assumed to be 10)
    alpha = torch.rand(real_data.size(0), 1).to(device)

    # Create interpolated data between real and fake
    interpolates = (alpha * real_data + ((1 - alpha) * fake_data)).requires_grad_(True)
    # Get discriminator's prediction for interpolated data
    d_interpolates = discriminator(interpolates)
    fake = torch.ones(real_data.size(0), 1).to(device)

    # Calculate the gradients w.r.t interpolated data
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)

    # Compute gradient penalty
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10
    logger.info("train_discriminator: Gradient penalty: %s", gradient_penalty.item())

    # Add gradient penalty to the discriminator loss
    d_loss += gradient_penalty
    logger.info("train_discriminator: Discriminator loss: %s", d_loss.item())

    # Update weights
    d_loss.backward()
    logger.info("train_discriminator: d_loss.backward passed")

    optimizer.step()

    return d_loss


def train_gan(gen, dis, audio_folder, json_file, bucketData_path, element_duration_ms, n_epochs, lr):
    logger.info("train_gan: Training Gan")
    # Define device, loss function, and optimizers
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    gen = gen.to(device)
    dis = dis.to(device)

    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=lr)
    optimizer_dis = torch.optim.Adam(dis.parameters(), lr=lr)

    for epoch in range(n_epochs):
        # Get the data from the generator function
        for real, conditions, batch_size in prepare_sequences(audio_folder, json_file, bucketData_path, element_duration_ms):
            real = real.permute(0, 2, 1).to(device)  # This will rearrange dimensions to (batch_size, sequence_length, num_features)
            conditions = conditions.to(device)
            max_sequence_length = real.shape[1]  # Infer max_sequence_length from the shape of the real data
            logger.debug(f"train_gan: max_sequence_length: {max_sequence_length})")

            logger.debug(f"train_gan: Real data shape: {real.shape}, expected: ({batch_size}, {max_sequence_length}, {systemParams['input_dim']})")

            real = real.to(device)
            conditions = conditions.to(device)

            ###################
            # Train Generator
            ###################

            gen.reset_hidden_state(batch_size)  # Reset hidden state
            optimizer_gen.zero_grad()

            # Generate fake samples
            noise = torch.randn(batch_size, max_sequence_length, systemParams['input_dim']).to(device)
            logger.debug(f"train_gan: Noise shape: {noise.shape}, expected: ({batch_size}, {max_sequence_length}, {systemParams['input_dim']})")

            fake = gen(noise, conditions)
            logger.debug(f"train_gan: Fake data shape: {fake.shape}, expected: ({batch_size}, {max_sequence_length}, {systemParams['input_dim']})")

            # Calculate generator's loss based on discriminator's output
            gen_loss = -dis(fake.view(fake.size(0), -1)).mean()

            # Backward pass and optimization
            gen_loss.backward()
            optimizer_gen.step()

            #######################
            # Train Discriminator
            #######################

            dis.reset_hidden_state(batch_size)  # Reset hidden state
            optimizer_dis.zero_grad()

            # Train discriminator
            dis_loss = train_discriminator(dis, gen, real.reshape(real.size(0), -1), conditions, optimizer_dis, systemParams['steps_per_element'])

        # Printing losses after each epoch
        logger.info(f'train_gan: Epoch [{epoch+1}/{n_epochs}] Loss D: {dis_loss.item()}, Loss G: {gen_loss.item()}')




In [None]:
buckets = sort_and_bucket_audio_files(audio_folder, systemParams['bucket_min_duration_ms'], systemParams['bucket_max_duration_ms'], systemParams['bucket_max_size_variation_ms'], bucketData_path)


In [None]:
wandb.login()

In [None]:
logger.info(f"Input dim: {systemParams['input_dim']}, type: {type(systemParams['input_dim'])}")
logger.info(f"Hidden dim: {systemParams['hidden_dim']}, type: {type(systemParams['hidden_dim'])}")
logger.info(f"Num layers: {systemParams['num_layers']}, type: {type(systemParams['num_layers'])}")
logger.info(f"Mel n mels: {systemParams['mel_n_mels']}, type: {type(systemParams['mel_n_mels'])}")
logger.info(f"Num attention heads: {systemParams['num_attention_heads']}, type: {type(systemParams['num_attention_heads'])}")
logger.info(f"Conditioning dim: {systemParams['conditioning_dim']}, type: {type(systemParams['conditioning_dim'])}")

gen = Generator(systemParams['input_dim'], systemParams['hidden_dim'], systemParams['num_layers'], systemParams['mel_n_mels'], systemParams['num_attention_heads'], systemParams['conditioning_dim'])
dis = Discriminator(systemParams['mel_n_mels'], systemParams['hidden_dim'], systemParams['num_layers'], systemParams['steps_per_element'])


# Call the training function

train_gan(gen, dis, audio_folder, json_database_normalised_path, bucketData_path, systemParams['sequence_element_length_ms'], systemParams['num_epochs'], systemParams['learning_rate'])

In [None]:
# 5. Connect to Weights and Biases for tracking progress
wandb.init(project="audio-gan")