In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

# models
from style_encoder import StyleEncoder, initialize_weights
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss  # Nuovo decoder dinamico
from losses import infoNCE_loss, margin_loss, adversarial_loss, disentanglement_loss
from Dataloader import get_dataloader

# device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# hyperparameters
EPOCHS = 50
BATCH_SIZE = 8
LR_GEN = 5e-5  # Learning rate per Encoders + Decoder
LR_DISC = 1e-5  # Learning rate per Discriminator
TRANSFORMER_DIM = 256
NUM_FRAMES = 4
STFT_T, STFT_F = 287, 513 # Dimensioni STFT (n_fft=1024)
CQT_T, CQT_F = 287, 84   # Dimensioni CQT

# loss weights
LAMBDA_RECON = 1.0
LAMBDA_INFO_NCE = 0.5
LAMBDA_MARGIN = 0.5
LAMBDA_DISENTANGLE = 0.5
LAMBDA_ADV_GEN = 0.01 # Peso per la loss avversaria del generatore

# Pesi per le loss comprehensive del decoder
LAMBDA_TEMPORAL = 0.3
LAMBDA_PHASE = 0.2
LAMBDA_SPECTRAL = 0.1
LAMBDA_CONSISTENCY = 0.1

MODEL_SAVE_PATH = "./saved_models"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

Using device: cuda
GPU: NVIDIA GeForce RTX 3080
GPU Memory: 10.0 GB


In [2]:

# models
from style_encoder import StyleEncoder, initialize_weights
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss  # Nuovo decoder dinamico
from losses import infoNCE_loss, margin_loss, adversarial_loss, disentanglement_loss
from Dataloader import get_dataloader

# device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# hyperparameters
EPOCHS = 50
BATCH_SIZE = 8
LR_GEN = 2e-4  # Learning rate per Encoders + Decoder
LR_DISC = 1e-4  # Learning rate per Discriminator
TRANSFORMER_DIM = 256
NUM_FRAMES = 4
STFT_T, STFT_F = 287, 513 # Dimensioni STFT (n_fft=1024)
CQT_T, CQT_F = 287, 84   # Dimensioni CQT

# loss weights
LAMBDA_RECON = 10.0
LAMBDA_INFO_NCE = 1.0
LAMBDA_MARGIN = 1.0
LAMBDA_DISENTANGLE = 0.1
LAMBDA_ADV_GEN = 0.5 # Peso per la loss avversaria del generatore

# Pesi per le loss comprehensive del decoder
LAMBDA_TEMPORAL = 0.3
LAMBDA_PHASE = 0.2
LAMBDA_SPECTRAL = 0.1
LAMBDA_CONSISTENCY = 0.1

MODEL_SAVE_PATH = "./saved_models"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

Using device: cuda
GPU: NVIDIA GeForce RTX 3080
GPU Memory: 10.0 GB


aggiungere data augmentation

In [3]:
def init_weights_conservative(m):
    """
    Inizializzazione conservativa dei pesi per prevenire NaN
    """
    if isinstance(m, nn.Conv2d):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

