#### Audio Feature Extraction with Metadata

In [25]:
import librosa
import numpy as np
import os

def map_tags_to_features(tags):
    """
    Maps tags to high-level features such as mood, energy, etc.
    """
    mood_mapping = {
        'Ambient': 'calm',
        'Cinematic': 'epic',
        'Looping': 'neutral',
        'Drone': 'dark',
        'Atmosphere': 'calm',
        'Pad': 'soft',
        'Drums': 'energetic',
        'Sound-Design': 'creative'
    }
    
    energy_mapping = {
        'Drums': 0.9,
        'Cinematic': 0.8,
        'Looping': 0.6,
        'Drone': 0.4,
        'Pad': 0.5,
        'Ambient': 0.3
    }
    
    # Default values
    mood = 'neutral'
    energy = 0.5
    
    for tag in tags:
        if tag in mood_mapping:
            mood = mood_mapping[tag]
        if tag in energy_mapping:
            energy = max(energy, energy_mapping[tag])
    
    return mood, energy


def extract_audio_features_with_metadata(audio_path, sr=22050, metadata=None):
    """
    Extracts audio features using Librosa and combines them with metadata.
    
    Parameters:
    audio_path (str): Path to the audio file.
    sr (int): Sampling rate.
    metadata (dict): High-level feature metadata.

    Returns:
    dict: Dictionary of extracted audio and metadata features.
    """
    try:
        if os.path.exists(audio_path):
            y, sr = librosa.load(audio_path, sr=sr)
        else:
            print(f"File not found: {audio_path}")
            return None

        features = {}

        # Extract low-level audio features with Librosa
        bpm, _ = librosa.beat.beat_track(y=y, sr=sr)
        features['bpm'] = bpm if isinstance(bpm, (float, int)) else bpm.item()

        features['zero_crossing_rate'] = float(np.mean(librosa.feature.zero_crossing_rate(y)))
        features['spectral_centroid'] = float(np.mean(librosa.feature.spectral_centroid(y=y, sr=sr)))
        features['spectral_bandwidth'] = float(np.mean(librosa.feature.spectral_bandwidth(y=y, sr=sr)))
        features['rms_energy'] = float(np.mean(librosa.feature.rms(y=y)))

        features['mfccs'] = np.mean(librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13), axis=1)
        features['chroma'] = np.mean(librosa.feature.chroma_stft(y=y, sr=sr), axis=1)

        # Map metadata to relevant features (these are placeholders for model learning)
        if metadata:
            features.update({
                'kick_intensity': metadata.get('Kick', 0),
                'snare_intensity': metadata.get('Snare', 0),
                'hihat_presence': metadata.get('Hi-hat', 0),
                'percussion': metadata.get('Percussion', 0),
                'mood': metadata.get('Mood', 'neutral'),
                'timbre': metadata.get('Timbre', 'balanced'),
                'energy': metadata.get('Energy', 0),
                'danceability': metadata.get('Danceability', 0),
                'global_loudness': metadata.get('Global Loudness', 0),
                'dynamic_complexity': metadata.get('Dynamic Complexity', 0),
                'valence': metadata.get('Valence', 'neutral'),
                'arousal': metadata.get('Arousal', 'calm'),
                'pitch_confidence': metadata.get('Pitch Confidence', 0),
                'harmonicity': metadata.get('Harmonicity', 0),
                'texture_density': metadata.get('Texture Density', 0),
                'articulation': metadata.get('Articulation', 'smooth')
            })

        return features
    except Exception as e:
        print(f"Error loading or extracting features from file: {audio_path}, Error: {e}")
        return None


def print_feature_info(features):
    """
    Print detailed information about extracted features.
    
    Parameters:
    features (dict): Extracted audio features.
    """
    for key, value in features.items():
        print(f"Feature: {key}")
        print(f"Type: {type(value)}")
        
        if isinstance(value, (int, float)):
            print(f"Value: {value}")
        elif isinstance(value, np.ndarray):
            print(f"Shape: {value.shape}")
            print(f"First few values: {value[:5]}")  # Display the first 5 elements
        else:
            print(f"Value: {value}")
        print("\n")


