In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

# models
from style_encoder import StyleEncoder, initialize_weights
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss  # Nuovo decoder dinamico
from losses import infoNCE_loss, margin_loss, adversarial_loss, disentanglement_loss
from Dataloader import get_dataloader

# device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# hyperparameters
EPOCHS = 50
BATCH_SIZE = 8
LR_GEN = 2e-4  # Learning rate per Encoders + Decoder
LR_DISC = 1e-4  # Learning rate per Discriminator
TRANSFORMER_DIM = 256
NUM_FRAMES = 4
STFT_T, STFT_F = 287, 513 # Dimensioni STFT (n_fft=1024)
CQT_T, CQT_F = 287, 84   # Dimensioni CQT

# loss weights
LAMBDA_RECON = 10.0
LAMBDA_INFO_NCE = 1.0
LAMBDA_MARGIN = 1.0
LAMBDA_DISENTANGLE = 0.1
LAMBDA_ADV_GEN = 0.5 # Peso per la loss avversaria del generatore

# Pesi per le loss comprehensive del decoder
LAMBDA_TEMPORAL = 0.3
LAMBDA_PHASE = 0.2
LAMBDA_SPECTRAL = 0.1
LAMBDA_CONSISTENCY = 0.1

MODEL_SAVE_PATH = "./saved_models"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

Using device: cuda
GPU: NVIDIA GeForce RTX 3080
GPU Memory: 10.0 GB


aggiungere data augmentation