In [4]:
def train():
    print(f"Training on device: {device}")

    # load models
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM).to(device)

    # decoder
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,  # Ridotto per Colab
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    ).to(device)
    
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM).to(device)

    # initialize weights
    # initialize_weights(style_encoder)
    # initialize_weights(content_encoder)
    # initialize_weights(discriminator)
    # decoder già ha la sua inizializzazione
    models = [style_encoder, content_encoder, discriminator, decoder]
    model_names = ["style_encoder", "content_encoder", "discriminator", "decoder"]

    for model, name in zip(models, model_names):
        model.apply(init_weights_conservative)
        print(f"✅ {name} initialized")

    # optimizer for generators (style encoder, content encoder, decoder)
    optimizer_G = optim.Adam(
        list(style_encoder.parameters()) + list(content_encoder.parameters()) + list(decoder.parameters()),
        lr=LR_GEN, betas=(0.5, 0.999)
    )
    
    # optimizer for discriminator
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR_DISC, betas=(0.5, 0.999))

    # Create train and validation dataloaders
    train_dataloader = get_dataloader(
        piano_dir="dataset/train/piano",
        violin_dir="dataset/train/violin",
        batch_size=BATCH_SIZE,
        shuffle=True,
        stats_path="stats_stft_cqt.npz"
    )
    
    val_dataloader = get_dataloader(
        piano_dir="dataset/val/piano", 
        violin_dir="dataset/val/violin",
        batch_size=BATCH_SIZE,
        shuffle=False,  # No shuffle for validation
        stats_path="stats_stft_cqt.npz"
    )
    
    print(f"Training batches: {len(train_dataloader)}")
    print(f"Validation batches: {len(val_dataloader)}")
    
    # loss function for reconstruction
    def recon_loss_fn(output, target):
        loss_dict = compute_comprehensive_loss(
            output, target, 
            lambda_temporal=LAMBDA_TEMPORAL,
            lambda_phase=LAMBDA_PHASE,
            lambda_spectral=LAMBDA_SPECTRAL,
            lambda_consistency=LAMBDA_CONSISTENCY
        )
        return loss_dict['total_loss'], loss_dict
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []

    # train loop with epoch progress bar
    epoch_pbar = tqdm(range(EPOCHS), desc="Training", unit="epoch")
    
    for epoch in epoch_pbar:
        epoch_pbar.set_description(f"Epoch {epoch+1}/{EPOCHS}")
        
        # ================================================================== #
        #                             TRAINING                               #
        # ================================================================== #
        style_encoder.train()
        content_encoder.train()
        decoder.train()
        discriminator.train()
        
        train_loss_epoch = 0
        train_recon_loss_epoch = 0
        train_batches = 0
        
        for i, (x, labels) in enumerate(train_dataloader):
            x, labels = x.to(device), labels.to(device) # x: (B, S, 2, T, F)
            stft_part = x[:, :, :, :, :STFT_F]  # STFT part

            # ================================================================== #
            #                             Discriminator                          #
            # ================================================================== #
            optimizer_D.zero_grad()
            
            # with torch.no_grad() to avoid computing gradients for the encoders    <----------------
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
            
            # adversarial loss for the discriminator
            discriminator_loss, _ = adversarial_loss(style_emb.detach(), class_emb.detach(), 
                                                     content_emb.detach(), discriminator, labels, 
                                                     compute_for_discriminator=True)
            # discriminator_loss = torch.tensor(0.0, device=device)
            
            discriminator_loss.backward()
            optimizer_D.step()

            # ================================================================== #
            #               Generators (Style Encoder, Content Encoder)          #
            # ================================================================== #
            optimizer_G.zero_grad()

            # forward pass
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)

            # adversarial loss for the generator
            _, adv_generator_loss = adversarial_loss(style_emb, class_emb, content_emb, discriminator, labels,
                                                 compute_for_discriminator=False)

            # adv_generator_loss = torch.tensor(0.0, device=device)

            # disentanglement loss
            disent_loss = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)

            if len(torch.unique(labels)) > 1:
                # contrastive losses
                loss_infonce = infoNCE_loss(style_emb, labels)
                loss_margin = margin_loss(class_emb)
            else:
                # Fallback se tutti i label sono uguali in questo batch
                loss_infonce = torch.tensor(0.0, device=device)
                loss_margin = torch.tensor(0.0, device=device)

            # reconstruction loss
            reconstructed_spec = decoder(content_emb, style_emb, y=stft_part)  # y=x per teacher forcing
            loss_recon, loss_dict = recon_loss_fn(reconstructed_spec, stft_part)

            if torch.isnan(loss_recon):
                print(f"⚠️ NaN detected in reconstruction loss at batch {i+1}")
                print(f"   Reconstructed spec stats: min={reconstructed_spec.min():.4f}, max={reconstructed_spec.max():.4f}")
                print(f"   Target spec stats: min={stft_part.min():.4f}, max={stft_part.max():.4f}")
                continue  # Skip this batch
            
            # total generator loss
            total_gen_loss = (
                LAMBDA_RECON * loss_recon +
                LAMBDA_INFO_NCE * loss_infonce +
                LAMBDA_MARGIN * loss_margin +
                LAMBDA_DISENTANGLE * disent_loss 
                # LAMBDA_ADV_GEN * adv_generator_loss
            )

            total_gen_loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(
                list(style_encoder.parameters()) + 
                list(content_encoder.parameters()) + 
                list(decoder.parameters()), 
                max_norm=0.5
            )

            # AGGIUNGI CONTROLLO NaN sui gradienti
            has_nan_grad = False
            for name, param in decoder.named_parameters():
                if param.grad is not None and torch.isnan(param.grad).any():
                    print(f"⚠️ NaN gradient detected in {name}")
                    has_nan_grad = True
            
            if has_nan_grad:
                print("⚠️ Skipping optimizer step due to NaN gradients")
                continue

            optimizer_G.step()

            # Accumulate losses
            train_loss_epoch += total_gen_loss.item()
            train_recon_loss_epoch += loss_recon.item()
            train_batches += 1

            # Print batch metrics every batch
            print(f"Epoch {epoch+1}/{EPOCHS} - Batch {i+1}/{len(train_dataloader)} | "
                  f"D_loss: {discriminator_loss.item():.4f} | "
                  f"G_loss: {total_gen_loss.item():.4f} | "
                  f"Recon: {loss_recon.item():.4f} | "
                  f"InfoNCE: {loss_infonce.item():.4f} | "
                  f"Margin: {loss_margin.item():.4f} | "
                  f"Disentangle: {disent_loss.item():.4f}")
        
        # Average training losses
        avg_train_loss = train_loss_epoch / train_batches
        avg_train_recon_loss = train_recon_loss_epoch / train_batches
        train_losses.append(avg_train_loss)
        
        # ================================================================== #
        #                            VALIDATION                              #
        # ================================================================== #
        style_encoder.eval()
        content_encoder.eval()
        decoder.eval()
        discriminator.eval()
        
        val_loss_epoch = 0
        val_recon_loss_epoch = 0
        val_batches = 0
        
        print(f"\n🔍 Running validation for epoch {epoch+1}...")
        
        with torch.no_grad():
            for i, (x, labels) in enumerate(val_dataloader):
                x, labels = x.to(device), labels.to(device)
                stft_part = x[:, :, :, :, :STFT_F]

                # Forward pass
                style_emb, class_emb = style_encoder(x, labels)
                content_emb = content_encoder(x)

                # Validation losses (only the main ones)
                if len(torch.unique(labels)) > 1:
                    loss_infonce = infoNCE_loss(style_emb, labels)
                    loss_margin = margin_loss(class_emb)
                else:
                    loss_infonce = torch.tensor(0.0, device=device)
                    loss_margin = torch.tensor(0.0, device=device)

                disent_loss = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)
                
                # Reconstruction loss
                reconstructed_spec = decoder(content_emb, style_emb, y=stft_part)
                loss_recon, _ = recon_loss_fn(reconstructed_spec, stft_part)
                
                # Total validation loss
                total_val_loss = (
                    LAMBDA_RECON * loss_recon +
                    LAMBDA_INFO_NCE * loss_infonce +
                    LAMBDA_MARGIN * loss_margin +
                    LAMBDA_DISENTANGLE * disent_loss
                )

                val_loss_epoch += total_val_loss.item()
                val_recon_loss_epoch += loss_recon.item()
                val_batches += 1

        # Average validation losses
        avg_val_loss = val_loss_epoch / val_batches
        avg_val_recon_loss = val_recon_loss_epoch / val_batches
        val_losses.append(avg_val_loss)
        
        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
        print(f"    Train Loss: {avg_train_loss:.4f} | Train Recon: {avg_train_recon_loss:.4f}")
        print(f"    Val Loss:   {avg_val_loss:.4f} | Val Recon:   {avg_val_recon_loss:.4f}")
        
        # Update progress bar with current losses
        epoch_pbar.set_postfix({
            'Train_Loss': f'{avg_train_loss:.4f}',
            'Val_Loss': f'{avg_val_loss:.4f}',
            'Best_Val': f'{best_val_loss:.4f}'
        })
        
        # Save best model based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"New best validation loss: {best_val_loss:.4f}. Saving model...")
            
            torch.save({
                'epoch': epoch,
                'style_encoder_state_dict': style_encoder.state_dict(),
                'content_encoder_state_dict': content_encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'best_val_loss': best_val_loss,
                'train_losses': train_losses,
                'val_losses': val_losses,
            }, os.path.join(MODEL_SAVE_PATH, 'best_model.pth'))
        
        # Early stopping check (opzionale)
        if epoch > 10 and avg_val_loss > max(val_losses[-5:]):
            print("⚠️ Validation loss not improving. Consider early stopping.")
    
    print(f"\nTraining completed!")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    return train_losses, val_losses