# Example Usage
audio_path = "D:/LV-NTF+LoopGAN/data/FSL10K-trimmed/505_301.wav"
metadata = {
    "Kick": 0.8, 
    "Snare": 0.7, 
    "Hi-hat": 0.9, 
    "Percussion": 0.6, 
    "Mood": "happy", 
    "Timbre": "bright", 
    "Energy": 0.85, 
    "Danceability": 0.9
}

# Extract features
audio_features = extract_audio_features_with_metadata(audio_path, metadata=metadata)
# Optionally, uncomment this for debugging feature extraction
# print_feature_info(audio_features)


#### Model Architecture (Encoder and Decoder)

In [26]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        # Linear layers to downsample the input to latent space
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc_mu = nn.Linear(128, latent_dim)  # Latent space: Mean (mu)
        self.fc_logvar = nn.Linear(128, latent_dim)  # Latent space: Log-variance (logvar)

    def forward(self, x):
        """
        Forward pass to get mu and logvar for the latent space.
        """
        h = torch.relu(self.fc1(x))  # Apply ReLU activation
        h = torch.relu(self.fc2(h))
        mu = self.fc_mu(h)  # Get mean for latent space
        logvar = self.fc_logvar(h)  # Get log variance for latent space
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        # Linear layers to upsample the latent space back to the original input space
        self.fc1 = nn.Linear(latent_dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, z):
        """
        Forward pass to decode from latent space z back to the input space.
        """
        h = torch.relu(self.fc1(z))  # Apply ReLU activation
        h = torch.relu(self.fc2(h))
        return torch.sigmoid(self.fc3(h))  # Sigmoid to normalize output (between 0 and 1)


#### StyleGAN component

In [27]:
class StyleGAN(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(StyleGAN, self).__init__()
        # Fully connected layers to process the latent space
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, output_dim)

    def forward(self, z):
        # StyleGAN adds complexity by introducing deeper fully connected layers
        h = torch.relu(self.fc1(z))
        h = torch.relu(self.fc2(h))
        return torch.sigmoid(self.fc3(h))  # Sigmoid to match mel-spectrogram normalization


#### MelGAN Component

