### Feature

1. BPM (Beats per Minute): 90

- Purpose: Measures the tempo of the audio file. In drum loops, the BPM is critical for synchronizing and controlling the rhythm.
- Use case: This can be used as a conditioning factor for your VAE or StyleGAN models to control the tempo of generated audio loops.

2. Zero-Crossing Rate: 0.0148 (approx)

- Purpose: Measures how often the signal changes sign. This feature can indicate the noisiness or tonal quality of the audio.
- Use case: Can be relevant in distinguishing different types of drum sounds (e.g., harsh/noisy snare vs. smooth bass).


3. Spectral Centroid: 796.29

- Purpose: Represents the "brightness" of the sound by indicating where the center of mass of the spectrum is located.
- Use case: Helpful for timbre analysis, especially in distinguishing between bright or dark sound patterns in drum loops.


4. Spectral Bandwidth: 1406.69

- Purpose: Measures the width of the spectrum. This value gives an idea of the range of frequencies present in the sound.
- Use case: Can be useful in assessing the complexity of a sound (e.g., a more complex drum pattern might have a broader bandwidth).


5. RMS Energy: 0.1047

- Purpose: Measures the loudness or power of the audio signal.
- Use case: This can directly inform the energy feature for your model, allowing it to condition audio generation based on loudness levels.


6. MFCCs (Mel-Frequency Cepstral Coefficients): A 13-element array

- Purpose: Used to represent the timbral texture of the audio. MFCCs capture the shape of the spectral envelope and are widely used in audio and speech processing.
- Use case: These coefficients can be directly used to model the timbre of the drum sounds, making them a critical input for VAE or GAN models.


7. Chroma Features: A 12-element array

- Purpose: Represents the distribution of energy across the 12 pitch classes (similar to musical notes).
- Use case: Can be used to analyze harmonic content or pitch distribution, although for drums, this will typically show lower values compared to harmonic instruments.


8. Mood: 'creative'

- Purpose: Mapped from metadata (tags). This is a subjective feature derived from the audio's timbre and overall texture, which can be conditioned to generate different emotional flavors in drum loops.
- Use case: This feature can serve as a high-level condition for generating different types of drum loops (e.g., happy, dark, energetic, calm).


9. Energy: 0.9

- Purpose: Inferred from the tags and calculated from the spectral properties. It reflects how powerful or intense the drum loop is.
- Use case: Energy levels can be conditioned to control the intensity of generated drum loops.

### extract audio features

In [1]:
import librosa
import numpy as np
import os

def map_tags_to_features(tags):
    """
    Maps tags to high-level features such as mood, energy, etc.
    """
    mood_mapping = {
        'Ambient': 'calm',
        'Cinematic': 'epic',
        'Looping': 'neutral',
        'Drone': 'dark',
        'Atmosphere': 'calm',
        'Pad': 'soft',
        'Drums': 'energetic',
        'Sound-Design': 'creative'
    }
    
    energy_mapping = {
        'Drums': 0.9,
        'Cinematic': 0.8,
        'Looping': 0.6,
        'Drone': 0.4,
        'Pad': 0.5,
        'Ambient': 0.3
    }
    
    # Default values
    mood = 'neutral'
    energy = 0.5
    
    for tag in tags:
        if tag in mood_mapping:
            mood = mood_mapping[tag]
        if tag in energy_mapping:
            energy = max(energy, energy_mapping[tag])
    
    return mood, energy

import librosa
import numpy as np
import os