In [5]:
# Avvia il training e cattura le loss curves
train_losses, val_losses = train()

Training on device: cuda
✅ style_encoder initialized
✅ content_encoder initialized
✅ discriminator initialized
✅ decoder initialized
Training batches: 77
Validation batches: 10
Training batches: 77
Validation batches: 10


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/50 - Batch 1/77 | D_loss: 1.7333 | G_loss: 41.8426 | Recon: 3.9903 | InfoNCE: 1.9390 | Margin: 0.0000 | Disentangle: 0.0072
Epoch 1/50 - Batch 2/77 | D_loss: 1.7329 | G_loss: 28.8424 | Recon: 2.6938 | InfoNCE: 1.9045 | Margin: 0.0000 | Disentangle: 0.0012
Epoch 1/50 - Batch 2/77 | D_loss: 1.7329 | G_loss: 28.8424 | Recon: 2.6938 | InfoNCE: 1.9045 | Margin: 0.0000 | Disentangle: 0.0012
Epoch 1/50 - Batch 3/77 | D_loss: 1.7326 | G_loss: 67.2175 | Recon: 6.5580 | InfoNCE: 1.6376 | Margin: 0.0000 | Disentangle: 0.0005
Epoch 1/50 - Batch 3/77 | D_loss: 1.7326 | G_loss: 67.2175 | Recon: 6.5580 | InfoNCE: 1.6376 | Margin: 0.0000 | Disentangle: 0.0005
Epoch 1/50 - Batch 4/77 | D_loss: 1.7293 | G_loss: 27.1641 | Recon: 2.5911 | InfoNCE: 1.2531 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 1/50 - Batch 4/77 | D_loss: 1.7293 | G_loss: 27.1641 | Recon: 2.5911 | InfoNCE: 1.2531 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 1/50 - Batch 5/77 | D_loss: 1.7246 | G_loss: 35.7803 | Recon: 3.4675 |