In [28]:
class MelGAN(nn.Module):
    def __init__(self):
        super(MelGAN, self).__init__()
        # Transposed convolutions to generate audio from mel-spectrogram
        self.conv1 = nn.ConvTranspose1d(1, 128, 4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose1d(128, 64, 4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose1d(64, 1, 4, stride=2, padding=1)

    def forward(self, x):
        # MelGAN takes in a mel-spectrogram and generates waveforms using transpose convolutions
        if x.shape[1] != 1:  
            x = x.unsqueeze(1)  # Ensure it has 1 channel for ConvTranspose1d
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        return torch.tanh(self.conv3(x))  # Tanh ensures the output waveform is within [-1, 1]


#### DrumLoopVAEStyleGAN Model
- Implementing LV-NTF for Tensor Factorization of Audio Data
- Integrating LV-NTF in the VAE-StyleGAN Pipeline
- Implementing LTN (Loop Transformation Network)

In [29]:
import tensorly as tl
from tensorly.decomposition import non_negative_tucker
import numpy as np

def apply_lv_ntf(mel_spec, rank=(1, 64, 128)):
    """
    Apply Nonnegative Tucker Decomposition (LV-NTF) on mel-spectrogram.
    
    Params:
    - mel_spec: Input mel-spectrogram (2D or 3D tensor).
    - rank: The rank of decomposition (adjusted based on input dimensions).
    
    Returns:
    - core_tensor: Core tensor from LV-NTF.
    - factors: Factor matrices (latent factors).
    """
    # Ensure input mel-spectrogram is 3D (for LV-NTF)
    if mel_spec.ndim == 2:
        mel_spec = np.expand_dims(mel_spec, axis=0)  # Add channel dimension if 2D

    # Apply non-negative Tucker decomposition
    core_tensor, factors = non_negative_tucker(mel_spec, rank=rank, init='random', tol=10e-5)
    
    return core_tensor, factors

# Example usage
mel_spec = np.random.rand(128, 256)  # Sample mel-spectrogram data (2D)
core, factors = apply_lv_ntf(mel_spec, rank=(1, 64, 128))  # Adjusted rank for 3D tensor

print("Core tensor shape:", core.shape)
print("Factor shapes:", [factor.shape for factor in factors])


Core tensor shape: (1, 64, 128)
Factor shapes: [(1, 1), (128, 64), (256, 128)]


In [47]:
import torch.nn as nn
import librosa

class LoopTransformationNetwork(nn.Module):
    def __init__(self):
        super(LoopTransformationNetwork, self).__init__()
        self.time_stretch_factor = nn.Parameter(torch.tensor(1.0))  # Stretch factor, 1.0 means no stretching
        self.pitch_shift_factor = nn.Parameter(torch.tensor(0.0))   # Pitch shift in semitones, 0 means no shift

    def forward(self, audio_waveform, sample_rate=22050):
        """
        Applies time-stretching and pitch-shifting to the audio loop.
        
        Parameters:
        - audio_waveform (Tensor): Input waveform, expected shape (batch_size, samples).
        - sample_rate (int): Sampling rate of the input audio.
        
        Returns:
        - transformed_audio: Audio after time-stretching and pitch-shifting.
        """
        # Convert tensor to NumPy array for librosa
        audio_np = audio_waveform.cpu().numpy()

        # Debug: Check the shape of the audio input
        print(f"Original audio shape: {audio_np.shape}")

        # Ensure audio is a 1D array for librosa
        if audio_np.ndim == 2:
            audio_np = audio_np[0]  # Take the first channel for simplicity (assuming mono input)
        
        # Debug: After flattening to ensure it's 1D
        print(f"Processed audio shape for librosa: {audio_np.shape}")

        # Apply time-stretching using librosa with the correct keyword argument
        try:
            stretched_audio = librosa.effects.time_stretch(audio_np, rate=self.time_stretch_factor.item())  # Corrected here
            print(f"Time-stretched audio shape: {stretched_audio.shape}")
        except Exception as e:
            print(f"Error during time-stretching: {e}")
            return audio_waveform  # Return original audio if stretching fails

        # Apply pitch-shifting using librosa
        try:
            transformed_audio = librosa.effects.pitch_shift(stretched_audio, sr=sample_rate, n_steps=self.pitch_shift_factor.item())
            print(f"Pitch-shifted audio shape: {transformed_audio.shape}")
        except Exception as e:
            print(f"Error during pitch-shifting: {e}")
            return torch.tensor(stretched_audio).unsqueeze(0).to(audio_waveform.device)  # Return stretched audio if pitch shift fails

        # Convert the result back to a tensor
        transformed_audio_tensor = torch.tensor(transformed_audio, dtype=torch.float32).unsqueeze(0).to(audio_waveform.device)

        return transformed_audio_tensor

# Example usage
sample_audio = torch.randn(1, 22050)  # Simulated audio signal with batch size of 1
ltn = LoopTransformationNetwork()
transformed_audio = ltn(sample_audio)


Original audio shape: (1, 22050)
Processed audio shape for librosa: (22050,)
Time-stretched audio shape: (22050,)
Pitch-shifted audio shape: (22050,)


In [45]:
import torch
import torch.nn as nn

class DrumLoopVAEStyleGAN(nn.Module):
    def __init__(self, latent_dim, input_dim, output_dim=None):
        super(DrumLoopVAEStyleGAN, self).__init__()
        self.encoder = Encoder(input_dim=input_dim, latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim, output_dim=output_dim)
        self.stylegan = StyleGAN(latent_dim=latent_dim, output_dim=output_dim)
        self.melgan = MelGAN()
        self.ltn = LoopTransformationNetwork()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, apply_lv_ntf=False):
        if apply_lv_ntf:
            x, factors = apply_lv_ntf(x)  # Apply LV-NTF

        # Encoder step
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)

        # StyleGAN step
        generated_mel_spec = self.stylegan(z)

        # MelGAN step: Ensure the mel_spec is reshaped to 1 channel
        if generated_mel_spec.dim() == 2:
            generated_mel_spec = generated_mel_spec.unsqueeze(1)

        # Audio generation with MelGAN
        audio = self.melgan(generated_mel_spec)

        # Apply LTN (Loop Transformations) to generated audio
        transformed_audio = self.ltn(audio)

        return transformed_audio, generated_mel_spec, mu, logvar


