In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

# models
from style_encoder import StyleEncoder, initialize_weights
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss  # Nuovo decoder dinamico
from losses import infoNCE_loss, margin_loss, adversarial_loss, disentanglement_loss
from dummy_dataloader import get_dummy_dataloader

# device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# hyperparameters
EPOCHS = 50
BATCH_SIZE = 8
LR_GEN = 1e-4  # Learning rate per Encoders + Decoder
LR_DISC = 1e-4  # Learning rate per Discriminator
TRANSFORMER_DIM = 256
NUM_FRAMES = 4 # S (ora solo per riferimento, il decoder è dinamico)
STFT_T, STFT_F = 287, 513 # Dimensioni STFT (assumendo n_fft=1024)
CQT_T, CQT_F = 287, 84   # Dimensioni CQT

# loss weights
LAMBDA_RECON = 10.0
LAMBDA_INFO_NCE = 1.0
LAMBDA_MARGIN = 1.0
LAMBDA_DISENTANGLE = 0.1
LAMBDA_ADV_GEN = 0.5 # Peso per la loss avversaria del generatore

# Pesi per le loss comprehensive del decoder
LAMBDA_TEMPORAL = 0.3
LAMBDA_PHASE = 0.2
LAMBDA_SPECTRAL = 0.1
LAMBDA_CONSISTENCY = 0.1

MODEL_SAVE_PATH = "./saved_models"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

Using device: cuda
GPU: NVIDIA GeForce RTX 3080
GPU Memory: 10.0 GB


aggiungere data augmentation