Epoch 1/50:   2%|▏         | 1/50 [00:57<47:12, 57.81s/epoch, Train_Loss=37.4376, Val_Loss=41.9694, Best_Val=inf]


Epoch 1/50 Summary:
    Train Loss: 37.4376 | Train Recon: 3.6004
    Val Loss:   41.9694 | Val Recon:   4.0653
New best validation loss: 41.9694. Saving model...


Epoch 2/50:   2%|▏         | 1/50 [00:57<47:12, 57.81s/epoch, Train_Loss=37.4376, Val_Loss=41.9694, Best_Val=inf]

Epoch 2/50 - Batch 1/77 | D_loss: 0.9301 | G_loss: 27.5369 | Recon: 2.6400 | InfoNCE: 1.1364 | Margin: 0.0000 | Disentangle: 0.0002
Epoch 2/50 - Batch 2/77 | D_loss: 0.8717 | G_loss: 30.3047 | Recon: 2.9203 | InfoNCE: 1.1016 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 2/50 - Batch 2/77 | D_loss: 0.8717 | G_loss: 30.3047 | Recon: 2.9203 | InfoNCE: 1.1016 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 2/50 - Batch 3/77 | D_loss: 0.8649 | G_loss: 58.1938 | Recon: 5.7094 | InfoNCE: 1.0998 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 2/50 - Batch 3/77 | D_loss: 0.8649 | G_loss: 58.1938 | Recon: 5.7094 | InfoNCE: 1.0998 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 2/50 - Batch 4/77 | D_loss: 0.8638 | G_loss: 33.1460 | Recon: 3.2045 | InfoNCE: 1.1008 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 2/50 - Batch 4/77 | D_loss: 0.8638 | G_loss: 33.1460 | Recon: 3.2045 | InfoNCE: 1.1008 | Margin: 0.0000 | Disentangle: 0.0001
Epoch 2/50 - Batch 5/77 | D_loss: 0.8529 | G_loss: 103.2068 | Recon: 10.2107

Epoch 2/50:   2%|▏         | 1/50 [01:32<1:15:11, 92.08s/epoch, Train_Loss=37.4376, Val_Loss=41.9694, Best_Val=inf]