#### VAE Loss and GAN Loss Functions
 The VAE loss function includes both reconstruction loss and KL divergence. It is essential for balancing the quality of reconstruction with ensuring a useful latent space.


In [37]:
import torch.nn.functional as F

def vae_loss(reconstructed_x, x, mu, logvar, kl_weight=1.5):
    """
    VAE loss function combining the reconstruction loss and KL divergence.
    
    Args:
    - reconstructed_x: The output of the VAE's decoder (reconstructed mel-spectrogram).
    - x: The input mel-spectrogram.
    - mu: Mean of the latent variables.
    - logvar: Log variance of the latent variables.
    - kl_weight: Weight applied to the KL divergence term (annealed over epochs).

    Returns:
    - total_loss: The combined VAE loss (reconstruction + KL divergence).
    - recon_loss: The MSE reconstruction loss.
    - kl_div: The KL divergence term.
    """
    # Ensure dimensions match for reconstruction loss
    if reconstructed_x.size(-1) != x.size(-1):
        reconstructed_x = reconstructed_x[:, :x.size(-1)]  # Adjust dimensions
    
    # Reconstruction loss (Mean Squared Error between input and reconstructed output)
    recon_loss = F.mse_loss(reconstructed_x, x)

    # Clamp logvar to avoid extreme values and ensure stability
    logvar = torch.clamp(logvar, min=-10, max=10)
    mu = torch.clamp(mu, min=-5, max=5)

    # KL divergence loss: Regularization term for the VAE
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl_div = torch.clamp(kl_div, min=0, max=100)  # Avoid extremely high values

    # Debug: Print KL divergence if NaN or Inf is detected
    if torch.isnan(kl_div) or torch.isinf(kl_div):
        print("NaN or Inf detected in KL divergence")

    # Return total loss with KL weighting
    return recon_loss + kl_weight * kl_div, recon_loss, kl_div


def gan_loss(discriminator, real_data, fake_data):
    """
    GAN loss function using binary cross-entropy (BCE) for the discriminator.
    
    Args:
    - discriminator: The discriminator network.
    - real_data: Real data (ground truth) for the discriminator to classify.
    - fake_data: Generated (fake) data from the generator.

    Returns:
    - total_loss: The combined loss from real and fake data classifications.
    """
    real_loss = F.binary_cross_entropy(discriminator(real_data), torch.ones_like(real_data))
    fake_loss = F.binary_cross_entropy(discriminator(fake_data), torch.zeros_like(fake_data))
    
    return real_loss + fake_loss


##### DataLoader for Audio Dataset

In [38]:
# Path to the folder containing trimmed audio files
TRIMMED_AUDIO_FOLDER_PATH = "D:/LV-NTF+LoopGAN/data/FSL10K-trimmed"

# Function to get all audio file paths from the directory
def get_audio_file_paths(folder_path):
    """
    Recursively gathers all .wav files in the provided folder.
    
    Args:
    - folder_path: The root folder where the .wav files are located.

    Returns:
    - List of absolute paths to the audio files.
    """
    audio_paths = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".wav"):  # Only select .wav files
                audio_paths.append(os.path.abspath(os.path.join(root, file)))
    return audio_paths

# Load the audio file paths
audio_paths = get_audio_file_paths(TRIMMED_AUDIO_FOLDER_PATH)