In [11]:
def train():
    print(f"Training on device: {device}")

    # load models
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM).to(device)

    # decoder
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,  # Ridotto per Colab
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    ).to(device)
    
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM).to(device)

    # initialize weights
    initialize_weights(style_encoder)
    initialize_weights(content_encoder)
    initialize_weights(discriminator)
    # decoder già ha la sua inizializzazione

    # optimizer for generators (style encoder, content encoder, decoder)
    optimizer_G = optim.Adam(
        list(style_encoder.parameters()) + list(content_encoder.parameters()) + list(decoder.parameters()),
        lr=LR_GEN, betas=(0.5, 0.999)
    )
    
    # optimizer for discriminator
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR_DISC, betas=(0.5, 0.999))

    dataloader = get_dummy_dataloader(batch_size=BATCH_SIZE)
    
    # loss function for reconstruction
    def recon_loss_fn(output, target):
        loss_dict = compute_comprehensive_loss(
            output, target, 
            lambda_temporal=LAMBDA_TEMPORAL,
            lambda_phase=LAMBDA_PHASE,
            lambda_spectral=LAMBDA_SPECTRAL,
            lambda_consistency=LAMBDA_CONSISTENCY
        )
        return loss_dict['total_loss'], loss_dict
    
    best_loss = float('inf')

    # train loop
    for epoch in range(EPOCHS):
        
        style_encoder.train()
        content_encoder.train()
        decoder.train()
        discriminator.train()
        
        for i, (x, labels) in enumerate(tqdm(dataloader, unit="batch", desc=f"Epoch {epoch+1}/{EPOCHS}")):
            x, labels = x.to(device), labels.to(device) # x: (B, S, 2, T, F)
            stft_part = x[:, :, :, :, :STFT_F]  # STFT part

            # ================================================================== #
            #                             Discriminator                          #
            # ================================================================== #
            optimizer_D.zero_grad()
            
            # compute embeddings with no_grad to avoid backpropagation through the encoders
            with torch.no_grad():
                style_emb, class_emb = style_encoder(x, labels)
                content_emb = content_encoder(x)
            
            # adversarial loss for the discriminator
            discriminator_loss, _ = adversarial_loss(style_emb.detach(), class_emb.detach(), 
                                                     content_emb.detach(), discriminator, labels, 
                                                     compute_for_discriminator=True)
            
            discriminator_loss.backward()
            optimizer_D.step()


            # ================================================================== #
            #               Generators (Style Encoder, Content Encoder)          #
            # ================================================================== #
            optimizer_G.zero_grad()

            # forward pass
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)

            # adversarial loss for the generator
            _, adv_generator_loss = adversarial_loss(style_emb, class_emb, content_emb, discriminator, labels,
                                                 compute_for_discriminator=False)

            # disentanglement loss
            disent_loss = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)

            if len(torch.unique(labels)) > 1:
                # contrastive losses
                loss_infonce = infoNCE_loss(style_emb, labels)
                loss_margin = margin_loss(class_emb)
            else:
                raise ValueError("Labels must contain at least two unique classes for contrastive losses.")

            # reconstruction loss
            reconstructed_spec = decoder(content_emb, style_emb, y=stft_part)  # y=x per teacher forcing
            loss_recon, loss_dict = recon_loss_fn(reconstructed_spec, stft_part)
            
            # total generator loss
            total_gen_loss = (
                LAMBDA_RECON * loss_recon +
                LAMBDA_INFO_NCE * loss_infonce +
                LAMBDA_MARGIN * loss_margin +
                LAMBDA_DISENTANGLE * disent_loss +
                LAMBDA_ADV_GEN * adv_generator_loss
            )

            total_gen_loss.backward()
            optimizer_G.step()

            # loss logging 
            if i % 1 == 0:
                # Converte tensor scalari a float per evitare errori
                temporal_loss_val = loss_dict['temporal_loss'].item() if loss_dict['temporal_loss'].numel() > 0 else 0.0
                spectral_loss_val = loss_dict['spectral_loss'].item() if loss_dict['spectral_loss'].numel() > 0 else 0.0
                
                tqdm.write(
                    f"Batch {i}/{len(dataloader)} | "
                    f"Discriminator loss: {discriminator_loss.item():.4f} | "
                    f"Total Generator loss: {total_gen_loss.item():.4f} | "
                    f"Reconstruction loss: {loss_recon.item():.4f} | "
                    f"  - MSE: {loss_dict['mse_loss'].item():.4f} | "
                    f"  - Magnitude: {loss_dict['mag_loss'].item():.4f} | "
                    f"  - Phase: {loss_dict['phase_loss'].item():.4f} | "
                    f"  - Temporal: {temporal_loss_val:.4f} | "
                    f"  - Spectral: {spectral_loss_val:.4f} | "
                    f"Adversary Generator loss: {adv_generator_loss.item():.4f} | "
                    f"Disentanglement loss: {disent_loss.item():.4f}"
                )
                
        # saving best model
        current_recon_loss = loss_recon.item()
        if current_recon_loss < best_loss:
            best_loss = current_recon_loss
            print(f"\nNew best reconstruction loss: {best_loss:.4f}. Saving model...")
            
            torch.save({
                'epoch': epoch,
                'style_encoder_state_dict': style_encoder.state_dict(),
                'content_encoder_state_dict': content_encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'best_loss': best_loss,
            }, os.path.join(MODEL_SAVE_PATH, 'best_model.pth'))

In [12]:
train()

Training on device: cuda


Epoch 1/50:   0%|          | 0/13 [00:00<?, ?batch/s]

Norm of S: 20.0259
Norm of C: 11.9199
Norm of K: 5.7436
Norm of L: 7.0742


Epoch 1/50:   8%|▊         | 1/13 [00:11<02:18, 11.51s/batch]

Batch 0/13 | Discriminator loss: 1.9618 | Total Generator loss: 36.2249 | Reconstruction loss: 3.4588 |   - MSE: 1.0003 |   - Magnitude: 2.0003 |   - Phase: 3.2910 |   - Temporal: 2.0003 |   - Spectral: 2.0010 | Adversary Generator loss: -0.6501 | Disentanglement loss: 0.0063
Norm of S: 19.5589
Norm of C: 11.3297
Norm of K: 5.8023
Norm of L: 7.1444
Norm of S: 19.5589
Norm of C: 11.3297
Norm of K: 5.8023
Norm of L: 7.1444


Epoch 1/50:  15%|█▌        | 2/13 [00:16<01:23,  7.58s/batch]