def extract_audio_features_with_metadata(audio_path, sr=22050, metadata=None):
    """
    Extracts audio features using Librosa and combines them with metadata.
    
    Parameters:
    audio_path (str): Path to the audio file.
    sr (int): Sampling rate.
    metadata (dict): High-level feature metadata.

    Returns:
    dict: Dictionary of extracted audio and metadata features.
    """
    try:
        # Print the path being loaded
        # print(f"Trying to load audio file: {audio_path}")
        if os.path.exists(audio_path):
            y, sr = librosa.load(audio_path, sr=sr)
            # print("Audio file loaded successfully")
        else:
            print(f"File not found: {audio_path}")
            return None

        features = {}

        # Extract low-level audio features with Librosa
        bpm, _ = librosa.beat.beat_track(y=y, sr=sr)
        if isinstance(bpm, (np.ndarray, np.generic)):  # Only call .item() on NumPy scalars
            bpm = bpm.item()  # Convert to scalar if it's a single-item array
        # print(f"Extracted BPM: {bpm} (Type: {type(bpm)})")
        features['bpm'] = bpm

        # Other features (make sure they are floats and not arrays)
        features['zero_crossing_rate'] = float(np.mean(librosa.feature.zero_crossing_rate(y)))
        features['spectral_centroid'] = float(np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)))
        features['spectral_bandwidth'] = float(np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr)))
        features['rms_energy'] = float(np.mean(librosa.feature.rms(y=y)))
        # print(f"Extracted Zero Crossing Rate: {features['zero_crossing_rate']} (Type: {type(features['zero_crossing_rate'])})")
        
        # MFCCs and Chroma features are arrays
        features['mfccs'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1)  # 13 MFCC coefficients
        features['chroma'] = np.mean(librosa.feature.chroma_stft(y=y, sr=sr), axis=1)  # 12 chroma values
        # print(f"Extracted MFCCs: {features['mfccs']} (Type: {type(features['mfccs'])})")
        # print(f"Extracted Chroma: {features['chroma']} (Type: {type(features['chroma'])})")

        # Map metadata to relevant features
        if metadata:
            features.update({
                'kick_intensity': metadata.get('Kick', 0),
                'snare_intensity': metadata.get('Snare', 0),
                'hihat_presence': metadata.get('Hi-hat', 0),
                'percussion': metadata.get('Percussion', 0),
                'mood': metadata.get('Mood', 'neutral'),
                'timbre': metadata.get('Timbre', 'balanced'),
                'energy': metadata.get('Energy', 0),
                'danceability': metadata.get('Danceability', 0),
                'global_loudness': metadata.get('Global Loudness', 0),
                'dynamic_complexity': metadata.get('Dynamic Complexity', 0),
                'valence': metadata.get('Valence', 'neutral'),
                'arousal': metadata.get('Arousal', 'calm'),
                'pitch_confidence': metadata.get('Pitch Confidence', 0),
                'harmonicity': metadata.get('Harmonicity', 0),
                'texture_density': metadata.get('Texture Density', 0),
                'articulation': metadata.get('Articulation', 'smooth')
            })

        return features
    except Exception as e:
        print(f"Error loading or extracting features from file: {audio_path}, Error: {e}")
        return None


def print_feature_info(features):
    """
    Print detailed information about extracted features.
    
    Parameters:
    features (dict): Extracted audio features.
    """
    for key, value in features.items():
        print(f"Feature: {key}")
        print(f"Type: {type(value)}")
        
        # Check if it's a scalar value
        if isinstance(value, (int, float)):
            print(f"Value: {value}")
        
        # If it's an array or list, print the shape and first few values
        elif isinstance(value, np.ndarray):
            print(f"Shape: {value.shape}")
            print(f"First few values: {value[:5]}")  # Display the first 5 elements if it's an array
            
        # For other types, print directly
        else:
            print(f"Value: {value}")
        
        print("\n")  # Separate outputs for clarity


# Example Usage
audio_path = "D:/LV-NTF+LoopGAN/data/FSL10K-trimmed/505_301.wav"
metadata = {
    "Kick": 0.8, 
    "Snare": 0.7, 
    "Hi-hat": 0.9, 
    "Percussion": 0.6, 
    "Mood": "happy", 
    "Timbre": "bright", 
    "Energy": 0.85, 
    "Danceability": 0.9
}

# Extract features
audio_features = extract_audio_features_with_metadata(audio_path, metadata=metadata)