# Custom Dataset class to process audio data
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, audio_paths, metadata, feature_extraction_fn, required_input_dim=128):
        self.audio_paths = audio_paths
        self.metadata = metadata or {}
        self.feature_extraction_fn = feature_extraction_fn
        self.required_input_dim = required_input_dim

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        audio_file = self.audio_paths[idx]
        
        if not os.path.exists(audio_file):
            print(f"File not found: {audio_file}")
            return None

        # Extract features using the provided feature extraction function
        try:
            feature_dict = self.feature_extraction_fn(audio_file)
            if feature_dict is None:
                raise ValueError("Feature extraction returned None")
        except Exception as e:
            print(f"Error loading or extracting features from file: {audio_file}, Error: {e}")
            return None
        
        # Extract numerical features for model input
        scalar_features = np.array([
            feature_dict.get('bpm', 0),
            feature_dict.get('zero_crossing_rate', 0),
            feature_dict.get('spectral_centroid', 0),
            feature_dict.get('spectral_bandwidth', 0),
            feature_dict.get('rms_energy', 0),
        ], dtype=np.float32)

        # Combine scalar, MFCCs, and chroma features
        mfcc_features = feature_dict.get('mfccs', np.zeros(13))
        chroma_features = feature_dict.get('chroma', np.zeros(12))
        numeric_features = np.concatenate((scalar_features, mfcc_features, chroma_features)).astype(np.float32)
        
        # Normalize features and pad/truncate to required input size
        features = (numeric_features - numeric_features.min()) / (numeric_features.max() - numeric_features.min() + 1e-6)
        if features.size < self.required_input_dim:
            features = np.pad(features, (0, self.required_input_dim - features.size), 'constant')
        else:
            features = features[:self.required_input_dim]
        
        return torch.tensor(features, dtype=torch.float32)

# Collate function for batching
def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return torch.stack(batch)

# Define the dataset and DataLoader
audio_dataset = AudioDataset(audio_paths=audio_paths, feature_extraction_fn=extract_audio_features_with_metadata, metadata=metadata)
train_loader = torch.utils.data.DataLoader(
    dataset=audio_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn  # Use the custom collate function
)


#### Visualization and Audio Playback

In [39]:
import librosa.display
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np

def visualize_spectrogram(mel_spec, title="Mel-Spectrogram", sr=22050):
    """Visualizes a mel-spectrogram."""
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(librosa.power_to_db(mel_spec, ref=np.max), sr=sr, x_axis='time', y_axis='mel')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    plt.show()

def playback_audio(audio, sr=22050):
    """Plays back the generated audio."""
    display(Audio(audio, rate=sr))


In [40]:

import os
os.environ["WANDB_API_KEY"] = "c57527a5a2fe25105af1f8467d782bff19d66097"  
os.environ["WANDB_NOTEBOOK_NAME"] = "train.ipynb"  # Name your notebook in W&B

import wandb

# Initialize the W&B project with a unique name for this run
wandb.init(project="VAE-loopGAN", name="Training VAE-StyleGAN Run")

# Configuration for the experiment (can be logged to W&B for tracking)
config = {
    "learning_rate": 1e-4,
    "latent_dim": 64,
    "kl_weight": 0.0,  # Start KL weight at 0, gradually increase during training
    "epochs": 100
}

# Optionally log the configuration to W&B
wandb.config.update(config)



VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [41]:
import torch.optim as optim
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import librosa
import numpy as np
import wandb

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 64
input_dim = 128  # Input dimension from the extracted features
output_dim = 256  # Output dimension representing the mel-spectrogram size

vae_stylegan = DrumLoopVAEStyleGAN(latent_dim=latent_dim, input_dim=input_dim, output_dim=output_dim).to(device)
optimizer = optim.Adam(vae_stylegan.parameters(), lr=1e-4)  # Using Adam optimizer

# Initialize lists for tracking metrics
recon_losses = []  # Track reconstruction losses
kl_divs = []  # Track KL divergence losses
learning_rates = []  # Track learning rates
gradient_norms = []  # Track gradient norms for debugging and monitoring

# Function to load the training checkpoint
def load_checkpoint(checkpoint_path, model, optimizer):
    """
    Load model checkpoint and resume training from the last saved epoch.
    
    Parameters:
    - checkpoint_path: Path to the checkpoint file.
    - model: The VAE model instance.
    - optimizer: The optimizer instance (Adam).
    
    Returns:
    - start_epoch: The epoch number to resume training from.
    """
    if checkpoint_path and os.path.isfile(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])  # Load model weights
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # Load optimizer state
        start_epoch = checkpoint['epoch'] + 1  # Resume from the next epoch
        recon_losses.extend(checkpoint['recon_losses'])
        kl_divs.extend(checkpoint['kl_divs'])
        learning_rates.extend(checkpoint['learning_rates'])
        
        if 'gradient_norms' in checkpoint:
            gradient_norms.extend(checkpoint['gradient_norms'])
        else:
            gradient_norms = []  # Initialize empty if not found
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch
    else:
        print(f"No checkpoint found at {checkpoint_path}. Starting from scratch.")
        return 0  # Start from epoch 0 if no checkpoint exists