Batch 1/13 | Discriminator loss: 2.2692 | Total Generator loss: nan | Reconstruction loss: nan |   - MSE: nan |   - Magnitude: nan |   - Phase: nan |   - Temporal: nan |   - Spectral: nan | Adversary Generator loss: -0.3872 | Disentanglement loss: 0.0058
Norm of S: 21.3380
Norm of C: 11.0624
Norm of K: 5.9204
Norm of L: 7.3465
Norm of S: 21.3380
Norm of C: 11.0624
Norm of K: 5.9204
Norm of L: 7.3465


Epoch 1/50:  23%|██▎       | 3/13 [00:21<01:05,  6.59s/batch]

Batch 2/13 | Discriminator loss: 1.8719 | Total Generator loss: nan | Reconstruction loss: nan |   - MSE: nan |   - Magnitude: nan |   - Phase: nan |   - Temporal: nan |   - Spectral: nan | Adversary Generator loss: -0.6505 | Disentanglement loss: 0.0041
Norm of S: 19.5825
Norm of C: 10.9957
Norm of K: 5.6172
Norm of L: 7.1132
Norm of S: 19.5825
Norm of C: 10.9957
Norm of K: 5.6172
Norm of L: 7.1132


Epoch 1/50:  31%|███       | 4/13 [00:26<00:53,  5.89s/batch]

Batch 3/13 | Discriminator loss: 1.5510 | Total Generator loss: nan | Reconstruction loss: nan |   - MSE: nan |   - Magnitude: nan |   - Phase: nan |   - Temporal: nan |   - Spectral: nan | Adversary Generator loss: -0.6140 | Disentanglement loss: 0.0064
Norm of S: 19.6506
Norm of C: 10.6715
Norm of K: 5.6612
Norm of L: 7.1873
Norm of S: 19.6506
Norm of C: 10.6715
Norm of K: 5.6612
Norm of L: 7.1873


Epoch 1/50:  38%|███▊      | 5/13 [00:31<00:45,  5.63s/batch]

Batch 4/13 | Discriminator loss: 1.7785 | Total Generator loss: nan | Reconstruction loss: nan |   - MSE: nan |   - Magnitude: nan |   - Phase: nan |   - Temporal: nan |   - Spectral: nan | Adversary Generator loss: -0.6813 | Disentanglement loss: 0.0057
Norm of S: 19.3497
Norm of C: 10.0267
Norm of K: 5.6401
Norm of L: 7.2467
Norm of S: 19.3497
Norm of C: 10.0267
Norm of K: 5.6401
Norm of L: 7.2467


Epoch 1/50:  46%|████▌     | 6/13 [00:36<00:38,  5.47s/batch]

Batch 5/13 | Discriminator loss: 1.9700 | Total Generator loss: nan | Reconstruction loss: nan |   - MSE: nan |   - Magnitude: nan |   - Phase: nan |   - Temporal: nan |   - Spectral: nan | Adversary Generator loss: -0.6590 | Disentanglement loss: 0.0054
Norm of S: 18.9831
Norm of C: 10.1108
Norm of K: 5.7365
Norm of L: 7.2443
Norm of S: 18.9831
Norm of C: 10.1108
Norm of K: 5.7365
Norm of L: 7.2443


Epoch 1/50:  54%|█████▍    | 7/13 [00:42<00:32,  5.41s/batch]

Batch 6/13 | Discriminator loss: 1.8440 | Total Generator loss: nan | Reconstruction loss: nan |   - MSE: nan |   - Magnitude: nan |   - Phase: nan |   - Temporal: nan |   - Spectral: nan | Adversary Generator loss: -0.6850 | Disentanglement loss: 0.0052
Norm of S: 18.4729
Norm of C: 10.2994
Norm of K: 5.7524
Norm of L: 7.1844
Norm of S: 18.4729
Norm of C: 10.2994
Norm of K: 5.7524
Norm of L: 7.1844


Epoch 1/50:  54%|█████▍    | 7/13 [00:45<00:38,  6.44s/batch]



