In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt

from style_encoder import StyleEncoder
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss
from losses import (infoNCE_loss, margin_loss, adversarial_loss, 
                   disentanglement_loss)
from Dataloader import get_dataloader

# Configurazione dei dispositivi
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [2]:
config = {
    # Architettura
    "style_dim": 256,
    "content_dim": 256,
    "transformer_heads": 4,
    "transformer_layers": 4,
    "cnn_channels": [16, 32, 64, 128, 256],
    
    # Training
    "epochs": 100,
    "batch_size": 16,
    "lr": 1e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    
    # Pesi delle loss
    "lambda_adv_disc": 1.0,
    "lambda_adv_gen": 0.1,
    "lambda_disent": 0.5,
    "lambda_cont": 0.5,
    "lambda_margin": 0.2,
    "lambda_recon": 10.0,
    
    # Percorsi dati
    "piano_dir": "dataset/train/piano",
    "violin_dir": "dataset/train/violin",
    "stats_path": "stats_stft_cqt.npz",
    
    # Salvataggio
    "save_dir": "checkpoints",
    "save_interval": 5,
    
    # Training strategy
    "discriminator_steps": 1,  # Numero di step discriminator per step generator
    "generator_steps": 1,      # Numero di step generator per step discriminator
}

# Creazione directory salvataggio
os.makedirs(config["save_dir"], exist_ok=True)