KeyboardInterrupt: 

In [1]:
import torch
import numpy as np
import random
import IPython.display as ipd
from utilityFunctions import inverse_STFT, plot_stft
from dataloader import get_dataloader

def play_batch(batch_size=4, plot_stft_flag=False):
    """
    Carica un batch randomico dal dataloader, stampa le label e lo strumento associato,
    riproduce l'audio ricostruito dalla STFT e opzionalmente mostra la STFT.
    """
    # Crea il dataloader (usa shuffle per randomizzare i batch)
    dataloader = get_dataloader(
        piano_dir="dataset/train/piano",
        violin_dir="dataset/train/violin",
        batch_size=batch_size,
        shuffle=True,
        stats_path="stats_stft_cqt.npz"
    )
    print(f"\nDataloader creato con {len(dataloader)} batch di dimensione {batch_size}")
          
    # Scegli un batch randomico
    num_batches = len(dataloader)
    batch_idx = random.randint(0, num_batches - 1)
    for i, (x, labels) in enumerate(dataloader):
        if i == batch_idx:
            break

    print(f"\n🎲 Batch randomico scelto: {batch_idx+1}/{num_batches}")
    print(f"Shape batch: {x.shape}, Shape labels: {labels.shape}")
    print(f"Labels batch: {labels.tolist()}")

    # Associa label a strumento
    for idx, label in enumerate(labels):
        strumento = "Piano" if label.item() == 0 else "Violin"
        print(f"Sample {idx}: Label={label.item()} ({strumento})")

    # Per ogni audio nel batch: ricostruisci e riproduci
    for idx in range(x.shape[0]):
        label = labels[idx].item()
        strumento = "Piano" if label == 0 else "Violin"
        # Prendi la prima finestra e il primo canale (shape: (2, T, F))
        stft_tensor = x[idx, 0]  # (2, T, F_stft+CQT)
        # Prendi solo la parte STFT (usa costante STFT_F)
        stft_part = stft_tensor[:, :, :513]  # (2, T, 513)
        # La funzione inverse_STFT si aspetta (2, T, F), quindi non serve permutare
        # Ricostruisci audio
        try:
            audio = inverse_STFT(stft_part)
            audio = audio / (torch.max(torch.abs(audio)) + 1e-8)
            print(f"\n🔊 Sample {idx} - {strumento}:")
            display(ipd.Audio(audio.cpu().numpy(), rate=22050))
        except Exception as e:
            print(f"Errore nella ricostruzione audio: {e}")

        # Plot STFT se richiesto
        if plot_stft_flag:
            print(f"STFT di Sample {idx} - {strumento}:")
            plot_stft(stft_part)

# Esempio di utilizzo:
play_batch(batch_size=8, plot_stft_flag=False)

✅ Loaded separate statistics:
  Piano: stats_stft_cqt_piano.npz
  Violin: stats_stft_cqt_violin.npz

Dataloader creato con 77 batch di dimensione 8

🎲 Batch randomico scelto: 72/77
Shape batch: torch.Size([8, 3, 2, 287, 597]), Shape labels: torch.Size([8])
Labels batch: [0, 0, 0, 0, 1, 1, 1, 1]
Sample 0: Label=0 (Piano)
Sample 1: Label=0 (Piano)
Sample 2: Label=0 (Piano)
Sample 3: Label=0 (Piano)
Sample 4: Label=1 (Violin)
Sample 5: Label=1 (Violin)
Sample 6: Label=1 (Violin)
Sample 7: Label=1 (Violin)

🔊 Sample 0 - Piano:



🔊 Sample 1 - Piano:



🔊 Sample 2 - Piano:



🔊 Sample 3 - Piano:



🔊 Sample 4 - Violin:



🔊 Sample 5 - Violin:



🔊 Sample 6 - Violin:



🔊 Sample 7 - Violin:


In [4]:
def count_trainable_parameters():
    """Conta i parametri allenabili di tutto il modello"""
    print("=" * 60)
    print("📊 ANALISI PARAMETRI ALLENABILI DEL MODELLO")
    print("=" * 60)
    
    # Crea tutti i modelli
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM)
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    )
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM)
    
    models = [
        ("Style Encoder", style_encoder),
        ("Content Encoder", content_encoder),
        ("Decoder", decoder),
        ("Discriminator", discriminator)
    ]
    
    total_params = 0
    total_trainable_params = 0
    
    for name, model in models:
        # Conta parametri totali
        model_params = sum(p.numel() for p in model.parameters())
        # Conta parametri allenabili
        model_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        total_params += model_params
        total_trainable_params += model_trainable_params
        
        # Dimensione in MB
        param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
        
        print(f"{name:>15}: {model_params:>10,} parametri ({model_trainable_params:>10,} allenabili) | {param_size_mb:>6.1f} MB")
    
    print("-" * 60)
    print(f"{'TOTALE':>15}: {total_params:>10,} parametri ({total_trainable_params:>10,} allenabili)")
    
    # Calcola dimensioni totali
    total_size_mb = 0
    for name, model in models:
        total_size_mb += sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
    
    print(f"{'DIMENSIONE':>15}: {total_size_mb:>47.1f} MB")
    
    # Stima memoria GPU per training
    # Modello + gradients + optimizer states (Adam ~ 2x params) + attivazioni
    estimated_gpu_memory = total_size_mb * 4  # rough estimate
    print(f"{'MEM. STIMATA':>15}: {estimated_gpu_memory:>44.0f} MB (solo modello)")
    
    # Analisi per batch size
    print("\n" + "=" * 60)
    print("💾 ANALISI MEMORIA PER BATCH SIZE")
    print("=" * 60)
    
    batch_sizes = [4, 8, 16, 32]
    for bs in batch_sizes:
        # Stima memoria per batch (approssimativa)
        # Input: (B, S, 2, T, F) * 4 bytes per float32
        input_size_mb = bs * NUM_FRAMES * 2 * STFT_T * (STFT_F + CQT_F) * 4 / (1024**2)
        
        # Memoria totale stimata
        total_mem_mb = estimated_gpu_memory + input_size_mb * 2  # input + gradients
        
        status = "✅" if total_mem_mb < 8000 else "⚠️" if total_mem_mb < 12000 else "❌"
        print(f"Batch size {bs:>2}: {total_mem_mb:>6.0f} MB {status}")
    
    print("\n" + "=" * 60)
    print("🎯 CONFIGURAZIONE RACCOMANDATA")
    print("=" * 60)
    
    if total_trainable_params < 50_000_000:  # 50M
        print("✅ Modello di dimensioni ragionevoli per il training")
    elif total_trainable_params < 100_000_000:  # 100M
        print("⚠️ Modello grande - considerare mixed precision training")
    else:
        print("❌ Modello molto grande - ridurre dimensioni o usare tecniche avanzate")
    
    # Raccomandazioni finali
    print(f"\n🚀 Per GPU con 12GB: Batch size raccomandato = {8 if estimated_gpu_memory < 4000 else 4}")
    print(f"🔧 Per GPU con 8GB:  Batch size raccomandato = {4 if estimated_gpu_memory < 3000 else 2}")
    
    return total_trainable_params, total_size_mb

# Esegui l'analisi
total_params, model_size = count_trainable_parameters()

📊 ANALISI PARAMETRI ALLENABILI DEL MODELLO
  Style Encoder: 12,905,312 parametri (12,905,312 allenabili) |   49.2 MB
Content Encoder:  4,450,480 parametri ( 4,450,480 allenabili) |   17.0 MB
        Decoder:  3,681,515 parametri ( 3,681,515 allenabili) |   14.0 MB
  Discriminator:     49,666 parametri (    49,666 allenabili) |    0.2 MB
------------------------------------------------------------
         TOTALE: 21,086,973 parametri (21,086,973 allenabili)
     DIMENSIONE:                                            80.4 MB
   MEM. STIMATA:                                          322 MB (solo modello)

💾 ANALISI MEMORIA PER BATCH SIZE
Batch size  4:    364 MB ✅
Batch size  8:    405 MB ✅
Batch size 16:    489 MB ✅
Batch size 32:    656 MB ✅

🎯 CONFIGURAZIONE RACCOMANDATA
✅ Modello di dimensioni ragionevoli per il training

🚀 Per GPU con 12GB: Batch size raccomandato = 8
🔧 Per GPU con 8GB:  Batch size raccomandato = 4