KeyboardInterrupt: 

In [9]:
def count_trainable_parameters():
    """Conta i parametri allenabili di tutto il modello"""
    print("=" * 60)
    print("📊 ANALISI PARAMETRI ALLENABILI DEL MODELLO")
    print("=" * 60)
    
    # Crea tutti i modelli
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM)
    decoder = Decoder(
        d_model=TRANSFORMER_DIM,
        nhead=4,
        num_layers=4,
        dim_feedforward=TRANSFORMER_DIM * 2,
        dropout=0.1
    )
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM)
    
    models = [
        ("Style Encoder", style_encoder),
        ("Content Encoder", content_encoder),
        ("Decoder", decoder),
        ("Discriminator", discriminator)
    ]
    
    total_params = 0
    total_trainable_params = 0
    
    for name, model in models:
        # Conta parametri totali
        model_params = sum(p.numel() for p in model.parameters())
        # Conta parametri allenabili
        model_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        total_params += model_params
        total_trainable_params += model_trainable_params
        
        # Dimensione in MB
        param_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
        
        print(f"{name:>15}: {model_params:>10,} parametri ({model_trainable_params:>10,} allenabili) | {param_size_mb:>6.1f} MB")
    
    print("-" * 60)
    print(f"{'TOTALE':>15}: {total_params:>10,} parametri ({total_trainable_params:>10,} allenabili)")
    
    # Calcola dimensioni totali
    total_size_mb = 0
    for name, model in models:
        total_size_mb += sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
    
    print(f"{'DIMENSIONE':>15}: {total_size_mb:>47.1f} MB")
    
    # Stima memoria GPU per training
    # Modello + gradients + optimizer states (Adam ~ 2x params) + attivazioni
    estimated_gpu_memory = total_size_mb * 4  # rough estimate
    print(f"{'MEM. STIMATA':>15}: {estimated_gpu_memory:>44.0f} MB (solo modello)")
    
    # Analisi per batch size
    print("\n" + "=" * 60)
    print("💾 ANALISI MEMORIA PER BATCH SIZE")
    print("=" * 60)
    
    batch_sizes = [4, 8, 16, 32]
    for bs in batch_sizes:
        # Stima memoria per batch (approssimativa)
        # Input: (B, S, 2, T, F) * 4 bytes per float32
        input_size_mb = bs * NUM_FRAMES * 2 * STFT_T * (STFT_F + CQT_F) * 4 / (1024**2)
        
        # Memoria totale stimata
        total_mem_mb = estimated_gpu_memory + input_size_mb * 2  # input + gradients
        
        status = "✅" if total_mem_mb < 8000 else "⚠️" if total_mem_mb < 12000 else "❌"
        print(f"Batch size {bs:>2}: {total_mem_mb:>6.0f} MB {status}")
    
    print("\n" + "=" * 60)
    print("🎯 CONFIGURAZIONE RACCOMANDATA")
    print("=" * 60)
    
    if total_trainable_params < 50_000_000:  # 50M
        print("✅ Modello di dimensioni ragionevoli per il training")
    elif total_trainable_params < 100_000_000:  # 100M
        print("⚠️ Modello grande - considerare mixed precision training")
    else:
        print("❌ Modello molto grande - ridurre dimensioni o usare tecniche avanzate")
    
    # Raccomandazioni finali
    print(f"\n🚀 Per GPU con 12GB: Batch size raccomandato = {8 if estimated_gpu_memory < 4000 else 4}")
    print(f"🔧 Per GPU con 8GB:  Batch size raccomandato = {4 if estimated_gpu_memory < 3000 else 2}")
    
    return total_trainable_params, total_size_mb

# Esegui l'analisi
total_params, model_size = count_trainable_parameters()

📊 ANALISI PARAMETRI ALLENABILI DEL MODELLO
  Style Encoder: 12,905,312 parametri (12,905,312 allenabili) |   49.2 MB