In [None]:
def train():
    print(f"Training on device: {device}")

    # load models
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM).to(device)

    # decoder
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,  # Ridotto per Colab
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    ).to(device)
    
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM).to(device)

    # initialize weights
    initialize_weights(style_encoder)
    initialize_weights(content_encoder)
    initialize_weights(discriminator)
    # decoder gi√† ha la sua inizializzazione

    # optimizer for generators (style encoder, content encoder, decoder)
    optimizer_G = optim.Adam(
        list(style_encoder.parameters()) + list(content_encoder.parameters()) + list(decoder.parameters()),
        lr=LR_GEN, betas=(0.5, 0.999)
    )
    
    # optimizer for discriminator
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR_DISC, betas=(0.5, 0.999))

    # Create train and validation dataloaders
    train_dataloader = get_dataloader(
        piano_dir="dataset/train/piano",
        violin_dir="dataset/train/violin",
        batch_size=BATCH_SIZE,
        shuffle=True,
        stats_path="stats_stft_cqt.npz"
    )
    
    val_dataloader = get_dataloader(
        piano_dir="dataset/val/piano", 
        violin_dir="dataset/val/violin",
        batch_size=BATCH_SIZE,
        shuffle=False,  # No shuffle for validation
        stats_path="stats_stft_cqt.npz"
    )
    
    print(f"Training batches: {len(train_dataloader)}")
    print(f"Validation batches: {len(val_dataloader)}")
    
    # loss function for reconstruction
    def recon_loss_fn(output, target):
        loss_dict = compute_comprehensive_loss(
            output, target, 
            lambda_temporal=LAMBDA_TEMPORAL,
            lambda_phase=LAMBDA_PHASE,
            lambda_spectral=LAMBDA_SPECTRAL,
            lambda_consistency=LAMBDA_CONSISTENCY
        )
        return loss_dict['total_loss'], loss_dict
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    # train loop with epoch progress bar
    epoch_pbar = tqdm(range(EPOCHS), desc="Training", unit="epoch")
    
    for epoch in epoch_pbar:
        epoch_pbar.set_description(f"Epoch {epoch+1}/{EPOCHS}")
        
        # ================================================================== #
        #                             TRAINING                               #
        # ================================================================== #
        style_encoder.train()
        content_encoder.train()
        decoder.train()
        discriminator.train()
        
        train_loss_epoch = 0
        train_recon_loss_epoch = 0
        train_batches = 0
        
        for i, (x, labels) in enumerate(train_dataloader):
            x, labels = x.to(device), labels.to(device) # x: (B, S, 2, T, F)
            stft_part = x[:, :, :, :, :STFT_F]  # STFT part

            # ================================================================== #
            #                             Discriminator                          #
            # ================================================================== #
            optimizer_D.zero_grad()
            
            # with torch.no_grad() to avoid computing gradients for the encoders    <----------------
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
            
            # adversarial loss for the discriminator
            discriminator_loss, _ = adversarial_loss(style_emb.detach(), class_emb.detach(), 
                                                     content_emb.detach(), discriminator, labels, 
                                                     compute_for_discriminator=True)
            
            discriminator_loss.backward()
            optimizer_D.step()

            # ================================================================== #
            #               Generators (Style Encoder, Content Encoder)          #
            # ================================================================== #
            optimizer_G.zero_grad()

            # forward pass
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)

            # adversarial loss for the generator
            _, adv_generator_loss = adversarial_loss(style_emb, class_emb, content_emb, discriminator, labels,
                                                 compute_for_discriminator=False)

            # disentanglement loss
            disent_loss = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)

            if len(torch.unique(labels)) > 1:
                # contrastive losses
                loss_infonce = infoNCE_loss(style_emb, labels)
                loss_margin = margin_loss(class_emb)
            else:
                # Fallback se tutti i label sono uguali in questo batch
                loss_infonce = torch.tensor(0.0, device=device)
                loss_margin = torch.tensor(0.0, device=device)

            # reconstruction loss
            reconstructed_spec = decoder(content_emb, style_emb, y=stft_part)  # y=x per teacher forcing
            loss_recon, loss_dict = recon_loss_fn(reconstructed_spec, stft_part)

            if torch.isnan(loss_recon):
                print(f"‚ö†Ô∏è NaN detected in reconstruction loss at batch {i+1}")
                print(f"   Reconstructed spec stats: min={reconstructed_spec.min():.4f}, max={reconstructed_spec.max():.4f}")
                print(f"   Target spec stats: min={stft_part.min():.4f}, max={stft_part.max():.4f}")
                continue  # Skip this batch
            
            # total generator loss
            total_gen_loss = (
                LAMBDA_RECON * loss_recon +
                LAMBDA_INFO_NCE * loss_infonce +
                LAMBDA_MARGIN * loss_margin +
                LAMBDA_DISENTANGLE * disent_loss +
                LAMBDA_ADV_GEN * adv_generator_loss
            )

            total_gen_loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(
                list(style_encoder.parameters()) + 
                list(content_encoder.parameters()) + 
                list(decoder.parameters()), 
                max_norm=1.0
            )

            optimizer_G.step()

            # Accumulate losses
            train_loss_epoch += total_gen_loss.item()
            train_recon_loss_epoch += loss_recon.item()
            train_batches += 1

            # Print batch metrics every batch
            print(f"Epoch {epoch+1}/{EPOCHS} - Batch {i+1}/{len(train_dataloader)} | "
                  f"D_loss: {discriminator_loss.item():.4f} | "
                  f"G_loss: {total_gen_loss.item():.4f} | "
                  f"Recon: {loss_recon.item():.4f} | "
                  f"InfoNCE: {loss_infonce.item():.4f} | "
                  f"Margin: {loss_margin.item():.4f} | "
                  f"Disentangle: {disent_loss.item():.4f}")
        
        # Average training losses
        avg_train_loss = train_loss_epoch / train_batches
        avg_train_recon_loss = train_recon_loss_epoch / train_batches
        train_losses.append(avg_train_loss)
        
        # ================================================================== #
        #                            VALIDATION                              #
        # ================================================================== #
        style_encoder.eval()
        content_encoder.eval()
        decoder.eval()
        discriminator.eval()
        
        val_loss_epoch = 0
        val_recon_loss_epoch = 0
        val_batches = 0
        
        print(f"\nüîç Running validation for epoch {epoch+1}...")
        
        with torch.no_grad():
            for i, (x, labels) in enumerate(val_dataloader):
                x, labels = x.to(device), labels.to(device)
                stft_part = x[:, :, :, :, :STFT_F]

                # Forward pass
                style_emb, class_emb = style_encoder(x, labels)
                content_emb = content_encoder(x)

                # Validation losses (only the main ones)
                if len(torch.unique(labels)) > 1:
                    loss_infonce = infoNCE_loss(style_emb, labels)
                    loss_margin = margin_loss(class_emb)
                else:
                    loss_infonce = torch.tensor(0.0, device=device)
                    loss_margin = torch.tensor(0.0, device=device)

                disent_loss = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)
                
                # Reconstruction loss
                reconstructed_spec = decoder(content_emb, style_emb, y=stft_part)
                loss_recon, _ = recon_loss_fn(reconstructed_spec, stft_part)
                
                # Total validation loss
                total_val_loss = (
                    LAMBDA_RECON * loss_recon +
                    LAMBDA_INFO_NCE * loss_infonce +
                    LAMBDA_MARGIN * loss_margin +
                    LAMBDA_DISENTANGLE * disent_loss
                )

                val_loss_epoch += total_val_loss.item()
                val_recon_loss_epoch += loss_recon.item()
                val_batches += 1

        # Average validation losses
        avg_val_loss = val_loss_epoch / val_batches
        avg_val_recon_loss = val_recon_loss_epoch / val_batches
        val_losses.append(avg_val_loss)
        
        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"    Train Loss: {avg_train_loss:.4f} | Train Recon: {avg_train_recon_loss:.4f}")
        print(f"    Val Loss:   {avg_val_loss:.4f} | Val Recon:   {avg_val_recon_loss:.4f}")
        
        # Update progress bar with current losses
        epoch_pbar.set_postfix({
            'Train_Loss': f'{avg_train_loss:.4f}',
            'Val_Loss': f'{avg_val_loss:.4f}',
            'Best_Val': f'{best_val_loss:.4f}'
        })
        
        # Save best model based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"New best validation loss: {best_val_loss:.4f}. Saving model...")
            
            torch.save({
                'epoch': epoch,
                'style_encoder_state_dict': style_encoder.state_dict(),
                'content_encoder_state_dict': content_encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'best_val_loss': best_val_loss,
                'train_losses': train_losses,
                'val_losses': val_losses,
            }, os.path.join(MODEL_SAVE_PATH, 'best_model.pth'))
        
        # Early stopping check (opzionale)
        if epoch > 10 and avg_val_loss > max(val_losses[-5:]):
            print("‚ö†Ô∏è Validation loss not improving. Consider early stopping.")
    
    print(f"\nTraining completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    return train_losses, val_losses

In [6]:
# Avvia il training e cattura le loss curves
train_losses, val_losses = train()

Training on device: cuda
üìä Training batches: 77
üìä Validation batches: 10


Epoch 1/50:   0%|          | 0/50 [00:00<?, ?epoch/s]

Epoch 1/50 - Batch 1/77 | D_loss: 1.7941 | G_loss: 68.5450 | Recon: 6.6067 | InfoNCE: 2.8168 | Margin: 0.0000 | Disentangle: 0.0014
‚ö†Ô∏è NaN detected in reconstruction loss at batch 2
   Reconstructed spec stats: min=nan, max=nan
   Target spec stats: min=-141.0991, max=128.4867
‚ö†Ô∏è NaN detected in reconstruction loss at batch 2
   Reconstructed spec stats: min=nan, max=nan
   Target spec stats: min=-141.0991, max=128.4867
‚ö†Ô∏è NaN detected in reconstruction loss at batch 3
   Reconstructed spec stats: min=nan, max=nan
   Target spec stats: min=-115.1092, max=108.7914
‚ö†Ô∏è NaN detected in reconstruction loss at batch 3
   Reconstructed spec stats: min=nan, max=nan
   Target spec stats: min=-115.1092, max=108.7914
‚ö†Ô∏è NaN detected in reconstruction loss at batch 4
   Reconstructed spec stats: min=nan, max=nan
   Target spec stats: min=-77.2417, max=88.3961
‚ö†Ô∏è NaN detected in reconstruction loss at batch 4
   Reconstructed spec stats: min=nan, max=nan
   Target spec stats

Epoch 1/50:   0%|          | 0/50 [01:19<?, ?epoch/s]



KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as ipd
import librosa
import torch.nn.functional as F

def test_dataloader_and_play_audio():
    """
    Testa il dataloader e riproduce gli audio di un batch casuale
    """
        print("üéµ TESTING DATALOADER AND PLAYING AUDIO")
    print("=" * 60)
    
    # Get a single batch from the dataloader
    try:
        # Crea dataloader di test
        test_dataloader = get_dataloader(
            piano_dir="dataset/train/piano",
            violin_dir="dataset/train/violin",
            batch_size=4,  # Batch pi√π piccolo per il test
            shuffle=True,
            stats_path="stats_stft_cqt.npz"
        )
        
        print(f"‚úÖ Dataloader creato con successo!")
        print(f"üìä Numero di batch: {len(test_dataloader)}")
        
        # Verify batch structure
        batch_size = x.shape[0]
        half_batch = batch_size // 2
        
        first_half_labels = labels[:half_batch]
        second_half_labels = labels[half_batch:]
        
        print(f"\nFirst half (Piano) labels: {first_half_labels}")
        print(f"Second half (Violin) labels: {second_half_labels}")
        
        # Verify the structure
        all_piano = torch.all(first_half_labels == 0)
        all_violin = torch.all(second_half_labels == 1)
        
        if all_piano and all_violin:
            print("‚úÖ BATCH STRUCTURE CORRECT: First half piano, second half violin")
        else:
            print("‚ùå BATCH STRUCTURE INCORRECT:")
            print(f"  Piano section all 0s: {all_piano}")
            print(f"  Violin section all 1s: {all_violin}")
        
        # Test with actual dataloader if available
        print("\n" + "=" * 60)
        print("üéπ TESTING WITH ACTUAL DATALOADER (if available)")
        print("=" * 60)
        
        # Try to create real dataloader
        try:
            piano_dir = "dataset/piano"
            violin_dir = "dataset/violin"
            
            if os.path.exists(piano_dir) and os.path.exists(violin_dir):
                real_dataloader = get_dataloader(piano_dir, violin_dir, batch_size=2)
                
                # Get a batch
                x_real, labels_real = next(iter(real_dataloader))
                
                print(f"Real batch shape: {x_real.shape}")
                print(f"Real labels shape: {labels_real.shape}")
                print(f"Real labels: {labels_real}")
                
                # Verify structure
                half_batch_real = x_real.shape[0] // 2
                first_half_real = labels_real[:half_batch_real]
                second_half_real = labels_real[half_batch_real:]
                
                print(f"First half (Piano): {first_half_real}")
                print(f"Second half (Violin): {second_half_real}")
                
                piano_ok = torch.all(first_half_real == 0)
                violin_ok = torch.all(second_half_real == 1)
                
                if piano_ok and violin_ok:
                    print("‚úÖ REAL DATALOADER STRUCTURE CORRECT")
                else:
                    print("‚ùå REAL DATALOADER STRUCTURE INCORRECT")
                    
            else:
                print("‚ö†Ô∏è Real dataset directories not found, using dummy data only")
                
        except Exception as e:
            print(f"‚ö†Ô∏è Error with real dataloader: {e}")
            print("Using dummy dataloader only")
    
    except Exception as e:
        print(f"‚ùå Error testing dataloader: {e}")
        import traceback
        traceback.print_exc()

# Run the test
test_dataloader_and_play_audio()

üéµ TESTING DATALOADER AND PLAYING AUDIO
‚úÖ Dataloader creato con successo!
üìä Numero di batch: 154
üé≤ Batch casuale scelto: 12

üéπ CONTENUTO DEL BATCH:
   Labels numerici: [0 1 0 0]
   Strumenti:
     Sample 0: Piano
     Sample 1: Violin
     Sample 2: Piano
     Sample 3: Piano

üéµ RIPRODUCI AUDIO DEL BATCH:

üîä Sample 0 - Piano (durata: 3.32s)

üéπ CONTENUTO DEL BATCH:
   Labels numerici: [0 1 0 0]
   Strumenti:
     Sample 0: Piano
     Sample 1: Violin
     Sample 2: Piano
     Sample 3: Piano

üéµ RIPRODUCI AUDIO DEL BATCH:

üîä Sample 0 - Piano (durata: 3.32s)



üîä Sample 1 - Violin (durata: 3.32s)



üîä Sample 2 - Piano (durata: 3.32s)



üîä Sample 3 - Piano (durata: 3.32s)



‚úÖ Test completato!


In [None]:
def play_batch_audio(x, labels, max_samples=4):
    """
    Play audio for samples in a batch with proper labels
    
    Args:
        x: Batch tensor of shape (B, S, 2, T, F)
        labels: Labels tensor of shape (B,)
        max_samples: Maximum number of samples to play
    """
    print("=" * 60)
    print("üîä PLAYING AUDIO SAMPLES FROM BATCH")
    print("=" * 60)
    
    try:
        from IPython.display import Audio, display
        import torch
        
        batch_size = x.shape[0]
        num_samples = min(batch_size, max_samples)
        
        print(f"Playing {num_samples} samples from batch of {batch_size}")
        
        for i in range(num_samples):
            sample = x[i]  # Shape: (S, 2, T, F)
            label = labels[i].item()
            instrument = "üéπ Piano" if label == 0 else "üéª Violin"
            
            print(f"\nSample {i+1}: {instrument} (Label: {label})")
            print(f"Sample shape: {sample.shape}")
            
            # For audio playback, we need to convert back to time domain
            # This would require inverse STFT, which is complex
            # For now, we'll just display the spectral information
            
            # Get the first sequence frame
            first_frame = sample[0]  # Shape: (2, T, F)
            
            print(f"  First frame shape: {first_frame.shape}")
            print(f"  Channels: {first_frame.shape[0]}")
            print(f"  Time frames: {first_frame.shape[1]}")
            print(f"  Frequency bins: {first_frame.shape[2]}")
            
            # Calculate some statistics
            mean_magnitude = torch.mean(torch.abs(first_frame))
            max_magnitude = torch.max(torch.abs(first_frame))
            
            print(f"  Mean magnitude: {mean_magnitude:.4f}")
            print(f"  Max magnitude: {max_magnitude:.4f}")
            
            # Note: To actually play audio, we would need:
            # 1. Separate STFT and CQT parts
            # 2. Apply inverse STFT to get time domain signal
            # 3. Use IPython.display.Audio to play
            
            print("  (Audio playback requires inverse STFT - not implemented)")
            
    except ImportError:
        print("‚ö†Ô∏è IPython not available - cannot play audio")
    except Exception as e:
        print(f"‚ùå Error playing audio: {e}")

# Test audio playback with dummy data
print("Testing audio playback functionality...")
dummy_dataloader = get_dummy_dataloader(batch_size=2)
x_test, labels_test = next(iter(dummy_dataloader))
play_batch_audio(x_test, labels_test, max_samples=4)

In [None]:
def demonstrate_batch_structure():
    """
    Demonstrate the new batch structure with visual representation
    """
    print("=" * 80)
    print("üéØ DEMONSTRATING NEW BATCH STRUCTURE")
    print("=" * 80)
    
    # Create a test batch
    dataloader = get_dummy_dataloader(batch_size=3)  # This becomes 6 samples
    x, labels = next(iter(dataloader))
    
    batch_size = x.shape[0]
    half_batch = batch_size // 2
    
    print(f"Original batch_size parameter: 3")
    print(f"Actual batch size: {batch_size} (doubled)")
    print(f"Half batch size: {half_batch}")
    
    print("\n" + "=" * 80)
    print("üìä BATCH STRUCTURE VISUALIZATION")
    print("=" * 80)
    
    for i in range(batch_size):
        label = labels[i].item()
        instrument = "üéπ Piano" if label == 0 else "üéª Violin"
        section = "First Half" if i < half_batch else "Second Half"
        
        print(f"Sample {i+1:2d}: {instrument} (Label: {label}) - {section}")
    
    print("\n" + "=" * 80)
    print("‚úÖ BENEFITS OF THIS STRUCTURE")
    print("=" * 80)
    
    print("1. üéØ Every batch contains both piano and violin samples")
    print("2. üîÑ Adversarial loss always has both classes to work with")
    print("3. üìä Consistent training dynamics across all batches")
    print("4. üõ°Ô∏è Prevents batches with only one class (which caused errors)")
    print("5. üéµ Maintains balance between instruments in every batch")
    
    print("\n" + "=" * 80)
    print("‚öôÔ∏è IMPLEMENTATION DETAILS")
    print("=" * 80)
    
    print("‚Ä¢ Dataset.__getitem__ returns both piano and violin for each index")
    print("‚Ä¢ collate_fn separates piano and violin samples")
    print("‚Ä¢ First half of batch = all piano samples (label 0)")
    print("‚Ä¢ Second half of batch = all violin samples (label 1)")
    print("‚Ä¢ Actual batch size = 2 √ó requested batch size")
    print("‚Ä¢ This ensures robust adversarial training")

# Run demonstration
demonstrate_batch_structure()

In [None]:
# Funzione per testare il modello salvato
def test_saved_model():
    """Testa il modello salvato sul test set"""
    print("üß™ Testing saved model...")
    
    # Load the saved model
    checkpoint_path = os.path.join(MODEL_SAVE_PATH, 'best_model.pth')
    if not os.path.exists(checkpoint_path):
        print("‚ùå No saved model found!")
        return
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Create models
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    ).to(device)
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM).to(device)
    
    # Load state dicts
    style_encoder.load_state_dict(checkpoint['style_encoder_state_dict'])
    content_encoder.load_state_dict(checkpoint['content_encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    
    # Set to evaluation mode
    style_encoder.eval()
    content_encoder.eval()
    decoder.eval()
    discriminator.eval()
    
    # Create test dataloader
    test_dataloader = get_dataloader(
        piano_dir="dataset/test/piano",
        violin_dir="dataset/test/violin", 
        batch_size=BATCH_SIZE,
        shuffle=False,
        stats_path="stats_stft_cqt.npz"
    )
    
    print(f"üìä Test batches: {len(test_dataloader)}")
    
    # Test metrics
    test_loss = 0
    test_recon_loss = 0
    test_batches = 0
    
    def recon_loss_fn(output, target):
        loss_dict = compute_comprehensive_loss(
            output, target, 
            lambda_temporal=LAMBDA_TEMPORAL,
            lambda_phase=LAMBDA_PHASE,
            lambda_spectral=LAMBDA_SPECTRAL,
            lambda_consistency=LAMBDA_CONSISTENCY
        )
        return loss_dict['total_loss'], loss_dict
    
    with torch.no_grad():
        for i, (x, labels) in enumerate(tqdm(test_dataloader, desc="Testing")):
            x, labels = x.to(device), labels.to(device)
            stft_part = x[:, :, :, :, :STFT_F]
            
            # Forward pass
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
            
            # Reconstruction
            reconstructed_spec = decoder(content_emb, style_emb, y=stft_part)
            loss_recon, loss_dict = recon_loss_fn(reconstructed_spec, stft_part)
            
            # Additional losses
            if len(torch.unique(labels)) > 1:
                loss_infonce = infoNCE_loss(style_emb, labels)
                loss_margin = margin_loss(class_emb)
            else:
                loss_infonce = torch.tensor(0.0, device=device)
                loss_margin = torch.tensor(0.0, device=device)
            
            disent_loss = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)
            
            # Total test loss
            total_test_loss = (
                LAMBDA_RECON * loss_recon +
                LAMBDA_INFO_NCE * loss_infonce +
                LAMBDA_MARGIN * loss_margin +
                LAMBDA_DISENTANGLE * disent_loss
            )
            
            test_loss += total_test_loss.item()
            test_recon_loss += loss_recon.item()
            test_batches += 1
    
    # Average test losses
    avg_test_loss = test_loss / test_batches
    avg_test_recon_loss = test_recon_loss / test_batches
    
    print(f"\nüìä Test Results:")
    print(f"   Test Loss: {avg_test_loss:.4f}")
    print(f"   Test Recon Loss: {avg_test_recon_loss:.4f}")
    print(f"   Best Val Loss: {checkpoint['best_val_loss']:.4f}")
    
    # Performance comparison
    performance_gap = avg_test_loss - checkpoint['best_val_loss']
    if performance_gap < 0.1:
        print("‚úÖ Excellent generalization!")
    elif performance_gap < 0.3:
        print("üü° Good generalization")
    else:
        print("üî¥ Poor generalization - model might be overfitting")
    
    return avg_test_loss, avg_test_recon_loss

# Uncomment to test the saved model
# test_loss, test_recon_loss = test_saved_model()

In [9]:
def count_trainable_parameters():
    """Conta i parametri allenabili di tutto il modello"""
    print("=" * 60)
    print("üìä ANALISI PARAMETRI ALLENABILI DEL MODELLO")
    print("=" * 60)
    
    # Crea tutti i modelli
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM)
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    )
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM)
    
    models = [
        ("Style Encoder", style_encoder),
        ("Content Encoder", content_encoder),
        ("Decoder", decoder),
        ("Discriminator", discriminator)
    ]
    
    total_params = 0
    total_trainable_params = 0
    
    for name, model in models:
        # Conta parametri totali
        model_params = sum(p.numel() for p in model.parameters())
        # Conta parametri allenabili
        model_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        total_params += model_params
        total_trainable_params += model_trainable_params
        
        # Dimensione in MB
        param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
        
        print(f"{name:>15}: {model_params:>10,} parametri ({model_trainable_params:>10,} allenabili) | {param_size_mb:>6.1f} MB")
    
    print("-" * 60)
    print(f"{'TOTALE':>15}: {total_params:>10,} parametri ({total_trainable_params:>10,} allenabili)")
    
    # Calcola dimensioni totali
    total_size_mb = 0
    for name, model in models:
        total_size_mb += sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
    
    print(f"{'DIMENSIONE':>15}: {total_size_mb:>47.1f} MB")
    
    # Stima memoria GPU per training
    # Modello + gradients + optimizer states (Adam ~ 2x params) + attivazioni
    estimated_gpu_memory = total_size_mb * 4  # rough estimate
    print(f"{'MEM. STIMATA':>15}: {estimated_gpu_memory:>44.0f} MB (solo modello)")
    
    # Analisi per batch size
    print("\n" + "=" * 60)
    print("üíæ ANALISI MEMORIA PER BATCH SIZE")
    print("=" * 60)
    
    batch_sizes = [4, 8, 16, 32]
    for bs in batch_sizes:
        # Stima memoria per batch (approssimativa)
        # Input: (B, S, 2, T, F) * 4 bytes per float32
        input_size_mb = bs * NUM_FRAMES * 2 * STFT_T * (STFT_F + CQT_F) * 4 / (1024**2)
        
        # Memoria totale stimata
        total_mem_mb = estimated_gpu_memory + input_size_mb * 2  # input + gradients
        
        status = "‚úÖ" if total_mem_mb < 8000 else "‚ö†Ô∏è" if total_mem_mb < 12000 else "‚ùå"
        print(f"Batch size {bs:>2}: {total_mem_mb:>6.0f} MB {status}")
    
    print("\n" + "=" * 60)
    print("üéØ CONFIGURAZIONE RACCOMANDATA")
    print("=" * 60)
    
    if total_trainable_params < 50_000_000:  # 50M
        print("‚úÖ Modello di dimensioni ragionevoli per il training")
    elif total_trainable_params < 100_000_000:  # 100M
        print("‚ö†Ô∏è Modello grande - considerare mixed precision training")
    else:
        print("‚ùå Modello molto grande - ridurre dimensioni o usare tecniche avanzate")
    
    # Raccomandazioni finali
    print(f"\nüöÄ Per GPU con 12GB: Batch size raccomandato = {8 if estimated_gpu_memory < 4000 else 4}")
    print(f"üîß Per GPU con 8GB:  Batch size raccomandato = {4 if estimated_gpu_memory < 3000 else 2}")
    
    return total_trainable_params, total_size_mb