# Function to plot training dynamics
def plot_training_dynamics():
    """
    Plot graphs for training dynamics: reconstruction loss, KL divergence, learning rate, and gradient norms.
    """
    epochs = list(range(1, len(recon_losses) + 1))
    plt.figure(figsize=(15, 8))

    # Plot reconstruction loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, recon_losses, label="Reconstruction Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Reconstruction Loss")

    # Plot KL divergence
    plt.subplot(2, 2, 2)
    plt.plot(epochs, kl_divs, label="KL Divergence")
    plt.xlabel("Epoch")
    plt.ylabel("KL Divergence")
    plt.title("KL Divergence")

    # Plot learning rate over time
    plt.subplot(2, 2, 3)
    plt.plot(epochs, learning_rates, label="Learning Rate")
    plt.xlabel("Epoch")
    plt.ylabel("Learning Rate")
    plt.title("Learning Rate Over Time")

    # Plot gradient norms
    plt.subplot(2, 2, 4)
    plt.plot(epochs, gradient_norms, label="Gradient Norm")
    plt.xlabel("Epoch")
    plt.ylabel("Gradient Norm")
    plt.title("Gradient Norm Over Time")

    plt.tight_layout()
    plt.show()

# Load checkpoint if available
checkpoint_path = "vae_stylegan_checkpoint_epoch_10.pth"
start_epoch = load_checkpoint(checkpoint_path, vae_stylegan, optimizer)

# Training loop
num_epochs = 100  # Total number of epochs to train
for epoch in range(start_epoch, num_epochs):
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", dynamic_ncols=True)
    vae_stylegan.train()
    avg_loss = 0.0
    batch_count = 0

    # Initialize tracking for the current epoch's reconstruction and KL losses
    epoch_recon_loss = 0.0
    epoch_kl_div = 0.0

    # Gradual KL weight increase after 10 epochs
    kl_weight = min(1.0, (epoch - 10) / 20) if epoch >= 10 else 0

    for batch in progress_bar:
        if batch is None:
            continue  # Skip invalid batches

        # Move input features to the device
        features = batch.to(device)
        features = (features - features.min()) / (features.max() - features.min() + 1e-6)  # Normalize input

        # Forward pass through the VAE-StyleGAN
        audio_output, mel_spec, mu, logvar = vae_stylegan(features.unsqueeze(1))  # Add channel dimension
        # Detach for logging/debugging purposes (optional)
        mu_np = mu.detach().cpu().numpy()  # Detach before using NumPy if needed for logging
        logvar_np = logvar.detach().cpu().numpy()  # Detach for logging/debugging

        total_loss, recon_loss, kl_div = vae_loss(mel_spec.squeeze(1), features, mu, logvar, kl_weight)

        # Perform gradient clipping to avoid exploding gradients and update weights
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(vae_stylegan.parameters(), max_norm=5.0)
        optimizer.step()

        # Update metrics for tracking
        avg_loss += total_loss.item()
        batch_count += 1
        epoch_recon_loss += recon_loss.item()
        epoch_kl_div += kl_div.item()

        # Display progress
        progress_bar.set_postfix({
            "Batch Loss": f"{total_loss.item():.3f}",
            "Recon Loss": f"{recon_loss.item():.3f}",
            "KL Div": f"{kl_div.item():.2f}",
            "KL Weight": f"{kl_weight:.2f}"
        })

    # Log metrics to W&B after every epoch
    wandb.log({
        "epoch": epoch + 1,
        "Batch Loss": avg_loss / batch_count,
        "Recon Loss": epoch_recon_loss / batch_count,
        "KL Div": epoch_kl_div / batch_count,
        "KL Weight": kl_weight,
        "mu_mean": mu.mean().item(),
        "mu_std": mu.std().item(),
        "logvar_mean": logvar.mean().item(),
        "logvar_std": logvar.std().item()
    })

    # Save model checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': vae_stylegan.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'recon_losses': recon_losses,
            'kl_divs': kl_divs,
            'learning_rates': learning_rates,
            'gradient_norms': gradient_norms
        }
        checkpoint_path = f"vae_stylegan_checkpoint_epoch_{epoch+1}.pth"
        torch.save(checkpoint, checkpoint_path)
        wandb.save(checkpoint_path)