Content Encoder:  4,450,480 parametri ( 4,450,480 allenabili) |   17.0 MB
        Decoder:  3,681,515 parametri ( 3,681,515 allenabili) |   14.0 MB
  Discriminator:     49,666 parametri (    49,666 allenabili) |    0.2 MB
------------------------------------------------------------
         TOTALE: 21,086,973 parametri (21,086,973 allenabili)
     DIMENSIONE:                                            80.4 MB
   MEM. STIMATA:                                          322 MB (solo modello)

💾 ANALISI MEMORIA PER BATCH SIZE
Batch size  4:    364 MB ✅
Batch size  8:    405 MB ✅
Batch size 16:    489 MB ✅
Batch size 32:    656 MB ✅

🎯 CONFIGURAZIONE RACCOMANDATA
✅ Modello di dimensioni ragionevoli per il training

🚀 Per GPU con 12GB: Batch size raccomandato = 8
🔧 Per GPU con 8GB:  Batch size raccomandato = 4


In [None]:
# Funzioni di utilità per Colab
import psutil
import gc

def monitor_resources():
    """Monitora le risorse di sistema"""
    # CPU
    cpu_percent = psutil.cpu_percent()
    
    # RAM
    ram = psutil.virtual_memory()
    ram_percent = ram.percent
    
    # GPU
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.memory_allocated() / 1024**2
        gpu_cached = torch.cuda.memory_reserved() / 1024**2
        
        print(f"CPU: {cpu_percent:.1f}% | RAM: {ram_percent:.1f}% | GPU: {gpu_memory:.0f}MB/{gpu_cached:.0f}MB")
        
        # Allarme se troppa memoria
        if gpu_memory > 8000:  # >8GB
            print("⚠️ ATTENZIONE: Memoria GPU alta!")
            return True
    return False