# Esegui l'analisi
total_params, model_size = count_trainable_parameters()

üìä ANALISI PARAMETRI ALLENABILI DEL MODELLO
  Style Encoder: 12,905,312 parametri (12,905,312 allenabili) |   49.2 MB
Content Encoder:  4,450,480 parametri ( 4,450,480 allenabili) |   17.0 MB
        Decoder:  3,681,515 parametri ( 3,681,515 allenabili) |   14.0 MB
  Discriminator:     49,666 parametri (    49,666 allenabili) |    0.2 MB
------------------------------------------------------------
         TOTALE: 21,086,973 parametri (21,086,973 allenabili)
     DIMENSIONE:                                            80.4 MB
   MEM. STIMATA:                                          322 MB (solo modello)

üíæ ANALISI MEMORIA PER BATCH SIZE
Batch size  4:    364 MB ‚úÖ
Batch size  8:    405 MB ‚úÖ
Batch size 16:    489 MB ‚úÖ
Batch size 32:    656 MB ‚úÖ

üéØ CONFIGURAZIONE RACCOMANDATA
‚úÖ Modello di dimensioni ragionevoli per il training

üöÄ Per GPU con 12GB: Batch size raccomandato = 8
üîß Per GPU con 8GB:  Batch size raccomandato = 4