# Plot the training dynamics after completing all epochs
plot_training_dynamics()


No checkpoint found at vae_stylegan_checkpoint_epoch_10.pth. Starting from scratch.


Epoch 1/100:   0%|          | 0/297 [00:00<?, ?batch/s]

  return pitch_tuning(


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

#### Inference 

In [None]:
import torch
import librosa
import numpy as np
import os

# Load the trained model checkpoint
checkpoint_path = "vae_stylegan_checkpoint_epoch_10.pth"  # Adjust with your actual checkpoint
vae_stylegan = DrumLoopVAEStyleGAN(latent_dim=64, input_dim=128, output_dim=256).to(device)
vae_stylegan.eval()

# Load the model weights from the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
vae_stylegan.load_state_dict(checkpoint['model_state_dict'])

# Function to generate samples from the VAE latent space
def generate_samples(num_samples=5, latent_dim=64, n_mels=128, noise_factor=0.0):
    """ Generates audio samples from the VAE-StyleGAN latent space """
    with torch.no_grad():
        # Generate latent vectors from normal distribution
        z = torch.randn(num_samples, latent_dim).to(device)
        z = z + torch.randn_like(z) * noise_factor  # Add noise to the latent space
        generated_mel_spectrograms = vae_stylegan.decode(z)

        # Check the shape of the generated mel-spectrograms
        print(f"Generated mel-spectrograms shape: {generated_mel_spectrograms.shape}")

        # Infer time steps dynamically based on the output size
        total_elements = generated_mel_spectrograms.numel()
        time_steps = total_elements // (num_samples * n_mels)

        if total_elements != num_samples * n_mels * time_steps:
            raise ValueError(f"Cannot reshape tensor of size {total_elements} into shape ({num_samples}, {n_mels}, {time_steps}).")

        # Reshape the generated mel-spectrograms to the correct dimensions
        generated_mel_spectrograms = generated_mel_spectrograms.view(num_samples, n_mels, time_steps)

        return generated_mel_spectrograms.cpu().numpy()

# Function to convert mel-spectrogram to waveform using Griffin-Lim
def mel_to_audio(mel_spec, sr=22050, n_fft=32):
    """ Convert mel-spectrogram back to audio using Griffin-Lim algorithm """
    mel_spec = np.exp(mel_spec)  # Convert log-mel to linear-mel if applicable

    # Ensure that the mel spectrogram is properly scaled
    if mel_spec.ndim != 2:
        raise ValueError(f"Expected a 2D Mel-spectrogram, got shape: {mel_spec.shape}")

    # Convert mel-spectrogram to audio waveform
    audio_waveform = librosa.feature.inverse.mel_to_audio(mel_spec, sr=sr, n_fft=n_fft)
    return audio_waveform

# Generate some samples from the latent space
num_samples = 5
generated_samples = generate_samples(num_samples)

# Ensure the folder exists for saving audio
output_dir = r"D:\LV-NTF+LoopGAN\inference_audio"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save the waveform to the specified folder
import soundfile as sf
for i, mel_spec in enumerate(generated_samples):
    mel_spec = np.exp(mel_spec)  # Convert log-mel to linear-mel if needed
    print(f"Mel-spectrogram shape: {mel_spec.shape}")

    # Convert the mel-spectrogram back to audio waveform
    audio_waveform = mel_to_audio(mel_spec)

    # Define the file path for each generated audio file
    file_path = os.path.join(output_dir, f"generated_sample_{i}.wav")

    # Save the waveform to the specified folder
    sf.write(file_path, audio_waveform, samplerate=22050)
    print(f"Saved generated audio to {file_path}")