def clear_memory():
    """Libera memoria GPU"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        print("🧹 Memoria GPU liberata")

def get_model_size(model):
    """Calcola la dimensione del modello"""
    param_count = sum(p.numel() for p in model.parameters())
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    size_mb = (param_size + buffer_size) / 1024**2
    return param_count, size_mb

# Test delle risorse
print("=== RESOURCE MONITORING ===")
monitor_resources()

# ✅ Crea un decoder dinamico di test (senza parametro S!)
test_decoder = DynamicDecoder(d_model=256, nhead=4, num_layers=3)
param_count, size_mb = get_model_size(test_decoder)
print(f"Dynamic Decoder: {param_count:,} parametri ({size_mb:.1f} MB)")

del test_decoder  # Libera memoria
clear_memory()

In [None]:
# Test del decoder dinamico prima del training
def test_dynamic_decoder():
    """Test del decoder dinamico con sequenze di lunghezze diverse"""
    print("🔧 Testing Dynamic Decoder...")
    
    # Parametri di test
    B, d_model = 4, 256
    
    # ✅ Crea decoder dinamico (SENZA parametro S!)
    decoder = DynamicDecoder(d_model=d_model, nhead=4, num_layers=2)
    
    # Test con sequenze di lunghezze diverse
    test_cases = [
        (2, "sequenza molto corta"),
        (4, "sequenza standard"), 
        (6, "sequenza lunga"),
        (8, "sequenza molto lunga")
    ]
    
    for S, description in test_cases:
        print(f"\n--- Test: {description} (S={S}) ---")
        
        # Input di test
        content_emb = torch.randn(B, S, d_model)  # [B, S, d_model] - S dinamico!
        class_emb = torch.randn(B, d_model)       # [B, d_model]
        y_target = torch.randn(B, S, 2, 287, 513)  # [B, S, 2, 287, 513] - S dinamico!
        
        print(f"Content embedding shape: {content_emb.shape}")
        print(f"Class embedding shape: {class_emb.shape}")
        print(f"Target shape: {y_target.shape}")
        
        # Test training mode
        decoder.train()
        try:
            output = decoder(content_emb, class_emb, y=y_target)
            print(f"✅ Training mode output shape: {output.shape}")
            assert output.shape == y_target.shape, f"Shape mismatch: {output.shape} vs {y_target.shape}"
        except Exception as e:
            print(f"❌ Training mode error: {e}")
            return False
        
        # Test inference mode con target_length specificato
        decoder.eval()
        try:
            with torch.no_grad():
                output = decoder(content_emb, class_emb, target_length=S)
                print(f"✅ Inference mode output shape: {output.shape}")
                assert output.shape == (B, S, 2, 287, 513), f"Inference shape mismatch: {output.shape}"
        except Exception as e:
            print(f"❌ Inference mode error: {e}")
            return False
        
        # Test inference mode senza target_length (usa default)
        try:
            with torch.no_grad():
                output_auto = decoder(content_emb, class_emb)
                print(f"✅ Auto-length inference output shape: {output_auto.shape}")
        except Exception as e:
            print(f"❌ Auto-length inference error: {e}")
            return False
        
        # Test loss
        try:
            loss_dict = compute_comprehensive_loss(output, y_target)
            print(f"✅ Loss computation successful: {loss_dict['total_loss'].item():.4f}")
        except Exception as e:
            print(f"❌ Loss computation error: {e}")
            return False
    
    print("\n🎉 Dynamic Decoder test successful! Decoder can handle variable sequence lengths!")
    return True

# Esegui test
print("Testing the new Dynamic Decoder...")
if test_dynamic_decoder():
    print("✅ Dynamic Decoder pronto per il training!")
else:
    print("❌ Dynamic Decoder ha problemi, controlla gli errori.")

In [None]:
# Debug utilities per il training dinamico
def debug_shapes(x, content_emb, style_emb, class_emb, step="forward"):
    """Debug delle dimensioni durante il training"""
    print(f"\n=== DEBUG SHAPES - {step} ===")
    print(f"Input x shape: {x.shape}")
    print(f"Content embedding shape: {content_emb.shape}")
    print(f"Style embedding shape: {style_emb.shape}")
    print(f"Class embedding shape: {class_emb.shape}")
    
    # Verifica dimensioni attese
    B, S, C, H, W = x.shape
    B_c, S_c, D_c = content_emb.shape
    B_s, D_s = style_emb.shape
    B_cl, D_cl = class_emb.shape
    
    print(f"Batch size: {B}")
    print(f"Sequence length: {S}")
    print(f"STFT dims: {C}x{H}x{W}")
    
    # Verifica coerenza
    assert B == B_c == B_s == B_cl, f"Batch size mismatch: {B}, {B_c}, {B_s}, {B_cl}"
    assert S == S_c, f"Sequence length mismatch: {S}, {S_c}"
    assert D_c == D_s == D_cl, f"Embedding dimension mismatch: {D_c}, {D_s}, {D_cl}"
    
    print("✅ All dimensions are consistent!")
    return True

# Versione safe del training con debug del decoder dinamico
def safe_debug_train():
    """Training con debug delle dimensioni per decoder dinamico"""
    print(f"Debug training on device: {device}")
    
    # ✅ Inizializza modelli con decoder dinamico
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM).to(device)
    decoder = DynamicDecoder(d_model=TRANSFORMER_DIM, nhead=4, num_layers=2).to(device)  # Senza S!
    
    # Test con un batch
    dataloader = get_dummy_dataloader(batch_size=2)  # Batch piccolo per debug
    
    print("\n🔍 Testing forward pass with Dynamic Decoder...")
    
    for i, (x, labels) in enumerate(dataloader):
        x, labels = x.to(device), labels.to(device)
        
        print(f"\nBatch {i+1} shapes:")
        print(f"  x: {x.shape}")
        print(f"  labels: {labels.shape}")
        
        # Forward pass encoders
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Debug shapes
        debug_shapes(x, content_emb, style_emb, class_emb)
        
        # Test decoder dinamico
        print(f"\n🧪 Testing Dynamic Decoder...")
        try:
            # Training mode con teacher forcing
            decoder.train()
            reconstructed = decoder(content_emb, style_emb, y=x)
            print(f"✅ Training mode output shape: {reconstructed.shape}")
            print(f"Expected shape: {x.shape}")
            assert reconstructed.shape == x.shape, f"Shape mismatch: {reconstructed.shape} vs {x.shape}"
            
            # Inference mode
            decoder.eval()
            with torch.no_grad():
                reconstructed_inf = decoder(content_emb, style_emb)
                print(f"✅ Inference mode output shape: {reconstructed_inf.shape}")
            
            # Test loss
            loss_dict = compute_comprehensive_loss(reconstructed, x)
            print(f"✅ Loss computation: {loss_dict['total_loss'].item():.4f}")
            
            print("✅ Dynamic Decoder test successful!")
            
        except Exception as e:
            print(f"❌ Dynamic Decoder error: {e}")
            import traceback
            traceback.print_exc()
            return False
        
        # Test solo il primo batch
        break
    
    print("\n🎉 Debug test completed successfully!")
    print("🚀 Dynamic Decoder is ready for full training!")
    return True

# Test delle dimensioni del modello dinamico
def check_dynamic_model_size():
    """Controlla le dimensioni del modello dinamico"""
    print("\n📊 Checking Dynamic Model Sizes...")
    
    # Crea modelli
    style_encoder = StyleEncoder(transformer_dim=TRANSFORMER_DIM)
    content_encoder = ContentEncoder(transformer_dim=TRANSFORMER_DIM)
    decoder = DynamicDecoder(d_model=TRANSFORMER_DIM, nhead=4, num_layers=3)
    discriminator = Discriminator(input_dim=TRANSFORMER_DIM)
    
    models = [
        ("Style Encoder", style_encoder),
        ("Content Encoder", content_encoder), 
        ("Dynamic Decoder", decoder),
        ("Discriminator", discriminator)
    ]
    
    total_params = 0
    total_size = 0
    
    for name, model in models:
        param_count, size_mb = get_model_size(model)
        total_params += param_count
        total_size += size_mb
        print(f"{name}: {param_count:,} parametri ({size_mb:.1f} MB)")
    
    print(f"\n📈 TOTALE: {total_params:,} parametri ({total_size:.1f} MB)")
    
    # Stima memoria per batch_size=8
    estimated_memory = total_size * 3 + (8 * 4 * 2 * 287 * 513 * 4) / (1024**2)  # modello + gradients + data
    print(f"💾 Memoria stimata (batch_size=8): ~{estimated_memory:.0f} MB")
    print(f"🖥️ Using device: {device}")
    
    if estimated_memory < 10000:  # <10GB
        print("✅ Dovrebbe funzionare su Colab Free!")
    else:
        print("⚠️ Potrebbe essere troppo per Colab Free")

# Esegui controlli
check_dynamic_model_size()

# Uncomment to run debug test
# safe_debug_train()

In [None]:
# Training con monitoraggio delle risorse
def safe_train():
    """Training con monitoraggio delle risorse per Colab"""
    try:
        print(f"🚀 Avvio training con monitoraggio risorse su {device}...")
        monitor_resources()
        
        # Mostra info GPU se disponibile
        if torch.cuda.is_available():
            print(f"💾 GPU Memory before training: {torch.cuda.memory_allocated()/1024**2:.1f}MB / {torch.cuda.memory_reserved()/1024**2:.1f}MB")
        
        # Avvia il training
        train()
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("💥 ERRORE: Memoria GPU insufficiente!")
            print("Prova a ridurre:")
            print("- BATCH_SIZE (attualmente {})".format(BATCH_SIZE))
            print("- NUM_FRAMES (attualmente {})".format(NUM_FRAMES))
            print("- TRANSFORMER_DIM (attualmente {})".format(TRANSFORMER_DIM))
            clear_memory()
        else:
            print(f"❌ Errore durante il training: {e}")
    except KeyboardInterrupt:
        print("⏹️ Training interrotto dall'utente")
        clear_memory()
    except Exception as e:
        print(f"❌ Errore imprevisto: {e}")
        clear_memory()
    finally:
        print(f"🏁 Training terminato su {device}")
        monitor_resources()

# Uncomment to run training
# safe_train()