# Print detailed info about the features
# print_feature_info(audio_features)


### Custom Dataset Class

### VAE Encoder and Decoder


In [2]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)

    def forward(self, x):
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(h))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, z):
        h = torch.relu(self.fc1(z))
        h = torch.relu(self.fc2(h))
        return torch.sigmoid(self.fc3(h))


### Dataloader

### StyleGAN (Mel-spectrogram Generation)


In [3]:
class StyleGAN(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(StyleGAN, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, output_dim)

    def forward(self, z):
        h = torch.relu(self.fc1(z))
        h = torch.relu(self.fc2(h))
        return torch.sigmoid(self.fc3(h))


### MelGAN (Mel-Spectrogram to Audio)


In [4]:
class MelGAN(nn.Module):
    def __init__(self):
        super(MelGAN, self).__init__()
        self.conv1 = nn.ConvTranspose1d(1, 128, 4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose1d(128, 64, 4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose1d(64, 1, 4, stride=2, padding=1)

    def forward(self, x):
        # Ensure input has a single channel before the transposed convolution
        if x.shape[1] != 1:  # If it doesn't have 1 channel
            x = x.unsqueeze(1)  # Add a channel dimension, from [batch_size, time_steps] -> [batch_size, 1, time_steps]
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return torch.tanh(self.conv3(x))




### VAE-GAN Hybrid Model

In [5]:
class DrumLoopVAEStyleGAN(nn.Module):
    def __init__(self, latent_dim, input_dim, output_dim):
        super(DrumLoopVAEStyleGAN, self).__init__()
        self.encoder = Encoder(input_dim=input_dim, latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim, output_dim=output_dim)
        self.stylegan = StyleGAN(latent_dim=latent_dim, output_dim=output_dim)
        self.melgan = MelGAN()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        std = torch.clamp(std, min=1e-6)  # Avoid too small std
        eps = torch.randn_like(std)
        return mu + eps * std


    def forward(self, x):
        # Encoder step
        mu, logvar = self.encoder(x)
        # print(f"mu: min {mu.min().item()}, max {mu.max().item()}, mean {mu.mean().item()}, std {mu.std().item()}")
        # print(f"logvar: min {logvar.min().item()}, max {logvar.max().item()}, mean {logvar.mean().item()}, std {logvar.std().item()}")

        # Reparameterize to sample from latent space
        z = self.reparameterize(mu, logvar)
        # Generate mel spectrogram with StyleGAN
        generated_mel_spec = self.stylegan(z)
        
        # Ensure that the mel_spec is reshaped to have 1 channel before feeding it to MelGAN
        if generated_mel_spec.dim() == 2:  # If it is [batch_size, time_steps]
            generated_mel_spec = generated_mel_spec.unsqueeze(1)  # Add a channel dimension

        # Generate audio with MelGAN
        audio = self.melgan(generated_mel_spec)
        return audio, generated_mel_spec, mu, logvar


### Loss Functions (VAE and GAN)



In [6]:
import torch.nn.functional as F


def vae_loss(reconstructed_x, x, mu, logvar, kl_weight=1.5):
    # Ensure the reconstructed_x (mel_spec) and x (features) have the same dimensions
    if reconstructed_x.size(-1) != x.size(-1):
        reconstructed_x = reconstructed_x[:, :x.size(-1)]  # Match length to input features

    # Reconstruction loss (MSE between reconstructed mel_spec and input features)
    recon_loss = nn.MSELoss()(reconstructed_x, x)

    # Clamp logvar to avoid extremely large or small values but allow variability
    logvar = torch.clamp(logvar, min=-10, max=10)
    mu = torch.clamp(mu, min=-5, max=5)

    # KL divergence loss (regularization term for the VAE)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_div = torch.clamp(kl_div, min=0, max=100)  # Clip to avoid exploding values

    # Print for debugging if needed
    # print(f"recon_loss: {recon_loss.item()}, kl_div: {kl_div.item()}")

    if torch.isnan(kl_div) or torch.isinf(kl_div):
        print("NaN or Inf detected in KL divergence")

    # Return total loss with KL weighting
    return recon_loss + kl_weight * kl_div, recon_loss, kl_div






def gan_loss(discriminator, real_data, fake_data):
    real_loss = nn.BCELoss()(discriminator(real_data), torch.ones_like(real_data))
    fake_loss = nn.BCELoss()(discriminator(fake_data), torch.zeros_like(fake_data))
    return real_loss + fake_loss


### Define the Data Loader (train_loader)

In [7]:
# Path to the folder containing trimmed audio files
TRIMMED_AUDIO_FOLDER_PATH = "D:/LV-NTF+LoopGAN/data/FSL10K-trimmed"

# Function to get all audio file paths from the directory
def get_audio_file_paths(folder_path):
    audio_paths = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".wav"):  # Ensure only .wav files are picked
                audio_paths.append(os.path.abspath(os.path.join(root, file)))  # Convert to absolute path
    return audio_paths

# Load the audio file paths
audio_paths = get_audio_file_paths(TRIMMED_AUDIO_FOLDER_PATH)
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, audio_paths, metadata, feature_extraction_fn, required_input_dim=128):
        self.audio_paths = audio_paths
        self.metadata = metadata or {}
        self.feature_extraction_fn = feature_extraction_fn
        self.required_input_dim = required_input_dim

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        audio_file = self.audio_paths[idx]
        
        if not os.path.exists(audio_file):
            print(f"File not found: {audio_file}")
            return None

        try:
            feature_dict = self.feature_extraction_fn(audio_file)
            if feature_dict is None:
                raise ValueError("Feature extraction returned None")
        except Exception as e:
            print(f"Error loading or extracting features from file: {audio_file}, Error: {e}")
            return None
        
        # Extract relevant numeric features from the feature_dict
        scalar_features = np.array([
            feature_dict.get('bpm', 0),
            feature_dict.get('zero_crossing_rate', 0),
            feature_dict.get('spectral_centroid', 0),
            feature_dict.get('spectral_bandwidth', 0),
            feature_dict.get('rms_energy', 0),
        ], dtype=np.float32)
        
        # Combine MFCCs and chroma features with scalar features
        mfcc_features = feature_dict.get('mfccs', np.zeros(13))
        chroma_features = feature_dict.get('chroma', np.zeros(12))

        numeric_features = np.concatenate((scalar_features, mfcc_features, chroma_features)).astype(np.float32)
        
        # Normalize the combined features
        features = (numeric_features - numeric_features.min()) / (numeric_features.max() - numeric_features.min() + 1e-6)

        # Ensure the features are padded or truncated to the required input dimension
        if features.size < self.required_input_dim:
            features = np.pad(features, (0, self.required_input_dim - features.size), 'constant')
        else:
            features = features[:self.required_input_dim]
        
        return torch.tensor(features, dtype=torch.float32)






def collate_fn(batch):
    # Filter out None items from the batch
    batch = [b for b in batch if b is not None]
    
    if len(batch) == 0:
        return None  # Return None if all elements in the batch were None

    # Stack valid items
    return torch.stack(batch)

# Define dataset and dataloader
audio_dataset = AudioDataset(audio_paths=audio_paths, feature_extraction_fn=extract_audio_features_with_metadata, metadata=metadata)
train_loader = torch.utils.data.DataLoader(
    dataset=audio_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn  # Use the custom collate function
)


### Utility function for training

In [8]:
import librosa.display
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np
def visualize_spectrogram(mel_spec, title="Mel-Spectrogram", sr=22050):
    """Visualizes a mel-spectrogram."""
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(librosa.power_to_db(mel_spec, ref=np.max), sr=sr, x_axis='time', y_axis='mel')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    plt.show()

def playback_audio(audio, sr=22050):
    """Plays back the generated audio in Jupyter."""
    display(Audio(audio, rate=sr))


### Training loop


In [None]:
import torch.optim as optim
from tqdm.auto import tqdm  # Using autonotebook for Jupyter/VScode compatibility
import matplotlib.pyplot as plt
import librosa
from IPython.display import Audio
import numpy as np
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model and optimizer
latent_dim = 64
input_dim = 128  # Adjust based on extracted features
output_dim = 256  # Mel-spectrogram size

vae_stylegan = DrumLoopVAEStyleGAN(latent_dim=latent_dim, input_dim=input_dim, output_dim=output_dim).to(device)
optimizer = optim.Adam(vae_stylegan.parameters(), lr=1e-4)

# Lists to store metrics for plotting
recon_losses = []
kl_divs = []
learning_rates = []
gradient_norms = []

# Function to load checkpoint
def load_checkpoint(checkpoint_path, model, optimizer):
    if checkpoint_path and os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
        recon_losses.extend(checkpoint['recon_losses'])  # Continue tracking losses
        kl_divs.extend(checkpoint['kl_divs'])
        learning_rates.extend(checkpoint['learning_rates'])
        gradient_norms.extend(checkpoint['gradient_norms'])
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch
    else:
        print(f"No checkpoint found at {checkpoint_path}. Starting from scratch.")
        return 0  # Start from the beginning if no checkpoint found

# Function to plot training dynamics
def plot_training_dynamics():
    epochs = list(range(1, len(recon_losses) + 1))
    
    plt.figure(figsize=(15, 8))
    
    # Reconstruction loss plot
    plt.subplot(2, 2, 1)
    plt.plot(epochs, recon_losses, label="Reconstruction Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Reconstruction Loss")
    
    # KL divergence plot
    plt.subplot(2, 2, 2)
    plt.plot(epochs, kl_divs, label="KL Divergence")
    plt.xlabel("Epoch")
    plt.ylabel("KL Divergence")
    plt.title("KL Divergence")
    
    # Learning rate plot
    plt.subplot(2, 2, 3)
    plt.plot(epochs, learning_rates, label="Learning Rate")
    plt.xlabel("Epoch")
    plt.ylabel("Learning Rate")
    plt.title("Learning Rate Over Time")
    
    # Gradient norm plot
    plt.subplot(2, 2, 4)
    plt.plot(epochs, gradient_norms, label="Gradient Norm")
    plt.xlabel("Epoch")
    plt.ylabel("Gradient Norm")
    plt.title("Gradient Norm Over Time")

    plt.tight_layout()
    plt.show()

# Function to visualize the spectrogram
def visualize_spectrogram(mel_spec, title="Mel-Spectrogram", sr=22050):
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(librosa.power_to_db(mel_spec, ref=np.max), sr=sr, x_axis='time', y_axis='mel')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    plt.show()

# Function to play audio
def playback_audio(audio, sr=22050):
    display(Audio(audio, rate=sr))

# Path to your checkpoint
checkpoint_path = "vae_stylegan_checkpoint_epoch_10.pth"  # Adjust as necessary

# Load checkpoint if available and resume training
start_epoch = load_checkpoint(checkpoint_path, vae_stylegan, optimizer)



# Training loop
num_epochs = 100
for epoch in range(start_epoch, num_epochs):
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", dynamic_ncols=True)
    
    vae_stylegan.train()
    avg_loss = 0.0
    batch_count = 0
    
    epoch_recon_loss = 0.0
    epoch_kl_div = 0.0
    epoch_mu_stats = {'min': float('inf'), 'max': float('-inf'), 'mean': 0.0, 'std': 0.0}
    epoch_logvar_stats = {'min': float('inf'), 'max': float('-inf'), 'mean': 0.0, 'std': 0.0}

    # Delay KL weight to allow focus on reconstruction first
    if epoch < 10:
        kl_weight = 0
    else:
        kl_weight = min(1.0, (epoch - 10) / 20)

    """
    If you prefer a more complex strategy, try cyclical KL annealing 
    where the weight oscillates between 0 and 1 to emphasize reconstruction 
    in some phases and regularization in others.
    Cyclical KL weight annealing 
    kl_weight = 0.5 * (1 + np.cos(np.pi * (epoch % 20) / 20))           
    """
    
    for batch_idx, batch in enumerate(progress_bar):
        if batch is None:
            continue  # Skip invalid batch

        features = batch.to(device)

        # Normalize input features to avoid extreme values
        features = (features - features.min()) / (features.max() - features.min() + 1e-6)

        # Forward pass through the VAE-StyleGAN model
        audio_output, mel_spec, mu, logvar = vae_stylegan(features.unsqueeze(1))

        # Compute the VAE loss
        total_loss, recon_loss, kl_div = vae_loss(mel_spec.squeeze(1), features, mu, logvar, kl_weight)

        # Gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(vae_stylegan.parameters(), max_norm=5.0)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        avg_loss += total_loss.item()
        batch_count += 1

        # Accumulate reconstruction loss, KL divergence for this epoch
        epoch_recon_loss += recon_loss.item()
        epoch_kl_div += kl_div.item()

        # Update epoch-wise stats for mu and logvar
        epoch_mu_stats['min'] = min(epoch_mu_stats['min'], mu.min().item())
        epoch_mu_stats['max'] = max(epoch_mu_stats['max'], mu.max().item())
        epoch_mu_stats['mean'] += mu.mean().item()
        epoch_mu_stats['std'] += mu.std().item()

        epoch_logvar_stats['min'] = min(epoch_logvar_stats['min'], logvar.min().item())
        epoch_logvar_stats['max'] = max(epoch_logvar_stats['max'], logvar.max().item())
        epoch_logvar_stats['mean'] += logvar.mean().item()
        epoch_logvar_stats['std'] += logvar.std().item()

        # Update the progress bar with basic info for each batch
        progress_bar.set_postfix({
            "Batch Loss": f"{total_loss.item():.3f}",
            "Recon Loss": f"{recon_loss.item():.3f}",
            "KL Div": f"{kl_div.item():.2f}",
            "Avg Loss": f"{avg_loss / batch_count:.3f}",
            "KL Weight": f"{kl_weight:.2f}"
        })

    # Compute mean statistics for mu and logvar for the entire epoch
    epoch_mu_stats['mean'] /= batch_count
    epoch_mu_stats['std'] /= batch_count
    epoch_logvar_stats['mean'] /= batch_count
    epoch_logvar_stats['std'] /= batch_count

    # Record metrics for the current epoch
    recon_losses.append(epoch_recon_loss / batch_count)
    kl_divs.append(epoch_kl_div / batch_count)
    learning_rates.append(optimizer.param_groups[0]['lr'])

    # End of epoch: log detailed metrics for the epoch
    tqdm.write(f"mu: min {epoch_mu_stats['min']:.5f}, max {epoch_mu_stats['max']:.5f}, mean {epoch_mu_stats['mean']:.5f}, std {epoch_mu_stats['std']:.5f}")
    tqdm.write(f"logvar: min {epoch_logvar_stats['min']:.5f}, max {epoch_logvar_stats['max']:.5f}, mean {epoch_logvar_stats['mean']:.5f}, std {epoch_logvar_stats['std']:.5f}")
    tqdm.write(f"recon_loss: {epoch_recon_loss / batch_count:.5f}, kl_div: {epoch_kl_div / batch_count:.5f}")

    
    # Save model checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': vae_stylegan.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'recon_losses': recon_losses,
            'kl_divs': kl_divs,
            'learning_rates': learning_rates
        }, f"vae_stylegan_checkpoint_epoch_{epoch+1}.pth")


# Plot the training dynamics after training
plot_training_dynamics()


### Metrics for evaluation
- reconstruction loss
- Frechet Audio Distance (FAD)
- Inception Score
- Visual Inspection of Mel-Spectrogram