In [3]:
style_encoder = StyleEncoder(
    cnn_out_dim=config["style_dim"],
    transformer_dim=config["style_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

content_encoder = ContentEncoder(
    cnn_out_dim=config["content_dim"],
    transformer_dim=config["content_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

discriminator = Discriminator(
    input_dim=config["style_dim"],
    hidden_dim=128
).to(device)

decoder = Decoder(
    d_model=config["style_dim"],
    nhead=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

# Inizializzazione pesi
def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

style_encoder.apply(init_weights)
content_encoder.apply(init_weights)
discriminator.apply(init_weights)
decoder.apply(init_weights)

Decoder(
  (conv_encoder): Sequential(
    (0): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): AdaptiveAvgPool2d(output_size=(32, 16))
  )
  (spatial_projection): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [4]:
# 3. Ottimizzatori
# =================
# Ottimizzatori separati per gruppi di modelli
optimizer_G = optim.Adam(
    list(style_encoder.parameters()) + 
    list(content_encoder.parameters()) + 
    list(decoder.parameters()),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"])
)

optimizer_D = optim.Adam(
    discriminator.parameters(),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"])
)

train_loader = get_dataloader(
    piano_dir=config["piano_dir"],
    violin_dir=config["violin_dir"],
    batch_size=config["batch_size"],
    shuffle=True,
    stats_path=config["stats_path"]
)


In [5]:
def set_requires_grad(models, requires_grad):
    """Abilita/disabilita i gradienti per un insieme di modelli"""
    if not isinstance(models, list):
        models = [models]
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

def save_checkpoint(epoch):
    """Salva i checkpoint dei modelli"""
    checkpoint = {
        'epoch': epoch,
        'style_encoder': style_encoder.state_dict(),
        'content_encoder': content_encoder.state_dict(),
        'discriminator': discriminator.state_dict(),
        'decoder': decoder.state_dict(),
        'optimizer_G': optimizer_G.state_dict(),
        'optimizer_D': optimizer_D.state_dict(),
        'config': config
    }
    torch.save(checkpoint, os.path.join(config["save_dir"], f"checkpoint_epoch_{epoch}.pth"))


In [6]:
def discriminator_training_step(x, labels):
    """Singolo step di training per il discriminatore"""
    
    # only discriminator gets gradients
    set_requires_grad(discriminator, True)
    set_requires_grad([style_encoder, content_encoder, decoder], False)
    
    # no gradient flow for generators
    with torch.no_grad():
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Debug: controlla se ci sono NaN negli embeddings
    if torch.isnan(style_emb).any() or torch.isnan(class_emb).any() or torch.isnan(content_emb).any():
        print("WARNING: NaN detected in embeddings!")
        print(f"style_emb has NaN: {torch.isnan(style_emb).any()}")
        print(f"class_emb has NaN: {torch.isnan(class_emb).any()}")
        print(f"content_emb has NaN: {torch.isnan(content_emb).any()}")
        return float('nan')
    
    # compute loss with detach
    disc_loss, _ = adversarial_loss(
        style_emb.detach(), 
        class_emb.detach(),
        content_emb.detach(),
        discriminator,
        labels,
        compute_for_discriminator=True,
        lambda_content=config["lambda_adv_disc"]
    )
    
        # Debug: controlla se la loss è NaN
    if torch.isnan(disc_loss):
        print("WARNING: NaN in discriminator loss!")
        return float('nan')
    
    # discriminator backprop
    optimizer_D.zero_grad()
    disc_loss.backward()
    
    # Gradient clipping per il discriminatore
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
    
    optimizer_D.step()
    
    return disc_loss.item()



def generator_training_step(x, labels):
    """Singolo step di training per i generatori"""

    set_requires_grad([style_encoder, content_encoder, decoder], True)
    set_requires_grad(discriminator, False)
    

    style_emb, class_emb = style_encoder(x, labels)
    content_emb = content_encoder(x)
    
    if torch.isnan(style_emb).any() or torch.isnan(class_emb).any() or torch.isnan(content_emb).any():
        print("WARNING: NaN detected in generator embeddings!")
        return {
            'total_G': float('nan'),
            'adv_gen': float('nan'),
            'disent': float('nan'),
            'cont': float('nan'),
            'margin': float('nan'),
            'recon': float('nan')
        }
    
    # keep grads
    _, adv_gen_loss = adversarial_loss(
        style_emb, 
        class_emb,
        content_emb,
        discriminator,
        labels,
        compute_for_discriminator=False,
        lambda_content=config["lambda_adv_gen"]
    )
    
    # disentanglement loss
    disent_loss = disentanglement_loss(
        style_emb,
        content_emb.mean(dim=1),
        use_hsic=True
    )
    
    # contrastive loss
    cont_loss = infoNCE_loss(style_emb, labels)
    
    # margin loss 
    margin_loss_val = margin_loss(class_emb)
    
    # reconstruction loss
    stft_part = x[:, :, :, :, :513]
    recon_x = decoder(content_emb, style_emb, y=stft_part)
    recon_losses = compute_comprehensive_loss(recon_x, stft_part)
    recon_loss = recon_losses['total_loss']
    
    
    # Debug: controlla se qualche loss è NaN
    losses_to_check = [adv_gen_loss, disent_loss, cont_loss, margin_loss_val, recon_loss]
    loss_names = ['adv_gen', 'disent', 'cont', 'margin', 'recon']
    
    for loss_val, loss_name in zip(losses_to_check, loss_names):
        if torch.isnan(loss_val):
            print(f"WARNING: NaN in {loss_name} loss!")
            
    
    total_G_loss = (
        config["lambda_adv_gen"] * adv_gen_loss +
        config["lambda_disent"] * disent_loss +
        config["lambda_cont"] * cont_loss +
        config["lambda_margin"] * margin_loss_val +
        config["lambda_recon"] * recon_loss
    )
    
    # Backpropagation
    optimizer_G.zero_grad()
    total_G_loss.backward()
    
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(
        list(style_encoder.parameters()) + 
        list(content_encoder.parameters()) +
        list(decoder.parameters()),
        0.5
    )
    
    optimizer_G.step()
    
    return {
        'total_G': total_G_loss.item(),
        'adv_gen': adv_gen_loss.item(),
        'disent': disent_loss.item(),
        'cont': cont_loss.item(),
        'margin': margin_loss_val.item(),
        'recon': recon_loss.item()
    }

In [7]:
# 6. Ciclo di training principale
# ================================
loss_history = {
    'total_G': [],
    'disc': [],
    'disent': [],
    'cont': [],
    'margin': [],
    'recon': [],
    'adv_gen': []
}

print("Starting training...")
for epoch in tqdm(range(config["epochs"]), desc="Training Progress"):
    # Modalità training per tutti i modelli
    style_encoder.train()
    content_encoder.train()
    discriminator.train()
    decoder.train()
    
    epoch_losses = {key: [] for key in loss_history.keys()}
    
    for batch_idx, (x, labels) in enumerate(train_loader):
        x = x.to(device)          # (B, S, 2, T, F)
        labels = labels.to(device) # (B,)
        
        # ================================================================
        # Alternating training: Discriminator -> Generator
        # ================================================================
        
        # Step 1: Train Discriminator
        for _ in range(config["discriminator_steps"]):
            disc_loss = discriminator_training_step(x, labels)
            epoch_losses['disc'].append(disc_loss)
        
        # Step 2: Train Generator
        for _ in range(config["generator_steps"]):
            gen_losses = generator_training_step(x, labels)
            
            # Aggiungi tutte le loss del generatore
            for key, value in gen_losses.items():
                epoch_losses[key].append(value)
        
        # Log progress periodically
        if batch_idx % 1 == 0:
            
            # Controlla se ci sono NaN nelle loss prima di calcolare le medie
            if any(np.isnan(epoch_losses['disc'][-config["discriminator_steps"]:])):
                print(f"Epoch {epoch+1}, Batch {batch_idx}: NaN detected in discriminator loss!")
                break
            if any(np.isnan(epoch_losses['total_G'][-config["generator_steps"]:])):
                print(f"Epoch {epoch+1}, Batch {batch_idx}: NaN detected in generator loss!")
                break
            
            
            avg_disc = np.mean(epoch_losses['disc'][-config["discriminator_steps"]:])
            avg_gen = np.mean(epoch_losses['total_G'][-config["generator_steps"]:])
            avg_recon = np.mean(epoch_losses['recon'][-config["generator_steps"]:])
            avg_disent = np.mean(epoch_losses['disent'][-config["generator_steps"]:])
            avg_cont = np.mean(epoch_losses['cont'][-config["generator_steps"]:])
            avg_margin = np.mean(epoch_losses['margin'][-config["generator_steps"]:])
            avg_adv_gen = np.mean(epoch_losses['adv_gen'][-config["generator_steps"]:])
            
            print(f"Epoch {epoch+1}, Batch {batch_idx}:")
            print(f"  D_loss: {avg_disc:.4f} | G_total: {avg_gen:.4f}")
            print(f"  Recon: {avg_recon:.4f} | Disent: {avg_disent:.4f} | Cont: {avg_cont:.4f}")
            print(f"  Margin: {avg_margin:.4f} | Adv_gen: {avg_adv_gen:.4f}")
            print("-" * 40)
    
    # ================================================================
    # Fine epoca: salvataggio e logging
    # ================================================================
    
    # Aggiungi loss dell'epoca alla storia
    for key in loss_history.keys():
        if epoch_losses[key]:  # Solo se ci sono valori
            loss_history[key].extend(epoch_losses[key])
    
    # Salvataggio checkpoint
    if (epoch + 1) % config["save_interval"] == 0:
        save_checkpoint(epoch + 1)
    
    # Logging delle loss medie dell'epoca
    print(f"\nEpoch {epoch+1}/{config['epochs']} Summary:")
    for key in ['disc', 'total_G', 'recon', 'disent', 'cont', 'margin']:
        if epoch_losses[key]:
            avg_loss = np.mean(epoch_losses[key])
            print(f"  {key}: {avg_loss:.4f}")
    print("-" * 50)

# 7. Salvataggio finale e visualizzazione
# ========================================
save_checkpoint(config["epochs"])

# Plot delle loss
plt.figure(figsize=(15, 10))

# Crea subplot per diverse categorie di loss
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

loss_groups = [
    ('disc', 'Discriminator Loss'),
    ('total_G', 'Generator Total Loss'),
    ('recon', 'Reconstruction Loss'),
    ('disent', 'Disentanglement Loss'),
    ('cont', 'Contrastive Loss'),
    ('adv_gen', 'Adversarial Generator Loss')
]

for i, (loss_name, title) in enumerate(loss_groups):
    if loss_name in loss_history and loss_history[loss_name]:
        # Smooth the curves for better visualization
        values = loss_history[loss_name]
        if len(values) > 100:
            # Moving average for smoother curves
            window = len(values) // 100
            smoothed = np.convolve(values, np.ones(window)/window, mode='valid')
            axes[i].plot(smoothed, label=f'{loss_name} (smoothed)')
        axes[i].plot(values, alpha=0.3, label=f'{loss_name} (raw)')
        axes[i].set_title(title)
        axes[i].set_xlabel('Iteration')
        axes[i].set_ylabel('Loss')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config["save_dir"], "loss_curves.png"), dpi=300, bbox_inches='tight')
plt.show()

Starting training...


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1, Batch 0:
  D_loss: 1.6904 | G_total: 29.2459
  Recon: 2.8032 | Disent: 0.0020 | Cont: 2.5507
  Margin: 0.0000 | Adv_gen: -0.6212
----------------------------------------
style_emb has NaN: True
class_emb has NaN: True
content_emb has NaN: True


Training Progress:   1%|          | 1/100 [00:41<1:08:32, 41.54s/it]

Epoch 1, Batch 1: NaN detected in discriminator loss!

Epoch 1/100 Summary:
  disc: nan
  total_G: nan
  recon: nan
  disent: nan
  cont: nan
  margin: nan
--------------------------------------------------
style_emb has NaN: True
class_emb has NaN: True
content_emb has NaN: True


Training Progress:   2%|▏         | 2/100 [00:51<37:46, 23.13s/it]  

Epoch 2, Batch 0: NaN detected in discriminator loss!

Epoch 2/100 Summary:
  disc: nan
  total_G: nan
  recon: nan
  disent: nan
  cont: nan
  margin: nan
--------------------------------------------------
style_emb has NaN: True
class_emb has NaN: True
content_emb has NaN: True


Training Progress:   3%|▎         | 3/100 [01:02<27:56, 17.29s/it]

Epoch 3, Batch 0: NaN detected in discriminator loss!

Epoch 3/100 Summary:
  disc: nan
  total_G: nan
  recon: nan
  disent: nan
  cont: nan
  margin: nan
--------------------------------------------------


Training Progress:   3%|▎         | 3/100 [01:05<35:32, 21.98s/it]


KeyboardInterrupt: 

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import warnings

from style_encoder import StyleEncoder
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss
from losses import (infoNCE_loss, margin_loss, adversarial_loss, 
                   disentanglement_loss)
from Dataloader import get_dataloader

# Enhanced device configuration with memory management
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
    print(f"Memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")

# Enhanced configuration with more conservative settings
config = {
    # Architecture - reduced complexity
    "style_dim": 128,  # Reduced from 256
    "content_dim": 128,  # Reduced from 256
    "transformer_heads": 2,  # Reduced from 4
    "transformer_layers": 2,  # Reduced from 4
    "cnn_channels": [16, 32, 64, 128],  # Reduced depth
    
    # Training - very conservative settings
    "epochs": 100,
    "batch_size": 8,  # Reduced from 16
    "lr": 1e-5,  # Much more conservative
    "beta1": 0.5,
    "beta2": 0.999,
    "weight_decay": 1e-6,  # Added weight decay
    
    # Loss weights - more balanced
    "lambda_adv_disc": 0.5,  # Reduced
    "lambda_adv_gen": 0.1,
    "lambda_disent": 0.1,    # Reduced
    "lambda_cont": 0.1,      # Reduced
    "lambda_margin": 0.05,   # Reduced
    "lambda_recon": 1.0,     # Reduced significantly
    
    # Gradient control
    "grad_clip_value": 0.25,  # Very conservative
    "grad_accumulation_steps": 2,  # Accumulate gradients
    
    # Stability controls
    "warmup_epochs": 5,
    "nan_check_interval": 10,
    "early_stop_nan_count": 5,
    
    # Paths
    "piano_dir": "dataset/train/piano",
    "violin_dir": "dataset/train/violin",
    "stats_path": "stats_stft_cqt.npz",
    
    # Saving
    "save_dir": "checkpoints",
    "save_interval": 10,
    
    # Training strategy
    "discriminator_steps": 1,
    "generator_steps": 1,
}

# Create save directory
os.makedirs(config["save_dir"], exist_ok=True)

# Enhanced weight initialization
def init_weights_conservative(m):
    """More conservative weight initialization to prevent NaN"""
    if isinstance(m, nn.Conv2d):
        # Xavier initialization for conv layers
        nn.init.xavier_uniform_(m.weight, gain=0.1)  # Small gain
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        # Xavier initialization for linear layers
        nn.init.xavier_uniform_(m.weight, gain=0.1)  # Small gain
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.MultiheadAttention):
        # Special handling for transformer attention
        if hasattr(m, 'in_proj_weight') and m.in_proj_weight is not None:
            nn.init.xavier_uniform_(m.in_proj_weight, gain=0.1)
        if hasattr(m, 'out_proj') and m.out_proj.weight is not None:
            nn.init.xavier_uniform_(m.out_proj.weight, gain=0.1)

# NaN detection and handling utilities
class NaNDetector:
    def __init__(self):
        self.nan_count = 0
        self.max_nan_count = config["early_stop_nan_count"]
    
    def check_tensor(self, tensor, name="tensor"):
        """Check if tensor contains NaN or Inf"""
        if torch.isnan(tensor).any():
            print(f"⚠️  NaN detected in {name}!")
            self.nan_count += 1
            return True
        if torch.isinf(tensor).any():
            print(f"⚠️  Inf detected in {name}!")
            self.nan_count += 1
            return True
        return False
    
    def check_model_parameters(self, model, model_name):
        """Check all model parameters for NaN/Inf"""
        for name, param in model.named_parameters():
            if param.grad is not None:
                if self.check_tensor(param.grad, f"{model_name}.{name}.grad"):
                    return True
            if self.check_tensor(param, f"{model_name}.{name}"):
                return True
        return False
    
    def should_stop(self):
        return self.nan_count >= self.max_nan_count

# Initialize models with conservative settings
style_encoder = StyleEncoder(
    cnn_out_dim=config["style_dim"],
    transformer_dim=config["style_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

content_encoder = ContentEncoder(
    cnn_out_dim=config["content_dim"],
    transformer_dim=config["content_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

discriminator = Discriminator(
    input_dim=config["style_dim"],
    hidden_dim=64  # Reduced from 128
).to(device)

decoder = Decoder(
    d_model=config["style_dim"],
    nhead=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

# Apply conservative initialization
for model in [style_encoder, content_encoder, discriminator, decoder]:
    model.apply(init_weights_conservative)

# Initialize NaN detector
nan_detector = NaNDetector()

# Enhanced optimizers with weight decay
optimizer_G = optim.AdamW(  # Using AdamW for better stability
    list(style_encoder.parameters()) + 
    list(content_encoder.parameters()) + 
    list(decoder.parameters()),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

optimizer_D = optim.AdamW(
    discriminator.parameters(),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

# Learning rate schedulers for stability
scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_G, mode='min', factor=0.5, patience=5
)

scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_D, mode='min', factor=0.5, patience=5
)

# Dataset and DataLoader
try:
    train_loader = get_dataloader(
        piano_dir=config["piano_dir"],
        violin_dir=config["violin_dir"],
        batch_size=config["batch_size"],
        shuffle=True,
        stats_path=config["stats_path"]
    )
    print(f"✅ DataLoader created successfully with batch_size={config['batch_size']}")
except Exception as e:
    print(f"❌ Error creating DataLoader: {e}")
    raise

# Enhanced utility functions
def set_requires_grad(models, requires_grad):
    """Enable/disable gradients for models"""
    if not isinstance(models, list):
        models = [models]
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

def save_checkpoint(epoch, additional_info=None):
    """Save model checkpoints with additional info"""
    checkpoint = {
        'epoch': epoch,
        'style_encoder': style_encoder.state_dict(),
        'content_encoder': content_encoder.state_dict(),
        'discriminator': discriminator.state_dict(),
        'decoder': decoder.state_dict(),
        'optimizer_G': optimizer_G.state_dict(),
        'optimizer_D': optimizer_D.state_dict(),
        'scheduler_G': scheduler_G.state_dict(),
        'scheduler_D': scheduler_D.state_dict(),
        'config': config,
        'nan_count': nan_detector.nan_count
    }
    
    if additional_info:
        checkpoint.update(additional_info)
    
    checkpoint_path = os.path.join(config["save_dir"], f"checkpoint_epoch_{epoch}.pth")
    torch.save(checkpoint, checkpoint_path)
    print(f"💾 Checkpoint saved: {checkpoint_path}")

def get_learning_rate_multiplier(epoch):
    """Get learning rate multiplier for warmup"""
    if epoch < config["warmup_epochs"]:
        return (epoch + 1) / config["warmup_epochs"]
    return 1.0

def safe_loss_computation(loss_fn, *args, **kwargs):
    """Safely compute loss with NaN checking"""
    try:
        loss = loss_fn(*args, **kwargs)
        if torch.isnan(loss).any() or torch.isinf(loss).any():
            print(f"⚠️  Invalid loss detected in {loss_fn.__name__}")
            return torch.tensor(0.0, device=device, requires_grad=True)
        return loss
    except Exception as e:
        print(f"❌ Error in loss computation {loss_fn.__name__}: {e}")
        return torch.tensor(0.0, device=device, requires_grad=True)

def discriminator_training_step(x, labels, epoch):
    """Enhanced discriminator training step with NaN protection"""
    set_requires_grad(discriminator, True)
    set_requires_grad([style_encoder, content_encoder, decoder], False)
    
    # Forward pass with gradient stopping
    with torch.no_grad():
        try:
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
        except Exception as e:
            print(f"❌ Error in encoder forward pass: {e}")
            return float('nan')
    
    # Check for NaN in embeddings
    if (nan_detector.check_tensor(style_emb, "style_emb") or 
        nan_detector.check_tensor(class_emb, "class_emb") or 
        nan_detector.check_tensor(content_emb, "content_emb")):
        return float('nan')
    
    # Compute discriminator loss
    try:
        disc_loss, _ = adversarial_loss(
            style_emb.detach(),
            class_emb.detach(),
            content_emb.detach(),
            discriminator,
            labels,
            compute_for_discriminator=True,
            lambda_content=config["lambda_adv_disc"]
        )
    except Exception as e:
        print(f"❌ Error in discriminator loss computation: {e}")
        return float('nan')
    
    if nan_detector.check_tensor(disc_loss, "disc_loss"):
        return float('nan')
    
    # Backpropagation
    optimizer_D.zero_grad()
    disc_loss.backward()
    
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config["grad_clip_value"])
    
    # Check gradients
    if nan_detector.check_model_parameters(discriminator, "discriminator"):
        return float('nan')
    
    optimizer_D.step()
    
    return disc_loss.item()

def generator_training_step(x, labels, epoch):
    """Enhanced generator training step with NaN protection"""
    set_requires_grad([style_encoder, content_encoder, decoder], True)
    set_requires_grad(discriminator, False)
    
    # Forward pass
    try:
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
    except Exception as e:
        print(f"❌ Error in generator forward pass: {e}")
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
    
    # Check embeddings
    if (nan_detector.check_tensor(style_emb, "style_emb") or 
        nan_detector.check_tensor(class_emb, "class_emb") or 
        nan_detector.check_tensor(content_emb, "content_emb")):
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
    
    # Compute losses safely
    losses = {}
    
    # Adversarial loss for generator
    try:
        _, adv_gen_loss = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=False,
            lambda_content=config["lambda_adv_gen"]
        )
        losses['adv_gen'] = adv_gen_loss
    except Exception as e:
        print(f"❌ Error in adversarial loss: {e}")
        losses['adv_gen'] = torch.tensor(0.0, device=device, requires_grad=True)
    
    # Disentanglement loss
    try:
        disent_loss = disentanglement_loss(
            style_emb,
            content_emb.mean(dim=1),
            use_hsic=True
        )
        losses['disent'] = disent_loss
    except Exception as e:
        print(f"❌ Error in disentanglement loss: {e}")
        losses['disent'] = torch.tensor(0.0, device=device, requires_grad=True)
    
    # Contrastive loss
    try:
        cont_loss = infoNCE_loss(style_emb, labels)
        losses['cont'] = cont_loss
    except Exception as e:
        print(f"❌ Error in contrastive loss: {e}")
        losses['cont'] = torch.tensor(0.0, device=device, requires_grad=True)
    
    # Margin loss
    try:
        margin_loss_val = margin_loss(class_emb)
        losses['margin'] = margin_loss_val
    except Exception as e:
        print(f"❌ Error in margin loss: {e}")
        losses['margin'] = torch.tensor(0.0, device=device, requires_grad=True)
    
    # Reconstruction loss
    try:
        stft_part = x[:, :, :, :, :513]
        recon_x = decoder(content_emb, style_emb, y=stft_part)
        recon_losses = compute_comprehensive_loss(recon_x, stft_part)
        recon_loss = recon_losses['total_loss']
        losses['recon'] = recon_loss
    except Exception as e:
        print(f"❌ Error in reconstruction loss: {e}")
        losses['recon'] = torch.tensor(0.0, device=device, requires_grad=True)
    
    # Check all losses for NaN
    for loss_name, loss_val in losses.items():
        if nan_detector.check_tensor(loss_val, f"{loss_name}_loss"):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
    
    # Apply learning rate warmup
    lr_multiplier = get_learning_rate_multiplier(epoch)
    
    # Compute total loss with reduced weights during warmup
    warmup_factor = lr_multiplier if epoch < config["warmup_epochs"] else 1.0
    
    total_G_loss = (
        config["lambda_adv_gen"] * losses['adv_gen'] * warmup_factor +
        config["lambda_disent"] * losses['disent'] * warmup_factor +
        config["lambda_cont"] * losses['cont'] * warmup_factor +
        config["lambda_margin"] * losses['margin'] * warmup_factor +
        config["lambda_recon"] * losses['recon']
    )
    
    if nan_detector.check_tensor(total_G_loss, "total_G_loss"):
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
    
    # Backpropagation
    optimizer_G.zero_grad()
    total_G_loss.backward()
    
    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(
        list(style_encoder.parameters()) + 
        list(content_encoder.parameters()) +
        list(decoder.parameters()),
        config["grad_clip_value"]
    )
    
    # Check gradients
    for model, name in [(style_encoder, "style_encoder"), 
                       (content_encoder, "content_encoder"), 
                       (decoder, "decoder")]:
        if nan_detector.check_model_parameters(model, name):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
    
    optimizer_G.step()
    
    return {
        'total_G': total_G_loss.item(),
        'adv_gen': losses['adv_gen'].item(),
        'disent': losses['disent'].item(),
        'cont': losses['cont'].item(),
        'margin': losses['margin'].item(),
        'recon': losses['recon'].item()
    }

# Training loop with enhanced monitoring
loss_history = {
    'total_G': [],
    'disc': [],
    'disent': [],
    'cont': [],
    'margin': [],
    'recon': [],
    'adv_gen': []
}

print("🚀 Starting enhanced training with NaN protection...")
print(f"📊 Configuration: {config}")

try:
    for epoch in tqdm(range(config["epochs"]), desc="Training Progress"):
        if nan_detector.should_stop():
            print(f"🛑 Early stopping due to too many NaN occurrences ({nan_detector.nan_count})")
            break
        
        # Set models to training mode
        for model in [style_encoder, content_encoder, discriminator, decoder]:
            model.train()
        
        epoch_losses = {key: [] for key in loss_history.keys()}
        batch_count = 0
        
        for batch_idx, (x, labels) in enumerate(train_loader):
            if nan_detector.should_stop():
                print(f"🛑 Stopping epoch {epoch+1} due to NaN issues")
                break
            
            x = x.to(device)
            labels = labels.to(device)
            
            # Check input data
            if nan_detector.check_tensor(x, "input_x") or nan_detector.check_tensor(labels, "input_labels"):
                print(f"⚠️  Skipping batch {batch_idx} due to NaN in input")
                continue
            
            # Training steps
            for _ in range(config["discriminator_steps"]):
                disc_loss = discriminator_training_step(x, labels, epoch)
                if not np.isnan(disc_loss):
                    epoch_losses['disc'].append(disc_loss)
            
            for _ in range(config["generator_steps"]):
                gen_losses = generator_training_step(x, labels, epoch)
                
                # Only add non-NaN losses
                for key, value in gen_losses.items():
                    if not np.isnan(value):
                        epoch_losses[key].append(value)
            
            batch_count += 1
            
            # Periodic logging
            if batch_idx % 1 == 0:  # More frequent logging
                print(f"\n📈 Epoch {epoch+1}/{config['epochs']}, Batch {batch_idx}")
                print(f"   NaN count: {nan_detector.nan_count}")
                
                # Log recent losses
                for key in ['disc', 'total_G', 'recon']:
                    if epoch_losses[key]:
                        recent_losses = epoch_losses[key][-config["generator_steps"]:]
                        avg_loss = np.mean(recent_losses)
                        print(f"   {key}: {avg_loss:.4f}")
                
                # Memory usage
                if torch.cuda.is_available():
                    print(f"   GPU Memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
        
        # End of epoch processing
        if not nan_detector.should_stop() and batch_count > 0:
            # Update learning rate schedulers
            if epoch_losses['disc']:
                scheduler_D.step(np.mean(epoch_losses['disc']))
            if epoch_losses['total_G']:
                scheduler_G.step(np.mean(epoch_losses['total_G']))
            
            # Add to history
            for key in loss_history.keys():
                if epoch_losses[key]:
                    loss_history[key].extend(epoch_losses[key])
            
            # Save checkpoint
            if (epoch + 1) % config["save_interval"] == 0:
                save_checkpoint(epoch + 1, {'loss_history': loss_history})
            
            # Epoch summary
            print(f"\n✅ Epoch {epoch+1} completed:")
            for key in ['disc', 'total_G', 'recon']:
                if epoch_losses[key]:
                    print(f"   {key}: {np.mean(epoch_losses[key]):.4f}")
            print(f"   Batches processed: {batch_count}")
            print("-" * 50)
        
        # Clear cache periodically
        if torch.cuda.is_available() and epoch % 10 == 0:
            torch.cuda.empty_cache()

except KeyboardInterrupt:
    print("\n⏹️  Training interrupted by user")
except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Final checkpoint
    save_checkpoint(epoch + 1, {'loss_history': loss_history, 'final_checkpoint': True})
    
    # Plot results if we have data
    if any(loss_history[key] for key in loss_history.keys()):
        plt.figure(figsize=(15, 10))
        
        valid_loss_groups = []
        for loss_name in ['disc', 'total_G', 'recon', 'disent', 'cont', 'adv_gen']:
            if loss_history[loss_name]:
                valid_loss_groups.append((loss_name, loss_name.replace('_', ' ').title()))
        
        if valid_loss_groups:
            n_plots = len(valid_loss_groups)
            cols = 3
            rows = (n_plots + cols - 1) // cols
            
            fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
            if rows == 1:
                axes = [axes] if cols == 1 else axes
            else:
                axes = axes.flatten()
            
            for i, (loss_name, title) in enumerate(valid_loss_groups):
                values = loss_history[loss_name]
                if len(values) > 100:
                    # Moving average for smoother visualization
                    window = max(1, len(values) // 100)
                    smoothed = np.convolve(values, np.ones(window)/window, mode='valid')
                    axes[i].plot(smoothed, label=f'{loss_name} (smoothed)', linewidth=2)
                
                axes[i].plot(values, alpha=0.3, label=f'{loss_name} (raw)')
                axes[i].set_title(title)
                axes[i].set_xlabel('Iteration')
                axes[i].set_ylabel('Loss')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
            
            # Hide unused subplots
            for i in range(len(valid_loss_groups), len(axes)):
                axes[i].set_visible(False)
            
            plt.tight_layout()
            plt.savefig(os.path.join(config["save_dir"], "loss_curves.png"), 
                       dpi=300, bbox_inches='tight')
            plt.show()
    
    print(f"\n🎯 Training session completed!")
    print(f"   Total NaN occurrences: {nan_detector.nan_count}")
    print(f"   Checkpoints saved in: {config['save_dir']}")
    print(f"   Final learning rates - G: {scheduler_G.get_last_lr()}, D: {scheduler_D.get_last_lr()}")

Using device: cpu
✅ DataLoader created successfully with batch_size=8
🚀 Starting enhanced training with NaN protection...
📊 Configuration: {'style_dim': 128, 'content_dim': 128, 'transformer_heads': 2, 'transformer_layers': 2, 'cnn_channels': [16, 32, 64, 128], 'epochs': 100, 'batch_size': 8, 'lr': 1e-05, 'beta1': 0.5, 'beta2': 0.999, 'weight_decay': 1e-06, 'lambda_adv_disc': 0.5, 'lambda_adv_gen': 0.1, 'lambda_disent': 0.1, 'lambda_cont': 0.1, 'lambda_margin': 0.05, 'lambda_recon': 1.0, 'grad_clip_value': 0.25, 'grad_accumulation_steps': 2, 'warmup_epochs': 5, 'nan_check_interval': 10, 'early_stop_nan_count': 5, 'piano_dir': 'dataset/train/piano', 'violin_dir': 'dataset/train/violin', 'stats_path': 'stats_stft_cqt.npz', 'save_dir': 'checkpoints', 'save_interval': 10, 'discriminator_steps': 1, 'generator_steps': 1}


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]


📈 Epoch 1/100, Batch 0
   NaN count: 0
   disc: 1.3863
   total_G: 2.6985
   recon: 2.6641

📈 Epoch 1/100, Batch 1
   NaN count: 0
   disc: 1.3863
   total_G: 3.7735
   recon: 3.7391

📈 Epoch 1/100, Batch 2
   NaN count: 0
   disc: 1.3863
   total_G: 4.8772
   recon: 4.8430

📈 Epoch 1/100, Batch 3
   NaN count: 0
   disc: 1.3863
   total_G: 3.1949
   recon: 3.1606

📈 Epoch 1/100, Batch 4
   NaN count: 0
   disc: 1.3863
   total_G: 3.1626
   recon: 3.1284

📈 Epoch 1/100, Batch 5
   NaN count: 0
   disc: 1.3863
   total_G: 4.5419
   recon: 4.5078

📈 Epoch 1/100, Batch 6
   NaN count: 0
   disc: 1.3863
   total_G: 2.8154
   recon: 2.7818

📈 Epoch 1/100, Batch 7
   NaN count: 0
   disc: 1.3863
   total_G: 2.3858
   recon: 2.3520

📈 Epoch 1/100, Batch 8
   NaN count: 0
   disc: 1.3863
   total_G: 3.6728
   recon: 3.6392

📈 Epoch 1/100, Batch 9
   NaN count: 0
   disc: 1.3863
   total_G: 3.8458
   recon: 3.8122

📈 Epoch 1/100, Batch 10
   NaN count: 0
   disc: 1.3863
   total_G: 3.9670
   r

Training Progress:   0%|          | 0/100 [03:28<?, ?it/s]



⏹️  Training interrupted by user
💾 Checkpoint saved: checkpoints/checkpoint_epoch_1.pth

🎯 Training session completed!
   Total NaN occurrences: 0
   Checkpoints saved in: checkpoints
   Final learning rates - G: [1e-05], D: [1e-05]


In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import warnings

from style_encoder import StyleEncoder
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss
from losses import (infoNCE_loss, margin_loss, adversarial_loss, 
                   disentanglement_loss)
from Dataloader import get_dataloader

# ================================================================
# CONFIGURAZIONE DISPOSITIVO E MEMORIA
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
    print(f"Memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")

# ================================================================
# CONFIGURAZIONE TRAINING
# ================================================================
config = {
    # Architettura - parametri originali richiesti
    "style_dim": 256,
    "content_dim": 256,
    "transformer_heads": 4,
    "transformer_layers": 4,
    "cnn_channels": [16, 32, 64, 128, 256],
    
    # Training - parametri conservativi per stabilità
    "epochs": 100,
    "batch_size": 8,          # Ridotto per evitare OOM
    "lr": 5e-5,              # Learning rate conservativo
    "beta1": 0.5,
    "beta2": 0.999,
    "weight_decay": 1e-5,    # Weight decay per prevenire overfitting
    
    # Pesi delle loss - bilanciati per stabilità
    "lambda_adv_disc": 0.8,
    "lambda_adv_gen": 0.1,
    "lambda_disent": 0.3,
    "lambda_cont": 0.2,
    "lambda_margin": 0.1,
    "lambda_recon": 2.0,
    
    # Controlli di stabilità
    "grad_clip_value": 0.5,      # Gradient clipping conservativo
    "warmup_epochs": 5,          # Epochs di warmup
    "nan_threshold": 5,          # Max NaN consecutivi prima di fermarsi
    
    # Percorsi dati
    "piano_dir": "dataset/train/piano",
    "violin_dir": "dataset/train/violin",
    "stats_path": "stats_stft_cqt.npz",
    
    # Salvataggio
    "save_dir": "checkpoints",
    "save_interval": 10,
    
    # Strategia di training
    "discriminator_steps": 1,
    "generator_steps": 1,
}

# Crea directory di salvataggio
os.makedirs(config["save_dir"], exist_ok=True)

# ================================================================
# INIZIALIZZAZIONE PESI CONSERVATIVA
# ================================================================
def init_weights_conservative(m):
    """
    Inizializzazione conservativa dei pesi per prevenire NaN
    """
    if isinstance(m, nn.Conv2d):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

# ================================================================
# FUNZIONI DI UTILITÀ
# ================================================================
def check_for_nan(*tensors, names=None):
    """
    Controlla se ci sono NaN o Inf nei tensori
    
    Args:
        *tensors: Tensori da controllare
        names: Nomi dei tensori per il debug
    
    Returns:
        bool: True se trovati NaN/Inf
    """
    if names is None:
        names = [f"tensor_{i}" for i in range(len(tensors))]
    
    for tensor, name in zip(tensors, names):
        if torch.isnan(tensor).any():
            print(f"🚨 NaN detected in {name}")
            return True
        if torch.isinf(tensor).any():
            print(f"🚨 Inf detected in {name}")
            return True
    return False

def set_requires_grad(models, requires_grad):
    """
    Abilita/disabilita i gradienti per i modelli
    
    Args:
        models: Modello singolo o lista di modelli
        requires_grad: True per abilitare, False per disabilitare
    """
    if not isinstance(models, list):
        models = [models]
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

def get_learning_rate_multiplier(epoch, warmup_epochs):
    """
    Calcola il moltiplicatore del learning rate per il warmup
    
    Args:
        epoch: Epoca corrente
        warmup_epochs: Numero di epoche di warmup
    
    Returns:
        float: Moltiplicatore del learning rate
    """
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    return 1.0

def save_checkpoint(epoch, models_dict, optimizers_dict, schedulers_dict, config):
    """
    Salva checkpoint completo
    
    Args:
        epoch: Epoca corrente
        models_dict: Dizionario dei modelli
        optimizers_dict: Dizionario degli ottimizzatori
        schedulers_dict: Dizionario degli schedulers
        config: Configurazione
    """
    checkpoint = {
        'epoch': epoch,
        'config': config
    }
    
    # Salva stati dei modelli
    for name, model in models_dict.items():
        checkpoint[name] = model.state_dict()
    
    # Salva stati degli ottimizzatori
    for name, optimizer in optimizers_dict.items():
        checkpoint[name] = optimizer.state_dict()
    
    # Salva stati degli schedulers
    for name, scheduler in schedulers_dict.items():
        checkpoint[name] = scheduler.state_dict()
    
    checkpoint_path = os.path.join(config["save_dir"], f"checkpoint_epoch_{epoch}.pth")
    torch.save(checkpoint, checkpoint_path)
    print(f"💾 Checkpoint saved: {checkpoint_path}")

# ================================================================
# INIZIALIZZAZIONE MODELLI
# ================================================================
print("🔧 Initializing models...")

style_encoder = StyleEncoder(
    cnn_out_dim=config["style_dim"],
    transformer_dim=config["style_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

content_encoder = ContentEncoder(
    cnn_out_dim=config["content_dim"],
    transformer_dim=config["content_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

discriminator = Discriminator(
    input_dim=config["style_dim"],
    hidden_dim=128
).to(device)

decoder = Decoder(
    d_model=config["style_dim"],
    nhead=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

# Applica inizializzazione conservativa
models = [style_encoder, content_encoder, discriminator, decoder]
model_names = ["style_encoder", "content_encoder", "discriminator", "decoder"]

for model, name in zip(models, model_names):
    model.apply(init_weights_conservative)
    print(f"✅ {name} initialized")

# ================================================================
# OTTIMIZZATORI E SCHEDULERS
# ================================================================
print("🔧 Setting up optimizers and schedulers...")

# Ottimizzatori con AdamW per maggiore stabilità
optimizer_G = optim.AdamW(
    list(style_encoder.parameters()) + 
    list(content_encoder.parameters()) + 
    list(decoder.parameters()),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

optimizer_D = optim.AdamW(
    discriminator.parameters(),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

# Schedulers per learning rate adattivo
scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_G, mode='min', factor=0.7, patience=5
)

scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_D, mode='min', factor=0.7, patience=5
)

# ================================================================
# DATALOADER
# ================================================================
print("🔧 Setting up dataloader...")

try:
    train_loader = get_dataloader(
        piano_dir=config["piano_dir"],
        violin_dir=config["violin_dir"],
        batch_size=config["batch_size"],
        shuffle=True,
        stats_path=config["stats_path"]
    )
    print(f"✅ DataLoader created successfully with batch_size={config['batch_size']}")
except Exception as e:
    print(f"❌ Error creating DataLoader: {e}")
    raise

# ================================================================
# FUNZIONI DI TRAINING
# ================================================================
def discriminator_training_step(x, labels, epoch):
    """
    Step di training per il discriminatore
    
    Args:
        x: Input batch
        labels: Labels del batch
        epoch: Epoca corrente
    
    Returns:
        float: Loss del discriminatore
    """
    # Abilita gradienti solo per discriminatore
    set_requires_grad(discriminator, True)
    set_requires_grad([style_encoder, content_encoder, decoder], False)
    
    try:
        # Forward pass senza gradienti per i generatori
        with torch.no_grad():
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
        
        # Controllo NaN negli embeddings
        if check_for_nan(style_emb, class_emb, content_emb, 
                         names=["style_emb", "class_emb", "content_emb"]):
            return float('nan')
        
        # Calcola loss del discriminatore
        disc_loss, _ = adversarial_loss(
            style_emb.detach(),
            class_emb.detach(),
            content_emb.detach(),
            discriminator,
            labels,
            compute_for_discriminator=True,
            lambda_content=config["lambda_adv_disc"]
        )
        
        # Controllo NaN nella loss
        if check_for_nan(disc_loss, names=["disc_loss"]):
            return float('nan')
        
        # Backpropagation
        optimizer_D.zero_grad()
        disc_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config["grad_clip_value"])
        
        # Controllo gradienti per NaN
        for name, param in discriminator.named_parameters():
            if param.grad is not None:
                if check_for_nan(param.grad, names=[f"discriminator.{name}.grad"]):
                    return float('nan')
        
        optimizer_D.step()
        return disc_loss.item()
        
    except Exception as e:
        print(f"❌ Error in discriminator training step: {e}")
        return float('nan')

def generator_training_step(x, labels, epoch):
    """
    Step di training per i generatori
    
    Args:
        x: Input batch
        labels: Labels del batch
        epoch: Epoca corrente
    
    Returns:
        dict: Dizionario con tutte le loss
    """
    # Abilita gradienti per i generatori
    set_requires_grad([style_encoder, content_encoder, decoder], True)
    set_requires_grad(discriminator, False)
    
    try:
        # Forward pass
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Controllo NaN negli embeddings
        if check_for_nan(style_emb, class_emb, content_emb,
                         names=["style_emb", "class_emb", "content_emb"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Calcola loss avversariale per il generatore
        _, adv_gen_loss = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=False,
            lambda_content=config["lambda_adv_gen"]
        )
        
        # Calcola loss di disentanglement
        disent_loss = disentanglement_loss(
            style_emb,
            content_emb.mean(dim=1),
            use_hsic=True
        )
        
        # Calcola loss contrastiva
        cont_loss = infoNCE_loss(style_emb, labels)
        
        # Calcola margin loss
        margin_loss_val = margin_loss(class_emb)
        
        # Calcola loss di ricostruzione
        stft_part = x[:, :, :, :, :513]
        recon_x = decoder(content_emb, style_emb, y=stft_part)
        recon_losses = compute_comprehensive_loss(recon_x, stft_part)
        recon_loss = recon_losses['total_loss']
        
        # Controllo NaN in tutte le loss
        losses = [adv_gen_loss, disent_loss, cont_loss, margin_loss_val, recon_loss]
        loss_names = ['adv_gen_loss', 'disent_loss', 'cont_loss', 'margin_loss', 'recon_loss']
        
        if check_for_nan(*losses, names=loss_names):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Calcola loss totale con warmup
        lr_multiplier = get_learning_rate_multiplier(epoch, config["warmup_epochs"])
        warmup_factor = lr_multiplier if epoch < config["warmup_epochs"] else 1.0
        
        total_G_loss = (
            config["lambda_adv_gen"] * adv_gen_loss * warmup_factor +
            config["lambda_disent"] * disent_loss * warmup_factor +
            config["lambda_cont"] * cont_loss * warmup_factor +
            config["lambda_margin"] * margin_loss_val * warmup_factor +
            config["lambda_recon"] * recon_loss
        )
        
        # Controllo NaN nella loss totale
        if check_for_nan(total_G_loss, names=["total_G_loss"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Backpropagation
        optimizer_G.zero_grad()
        total_G_loss.backward()
        
        # Gradient clipping
        generator_params = (
            list(style_encoder.parameters()) + 
            list(content_encoder.parameters()) +
            list(decoder.parameters())
        )
        torch.nn.utils.clip_grad_norm_(generator_params, config["grad_clip_value"])
        
        # Controllo gradienti per NaN
        for model, model_name in [(style_encoder, "style_encoder"), 
                                 (content_encoder, "content_encoder"), 
                                 (decoder, "decoder")]:
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if check_for_nan(param.grad, names=[f"{model_name}.{name}.grad"]):
                        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        optimizer_G.step()
        
        return {
            'total_G': total_G_loss.item(),
            'adv_gen': adv_gen_loss.item(),
            'disent': disent_loss.item(),
            'cont': cont_loss.item(),
            'margin': margin_loss_val.item(),
            'recon': recon_loss.item()
        }
        
    except Exception as e:
        print(f"❌ Error in generator training step: {e}")
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}

# ================================================================
# CICLO DI TRAINING PRINCIPALE
# ================================================================
print("🚀 Starting training with enhanced NaN protection...")
print(f"📊 Configuration: {config}")

# Inizializza strutture per il tracking delle loss
loss_history = {
    'total_G': [],
    'disc': [],
    'disent': [],
    'cont': [],
    'margin': [],
    'recon': [],
    'adv_gen': []
}

# Contatore per NaN consecutivi
consecutive_nan_count = 0
max_consecutive_nans = config["nan_threshold"]

# Dizionari per salvataggio
models_dict = {
    'style_encoder': style_encoder,
    'content_encoder': content_encoder,
    'discriminator': discriminator,
    'decoder': decoder
}

optimizers_dict = {
    'optimizer_G': optimizer_G,
    'optimizer_D': optimizer_D
}

schedulers_dict = {
    'scheduler_G': scheduler_G,
    'scheduler_D': scheduler_D
}

try:
    for epoch in tqdm(range(config["epochs"]), desc="Training Progress"):
        # Controllo early stopping per NaN
        if consecutive_nan_count >= max_consecutive_nans:
            print(f"🛑 Early stopping: {consecutive_nan_count} consecutive NaN occurrences")
            break
        
        # Imposta tutti i modelli in modalità training
        for model in models:
            model.train()
        
        # Tracking delle loss per l'epoca corrente
        epoch_losses = {key: [] for key in loss_history.keys()}
        
        print(f"\n🔄 Epoch {epoch+1}/{config['epochs']}")
        print("-" * 60)
        
        for batch_idx, (x, labels) in enumerate(train_loader):
            # Trasferisce dati su GPU
            x = x.to(device)
            labels = labels.to(device)
            
            # Controllo NaN nei dati di input
            if check_for_nan(x, labels, names=["input_x", "input_labels"]):
                print(f"⚠️  Skipping batch {batch_idx} due to NaN in input data")
                continue
            
            # ============================================
            # TRAINING DISCRIMINATORE
            # ============================================
            for _ in range(config["discriminator_steps"]):
                disc_loss = discriminator_training_step(x, labels, epoch)
                
                if not np.isnan(disc_loss):
                    epoch_losses['disc'].append(disc_loss)
                    consecutive_nan_count = 0  # Reset counter su successo
                else:
                    consecutive_nan_count += 1
                    print(f"⚠️  NaN in discriminator loss (consecutive: {consecutive_nan_count})")
            
            # ============================================
            # TRAINING GENERATORE
            # ============================================
            for _ in range(config["generator_steps"]):
                gen_losses = generator_training_step(x, labels, epoch)
                
                # Controlla se ci sono NaN nelle loss del generatore
                nan_in_gen = any(np.isnan(v) for v in gen_losses.values())
                
                if not nan_in_gen:
                    # Aggiungi loss valide alla storia
                    for key, value in gen_losses.items():
                        if key in epoch_losses:
                            epoch_losses[key].append(value)
                    consecutive_nan_count = 0  # Reset counter su successo
                else:
                    consecutive_nan_count += 1
                    print(f"⚠️  NaN in generator losses (consecutive: {consecutive_nan_count})")
            
            # ============================================
            # LOGGING PER BATCH
            # ============================================
            # Stampa loss dettagliate per ogni batch
            print(f"\n📊 Batch {batch_idx + 1}:")
            
            # Loss del discriminatore
            if epoch_losses['disc']:
                recent_disc = epoch_losses['disc'][-config["discriminator_steps"]:]
                avg_disc = np.mean(recent_disc)
                print(f"   Discriminator Loss: {avg_disc:.6f}")
            
            # Loss del generatore
            if epoch_losses['total_G']:
                recent_gen = epoch_losses['total_G'][-config["generator_steps"]:]
                avg_total_G = np.mean(recent_gen)
                print(f"   Generator Total Loss: {avg_total_G:.6f}")
            
            # Loss individuali del generatore
            gen_loss_names = ['adv_gen', 'disent', 'cont', 'margin', 'recon']
            for loss_name in gen_loss_names:
                if epoch_losses[loss_name]:
                    recent_loss = epoch_losses[loss_name][-config["generator_steps"]:]
                    avg_loss = np.mean(recent_loss)
                    print(f"     {loss_name.replace('_', ' ').title()}: {avg_loss:.6f}")
            
            # Informazioni aggiuntive
            print(f"   Consecutive NaN count: {consecutive_nan_count}")
            if torch.cuda.is_available():
                print(f"   GPU Memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
            
            # Controllo early stopping durante il batch
            if consecutive_nan_count >= max_consecutive_nans:
                print(f"🛑 Stopping epoch {epoch+1} due to consecutive NaN issues")
                break
        
        # ============================================
        # FINE EPOCA: AGGIORNAMENTI E SALVATAGGIO
        # ============================================
        if consecutive_nan_count < max_consecutive_nans:
            # Aggiorna schedulers
            if epoch_losses['disc']:
                scheduler_D.step(np.mean(epoch_losses['disc']))
            if epoch_losses['total_G']:
                scheduler_G.step(np.mean(epoch_losses['total_G']))
            
            # Aggiungi loss dell'epoca alla storia globale
            for key in loss_history.keys():
                if epoch_losses[key]:
                    loss_history[key].extend(epoch_losses[key])
            
            # Salva checkpoint periodicamente
            if (epoch + 1) % config["save_interval"] == 0:
                save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
            
            # Riepilogo dell'epoca
            print(f"\n✅ Epoch {epoch+1} Summary:")
            print("-" * 40)
            
            for key in ['disc', 'total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']:
                if epoch_losses[key]:
                    avg_loss = np.mean(epoch_losses[key])
                    print(f"   {key.replace('_', ' ').title()}: {avg_loss:.6f}")
            
            # Informazioni sui learning rates
            current_lr_G = optimizer_G.param_groups[0]['lr']
            current_lr_D = optimizer_D.param_groups[0]['lr']
            print(f"   LR Generator: {current_lr_G:.2e}")
            print(f"   LR Discriminator: {current_lr_D:.2e}")
            
            print("-" * 60)
        
        # Pulizia memoria periodica
        if torch.cuda.is_available() and epoch % 10 == 0:
            torch.cuda.empty_cache()

except KeyboardInterrupt:
    print("\n⏹️  Training interrupted by user")
except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Salvataggio finale
    print("\n💾 Saving final checkpoint...")
    save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
    
    # Visualizzazione delle loss
    if any(loss_history[key] for key in loss_history.keys()):
        print("📊 Generating loss plots...")
        
        # Configura plot
        plt.style.use('default')
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        # Definisci le loss da plottare
        loss_configs = [
            ('disc', 'Discriminator Loss', 'red'),
            ('total_G', 'Generator Total Loss', 'blue'),
            ('recon', 'Reconstruction Loss', 'green'),
            ('adv_gen', 'Adversarial Generator Loss', 'orange'),
            ('disent', 'Disentanglement Loss', 'purple'),
            ('cont', 'Contrastive Loss', 'brown')
        ]
        
        for i, (loss_name, title, color) in enumerate(loss_configs):
            if loss_history[loss_name]:
                values = loss_history[loss_name]
                
                # Plot raw values
                axes[i].plot(values, alpha=0.4, color=color, linewidth=0.5, label='Raw')
                
                # Plot smoothed values se ci sono abbastanza punti
                if len(values) > 50:
                    window = max(1, len(values) // 50)
                    smoothed = np.convolve(values, np.ones(window)/window, mode='valid')
                    axes[i].plot(smoothed, color=color, linewidth=2, label='Smoothed')
                
                axes[i].set_title(title, fontsize=12, fontweight='bold')
                axes[i].set_xlabel('Iteration')
                axes[i].set_ylabel('Loss')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
                axes[i].set_xlim(0, len(values))
        
        plt.tight_layout()
        plot_path = os.path.join(config["save_dir"], "loss_curves.png")
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"📈 Loss curves saved to: {plot_path}")
    
    # Riepilogo finale
    print(f"\n🎯 Training completed!")
    print(f"   Total epochs processed: {epoch + 1}")
    print(f"   Final consecutive NaN count: {consecutive_nan_count}")
    print(f"   Checkpoints saved in: {config['save_dir']}")
    
    if torch.cuda.is_available():
        print(f"   Final GPU memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
    
    print("🎉 Training session finished successfully!")

Using device: cpu
🔧 Initializing models...
✅ style_encoder initialized
✅ content_encoder initialized
✅ discriminator initialized
✅ decoder initialized
🔧 Setting up optimizers and schedulers...
🔧 Setting up dataloader...
✅ DataLoader created successfully with batch_size=8
🚀 Starting training with enhanced NaN protection...
📊 Configuration: {'style_dim': 256, 'content_dim': 256, 'transformer_heads': 4, 'transformer_layers': 4, 'cnn_channels': [16, 32, 64, 128, 256], 'epochs': 100, 'batch_size': 8, 'lr': 5e-05, 'beta1': 0.5, 'beta2': 0.999, 'weight_decay': 1e-05, 'lambda_adv_disc': 0.8, 'lambda_adv_gen': 0.1, 'lambda_disent': 0.3, 'lambda_cont': 0.2, 'lambda_margin': 0.1, 'lambda_recon': 2.0, 'grad_clip_value': 0.5, 'warmup_epochs': 5, 'nan_threshold': 5, 'piano_dir': 'dataset/train/piano', 'violin_dir': 'dataset/train/violin', 'stats_path': 'stats_stft_cqt.npz', 'save_dir': 'checkpoints', 'save_interval': 10, 'discriminator_steps': 1, 'generator_steps': 1}


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]


🔄 Epoch 1/100
------------------------------------------------------------

📊 Batch 1:
   Discriminator Loss: 1.593956
   Generator Total Loss: 7.046389
     Adv Gen: -0.693144
     Disent: 0.008900
     Cont: 1.943153
     Margin: 0.000000
     Recon: 3.490996
   Consecutive NaN count: 0

📊 Batch 2:
   Discriminator Loss: 1.593772
   Generator Total Loss: 6.424905
     Adv Gen: -0.693145
     Disent: 0.003024
     Cont: 1.902855
     Margin: 0.000000
     Recon: 3.181236
   Consecutive NaN count: 0

📊 Batch 3:
   Discriminator Loss: 1.593267
   Generator Total Loss: 11.224684
     Adv Gen: -0.693145
     Disent: 0.000838
     Cont: 1.906530
     Margin: 0.000000
     Recon: 5.581118
   Consecutive NaN count: 0

📊 Batch 4:
   Discriminator Loss: 1.593601
   Generator Total Loss: 7.152896
     Adv Gen: -0.693146
     Disent: 0.000788
     Cont: 1.964319
     Margin: 0.000000
     Recon: 3.544070
   Consecutive NaN count: 0

📊 Batch 5:
   Discriminator Loss: 1.592313
   Generator Total 

Training Progress:   0%|          | 0/100 [01:22<?, ?it/s]



⏹️  Training interrupted by user

💾 Saving final checkpoint...
💾 Checkpoint saved: checkpoints/checkpoint_epoch_1.pth

🎯 Training completed!
   Total epochs processed: 1
   Final consecutive NaN count: 0
   Checkpoints saved in: checkpoints
🎉 Training session finished successfully!