In [None]:
# Funzioni di utilit√† per Colab
import psutil
import gc

def monitor_resources():
    """Monitora le risorse di sistema"""
    # CPU
    cpu_percent = psutil.cpu_percent()
    
    # RAM
    ram = psutil.virtual_memory()
    ram_percent = ram.percent
    
    # GPU
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.memory_allocated() / 1024**2
        gpu_cached = torch.cuda.memory_reserved() / 1024**2
        
        print(f"CPU: {cpu_percent:.1f}% | RAM: {ram_percent:.1f}% | GPU: {gpu_memory:.0f}MB/{gpu_cached:.0f}MB")
        
        # Allarme se troppa memoria
        if gpu_memory > 8000:  # >8GB
            print("‚ö†Ô∏è ATTENZIONE: Memoria GPU alta!")
            return True
    return False

def clear_memory():
    """Libera memoria GPU"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        print("üßπ Memoria GPU liberata")

def get_model_size(model):
    """Calcola la dimensione del modello"""
    param_count = sum(p.numel() for p in model.parameters())
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    size_mb = (param_size + buffer_size) / 1024**2
    return param_count, size_mb

# Test delle risorse
print("=== RESOURCE MONITORING ===")
monitor_resources()

# ‚úÖ Crea un decoder dinamico di test (senza parametro S!)
test_decoder = DynamicDecoder(d_model=256, nhead=4, num_layers=3)
param_count, size_mb = get_model_size(test_decoder)
print(f"Dynamic Decoder: {param_count:,} parametri ({size_mb:.1f} MB)")

del test_decoder  # Libera memoria
clear_memory()

In [None]:
# Test del decoder dinamico prima del training
def test_dynamic_decoder():
    """Test del decoder dinamico con sequenze di lunghezze diverse"""
    print("üîß Testing Dynamic Decoder...")
    
    # Parametri di test
    B, d_model = 4, 256
    
    # ‚úÖ Crea decoder dinamico (SENZA parametro S!)
    decoder = DynamicDecoder(d_model=d_model, nhead=4, num_layers=2)
    
    # Test con sequenze di lunghezze diverse
    test_cases = [
        (2, "sequenza molto corta"),
        (4, "sequenza standard"), 
        (6, "sequenza lunga"),
        (8, "sequenza molto lunga")
    ]
    
    for S, description in test_cases:
        print(f"\n--- Test: {description} (S={S}) ---")
        
        # Input di test
        content_emb = torch.randn(B, S, d_model)  # [B, S, d_model] - S dinamico!
        class_emb = torch.randn(B, d_model)       # [B, d_model]
        y_target = torch.randn(B, S, 2, 287, 513)  # [B, S, 2, 287, 513] - S dinamico!
        
        print(f"Content embedding shape: {content_emb.shape}")
        print(f"Class embedding shape: {class_emb.shape}")
        print(f"Target shape: {y_target.shape}")
        
        # Test training mode
        decoder.train()
        try:
            output = decoder(content_emb, class_emb, y=y_target)
            print(f"‚úÖ Training mode output shape: {output.shape}")
            assert output.shape == y_target.shape, f"Shape mismatch: {output.shape} vs {y_target.shape}"
        except Exception as e:
            print(f"‚ùå Training mode error: {e}")
            return False
        
        # Test inference mode con target_length specificato
        decoder.eval()
        try:
            with torch.no_grad():
                output = decoder(content_emb, class_emb, target_length=S)
                print(f"‚úÖ Inference mode output shape: {output.shape}")
                assert output.shape == (B, S, 2, 287, 513), f"Inference shape mismatch: {output.shape}"
        except Exception as e:
            print(f"‚ùå Inference mode error: {e}")
            return False
        
        # Test inference mode senza target_length (usa default)
        try:
            with torch.no_grad():
                output_auto = decoder(content_emb, class_emb)
                print(f"‚úÖ Auto-length inference output shape: {output_auto.shape}")
        except Exception as e:
            print(f"‚ùå Auto-length inference error: {e}")
            return False
        
        # Test loss
        try:
            loss_dict = compute_comprehensive_loss(output, y_target)
            print(f"‚úÖ Loss computation successful: {loss_dict['total_loss'].item():.4f}")
        except Exception as e:
            print(f"‚ùå Loss computation error: {e}")
            return False
    
    print("\nüéâ Dynamic Decoder test successful! Decoder can handle variable sequence lengths!")
    return True

# Esegui test
print("Testing the new Dynamic Decoder...")
if test_dynamic_decoder():
    print("‚úÖ Dynamic Decoder pronto per il training!")
else:
    print("‚ùå Dynamic Decoder ha problemi, controlla gli errori.")

In [None]:
# Debug utilities per il training dinamico
def debug_shapes(x, content_emb, style_emb, class_emb, step="forward"):
    """Debug delle dimensioni durante il training"""
    print(f"\n=== DEBUG SHAPES - {step} ===")
    print(f"Input x shape: {x.shape}")
    print(f"Content embedding shape: {content_emb.shape}")
    print(f"Style embedding shape: {style_emb.shape}")
    print(f"Class embedding shape: {class_emb.shape}")
    
    # Verifica dimensioni attese
    B, S, C, H, W = x.shape
    B_c, S_c, D_c = content_emb.shape
    B_s, D_s = style_emb.shape
    B_cl, D_cl = class_emb.shape
    
    print(f"Batch size: {B}")
    print(f"Sequence length: {S}")
    print(f"STFT dims: {C}x{H}x{W}")
    
    # Verifica coerenza
    assert B == B_c == B_s == B_cl, f"Batch size mismatch: {B}, {B_c}, {B_s}, {B_cl}"
    assert S == S_c, f"Sequence length mismatch: {S}, {S_c}"
    assert D_c == D_s == D_cl, f"Embedding dimension mismatch: {D_c}, {D_s}, {D_cl}"
    
    print("‚úÖ All dimensions are consistent!")
    return True

# Versione safe del training con debug del decoder dinamico
def safe_debug_train():
    """Training con debug delle dimensioni per decoder dinamico"""
    print(f"Debug training on device: {device}")
    
    # ‚úÖ Inizializza modelli con decoder dinamico
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    decoder = DynamicDecoder(d_model=TRANSFORMER_DIM, nhead=4, num_layers=2).to(device)  # Senza S!
    
    # Test con un batch
    dataloader = get_dummy_dataloader(batch_size=2)  # Batch piccolo per debug
    
    print("\nüîç Testing forward pass with Dynamic Decoder...")
    
    for i, (x, labels) in enumerate(dataloader):
        x, labels = x.to(device), labels.to(device)
        
        print(f"\nBatch {i+1} shapes:")
        print(f"  x: {x.shape}")
        print(f"  labels: {labels.shape}")
        
        # Forward pass encoders
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Debug shapes
        debug_shapes(x, content_emb, style_emb, class_emb)
        
        # Test decoder dinamico
        print(f"\nüß™ Testing Dynamic Decoder...")
        try:
            # Training mode con teacher forcing
            decoder.train()
            reconstructed = decoder(content_emb, style_emb, y=x)
            print(f"‚úÖ Training mode output shape: {reconstructed.shape}")
            print(f"Expected shape: {x.shape}")
            assert reconstructed.shape == x.shape, f"Shape mismatch: {reconstructed.shape} vs {x.shape}"
            
            # Inference mode
            decoder.eval()
            with torch.no_grad():
                reconstructed_inf = decoder(content_emb, style_emb)
                print(f"‚úÖ Inference mode output shape: {reconstructed_inf.shape}")
            
            # Test loss
            loss_dict = compute_comprehensive_loss(reconstructed, x)
            print(f"‚úÖ Loss computation: {loss_dict['total_loss'].item():.4f}")
            
            print("‚úÖ Dynamic Decoder test successful!")
            
        except Exception as e:
            print(f"‚ùå Dynamic Decoder error: {e}")
            import traceback
            traceback.print_exc()
            return False
        
        # Test solo il primo batch
        break
    
    print("\nüéâ Debug test completed successfully!")
    print("üöÄ Dynamic Decoder is ready for full training!")
    return True

# Test delle dimensioni del modello dinamico
def check_dynamic_model_size():
    """Controlla le dimensioni del modello dinamico"""
    print("\nüìä Checking Dynamic Model Sizes...")
    
    # Crea modelli
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM)
    decoder = DynamicDecoder(d_model=TRANSFORMER_DIM, nhead=4, num_layers=3)
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM)
    
    models = [
        ("Style Encoder", style_encoder),
        ("Content Encoder", content_encoder), 
        ("Dynamic Decoder", decoder),
        ("Discriminator", discriminator)
    ]
    
    total_params = 0
    total_size = 0
    
    for name, model in models:
        param_count, size_mb = get_model_size(model)
        total_params += param_count
        total_size += size_mb
        print(f"{name}: {param_count:,} parametri ({size_mb:.1f} MB)")
    
    print(f"\nüìà TOTALE: {total_params:,} parametri ({total_size:.1f} MB)")
    
    # Stima memoria per batch_size=8
    estimated_memory = total_size * 3 + (8 * 4 * 2 * 287 * 513 * 4) / (1024**2)  # modello + gradients + data
    print(f"üíæ Memoria stimata (batch_size=8): ~{estimated_memory:.0f} MB")
    print(f"üñ•Ô∏è Using device: {device}")
    
    if estimated_memory < 10000:  # <10GB
        print("‚úÖ Dovrebbe funzionare su Colab Free!")
    else:
        print("‚ö†Ô∏è Potrebbe essere troppo per Colab Free")

# Esegui controlli
check_dynamic_model_size()

# Uncomment to run debug test
# safe_debug_train()

In [None]:
# Training con monitoraggio delle risorse
def safe_train():
    """Training con monitoraggio delle risorse per Colab"""
    try:
        print(f"üöÄ Avvio training con monitoraggio risorse su {device}...")
        monitor_resources()
        
        # Mostra info GPU se disponibile
        if torch.cuda.is_available():
            print(f"üíæ GPU Memory before training: {torch.cuda.memory_allocated()/1024**2:.1f}MB / {torch.cuda.memory_reserved()/1024**2:.1f}MB")
        
        # Avvia il training
        train()
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("üí• ERRORE: Memoria GPU insufficiente!")
            print("Prova a ridurre:")
            print("- BATCH_SIZE (attualmente {})".format(BATCH_SIZE))
            print("- NUM_FRAMES (attualmente {})".format(NUM_FRAMES))
            print("- TRANSFORMER_DIM (attualmente {})".format(TRANSFORMER_DIM))
            clear_memory()
        else:
            print(f"‚ùå Errore durante il training: {e}")
    except KeyboardInterrupt:
        print("‚èπÔ∏è Training interrotto dall'utente")
        clear_memory()
    except Exception as e:
        print(f"‚ùå Errore imprevisto: {e}")
        clear_memory()
    finally:
        print(f"üèÅ Training terminato su {device}")
        monitor_resources()

# Uncomment to run training
# safe_train()