In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import warnings
from utilityFunctions import *

from style_encoder import StyleEncoder
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss
from losses import (infoNCE_loss, margin_loss, adversarial_loss, 
                   disentanglement_loss)
from dataloader import get_dataloader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")


Using device: cpu


In [9]:
# train parameters
config = {
    "style_dim": 256,
    "content_dim": 256,
    "transformer_heads": 4,
    "transformer_layers": 4,
    "cnn_channels": [16, 32, 64, 128, 256],
    
    # Training params
    "epochs": 100,
    "batch_size": 16,          
    "lr_gen": 3e-5,
    "lr_disc": 1e-5,
    "beta1": 0.5,
    "beta2": 0.999,
    "weight_decay": 1e-4, 
    
    # losses weights
    "lambda_adv_disc": 1.0,
    "lambda_adv_gen": 0.5,
    "lambda_disent": 1.0,
    "lambda_cont": 0.5,
    "lambda_margin": 0.5,
    "lambda_recon": 5.0,
    
    # for stability
    "grad_clip_value": 0.5, 
    "warmup_epochs": 5,
    "nan_threshold": 5,          # Max NaN consecutivi
    
    # paths
    "piano_dir": "dataset/train/piano",
    "violin_dir": "dataset/train/violin",
    "stats_path": "stats_stft_cqt.npz",
    
    # save
    "save_dir": "checkpoints",
    "save_interval": 10,
    
    # training strategy ----> EXPERIMENT WITH DIFFERENT VALUES e.g. 10
    "discriminator_steps": 3,
    "generator_steps": 5,
}

os.makedirs(config["save_dir"], exist_ok=True)

In [10]:
# conservative weight initialization
def init_weights_conservative(m):
    """
    Inizializzazione conservativa dei pesi per prevenire NaN
    """
    if isinstance(m, nn.Conv2d):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

# utility functions

def check_for_nan(*tensors, names=None):
    """
    Controlla se ci sono NaN o Inf nei tensori
    
    Args:
        *tensors: Tensori da controllare
        names: Nomi dei tensori per il debug
    
    Returns:
        bool: True se trovati NaN/Inf
    """
    if names is None:
        names = [f"tensor_{i}" for i in range(len(tensors))]
    
    for tensor, name in zip(tensors, names):
        if torch.isnan(tensor).any():
            print(f"🚨 NaN detected in {name}")
            return True
        if torch.isinf(tensor).any():
            print(f"🚨 Inf detected in {name}")
            return True
    return False



def set_requires_grad(models, requires_grad):
    """
    Abilita/disabilita i gradienti per i modelli
    
    Args:
        models: Modello singolo o lista di modelli
        requires_grad: True per abilitare, False per disabilitare
    """
    if not isinstance(models, list):
        models = [models]
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad



def get_learning_rate_multiplier(epoch, warmup_epochs):
    """
    Calcola il moltiplicatore del learning rate per il warmup
    
    Args:
        epoch: Epoca corrente
        warmup_epochs: Numero di epoche di warmup
    
    Returns:
        float: Moltiplicatore del learning rate
    """
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    return 1.0

def save_checkpoint(epoch, models_dict, optimizers_dict, schedulers_dict, config):
    """
    Salva checkpoint completo
    
    Args:
        epoch: Epoca corrente
        models_dict: Dizionario dei modelli
        optimizers_dict: Dizionario degli ottimizzatori
        schedulers_dict: Dizionario degli schedulers
        config: Configurazione
    """
    checkpoint = {
        'epoch': epoch,
        'config': config
    }
    
    # Salva stati dei modelli
    for name, model in models_dict.items():
        checkpoint[name] = model.state_dict()
    
    # Salva stati degli ottimizzatori
    for name, optimizer in optimizers_dict.items():
        checkpoint[name] = optimizer.state_dict()
    
    # Salva stati degli schedulers
    for name, scheduler in schedulers_dict.items():
        checkpoint[name] = scheduler.state_dict()
    
    checkpoint_path = os.path.join(config["save_dir"], f"checkpoint_epoch_{epoch}.pth")
    torch.save(checkpoint, checkpoint_path)
    print(f"💾 Checkpoint saved: {checkpoint_path}")

In [11]:
# model initialization
print("Initializing models...")

style_encoder = StyleEncoder(
    cnn_out_dim=config["style_dim"],
    transformer_dim=config["style_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

content_encoder = ContentEncoder(
    cnn_out_dim=config["content_dim"],
    transformer_dim=config["content_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

discriminator = Discriminator(
    input_dim=config["style_dim"],
    hidden_dim=128
).to(device)

decoder = Decoder(
    d_model=config["style_dim"],
    nhead=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)


# Applica inizializzazione conservativa
models = [style_encoder, content_encoder, discriminator, decoder]
model_names = ["style_encoder", "content_encoder", "discriminator", "decoder"]

# inizializzazione conservativa
for model, name in zip(models, model_names):
    model.apply(init_weights_conservative)
    print(f"{name} initialized")


# inizializzazione regolare
# from style_encoder import initialize_weights
# initialize_weights(style_encoder)
# initialize_weights(content_encoder)
# initialize_weights(decoder)
# initialize_weights(discriminator)

Initializing models...
style_encoder initialized
content_encoder initialized
discriminator initialized
decoder initialized


In [12]:
# optimizers and schedulers
print("🔧 Setting up optimizers and schedulers...")

optimizer_G = optim.AdamW(
    list(style_encoder.parameters()) + 
    list(content_encoder.parameters()) + 
    list(decoder.parameters()),
    lr=config["lr_gen"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

optimizer_D = optim.AdamW(
    discriminator.parameters(),
    lr=config["lr_disc"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

# Schedulers
scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_G, mode='min', factor=0.7, patience=5
)

scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_D, mode='min', factor=0.7, patience=5
)

# dataloader
print("🔧 Setting up dataloader...")

try:
    train_loader = get_dataloader(
        piano_dir=config["piano_dir"],
        violin_dir=config["violin_dir"],
        batch_size=config["batch_size"],
        shuffle=True,
        stats_path=config["stats_path"]
    )
    print(f"✅ DataLoader created successfully with batch_size={config['batch_size']}")
except Exception as e:
    print(f"❌ Error creating DataLoader: {e}")
    raise

🔧 Setting up optimizers and schedulers...
🔧 Setting up dataloader...
✅ Loaded separate statistics:
  Piano: train_set_stats/stats_stft_cqt_piano.npz
  Violin: train_set_stats/stats_stft_cqt_violin.npz
✅ DataLoader created successfully with batch_size=16


## Regular training

In [13]:
# training steps for discriminator and generator
def discriminator_training_step(x, labels, epoch):

    # only discriminator requires gradients
    # set_requires_grad(discriminator, True)
    # set_requires_grad([style_encoder, content_encoder, decoder], False)
    
    try:
        # no gradient flow for generators
        with torch.no_grad():
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
        
        # NaN check
        if check_for_nan(style_emb, class_emb, content_emb, 
                         names=["style_emb", "class_emb", "content_emb"]):
            return float('nan')
        
        # detach embeddings to avoid gradient flow
        # removed detach
        disc_loss, _ = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=True,
            lambda_content=config["lambda_adv_disc"]
        )
        
        # NaN
        if check_for_nan(disc_loss, names=["disc_loss"]):
            return float('nan')
        
        # Backpropagation
        optimizer_D.zero_grad()
        disc_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config["grad_clip_value"])
        
        # Controllo gradienti per NaN
        for name, param in discriminator.named_parameters():
            if param.grad is not None:
                if check_for_nan(param.grad, names=[f"discriminator.{name}.grad"]):
                    return float('nan')
        
        optimizer_D.step()
        return disc_loss.item()
        
    except Exception as e:
        print(f"❌ Error in discriminator training step: {e}")
        return float('nan')



def generator_training_step(x, labels, epoch):

    # set_requires_grad([style_encoder, content_encoder, decoder], True)
    # set_requires_grad(discriminator, False)
    
    try:

        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        

        if check_for_nan(style_emb, class_emb, content_emb,
                         names=["style_emb", "class_emb", "content_emb"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        

        _, adv_gen_loss = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=False,
            lambda_content=config["lambda_adv_gen"]
        )
        
        # disentanglement
        disent_loss = disentanglement_loss(
            style_emb,
            content_emb.mean(dim=1),
            use_hsic=True
        )
        
        # contrastive losses
        cont_loss = infoNCE_loss(style_emb, labels)
        
        margin_loss_val = margin_loss(class_emb)
        
        # reconstruction loss
        # get only stft from input
        stft_part = x[:, :, :, :, :513]

        # duplicate class embedding from (2,d) to (B,d)
        B = content_emb.size(0)
        class_emb = class_emb.repeat_interleave(repeats=B//2, dim=0)

        # reconstruction loss
        recon_x = decoder(content_emb, class_emb, y=stft_part)
        recon_losses = compute_comprehensive_loss(recon_x, stft_part)
        recon_loss = recon_losses['total_loss']
        
        # NaN
        losses = [adv_gen_loss, disent_loss, cont_loss, margin_loss_val, recon_loss]
        loss_names = ['adv_gen_loss', 'disent_loss', 'cont_loss', 'margin_loss', 'recon_loss']
        
        if check_for_nan(*losses, names=loss_names):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # total loss with warmup
        lr_multiplier = get_learning_rate_multiplier(epoch, config["warmup_epochs"])
        warmup_factor = lr_multiplier if epoch < config["warmup_epochs"] else 1.0
        
        total_G_loss = (
            config["lambda_adv_gen"] * adv_gen_loss * warmup_factor +
            config["lambda_disent"] * disent_loss * warmup_factor +
            config["lambda_cont"] * cont_loss * warmup_factor +
            config["lambda_margin"] * margin_loss_val * warmup_factor +
            config["lambda_recon"] * recon_loss
        )
        
        if check_for_nan(total_G_loss, names=["total_G_loss"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Backpropagation
        optimizer_G.zero_grad()
        total_G_loss.backward()
        
        # Gradient clipping
        generator_params = (
            list(style_encoder.parameters()) + 
            list(content_encoder.parameters()) +
            list(decoder.parameters())
        )
        torch.nn.utils.clip_grad_norm_(generator_params, config["grad_clip_value"])
        
        # NaN
        for model, model_name in [(style_encoder, "style_encoder"), 
                                 (content_encoder, "content_encoder"), 
                                 (decoder, "decoder")]:
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if check_for_nan(param.grad, names=[f"{model_name}.{name}.grad"]):
                        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        optimizer_G.step()
        
        return {
            'total_G': total_G_loss.item(),
            'adv_gen': adv_gen_loss.item(),
            'disent': disent_loss.item(),
            'cont': cont_loss.item(),
            'margin': margin_loss_val.item(),
            'recon': recon_loss.item()
        }
        
    except Exception as e:
        print(f"❌ Error in generator training step: {e}")
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}

In [None]:
# main train loop
print("🚀 Starting training with enhanced NaN protection...")
print(f"📊 Configuration: {config}")


loss_history = {
    'total_G': [],
    'disc': [],
    'disent': [],
    'cont': [],
    'margin': [],
    'recon': [],
    'adv_gen': []
}


consecutive_nan_count = 0
max_consecutive_nans = config["nan_threshold"]


models_dict = {
    'style_encoder': style_encoder,
    'content_encoder': content_encoder,
    'discriminator': discriminator,
    'decoder': decoder
}

optimizers_dict = {
    'optimizer_G': optimizer_G,
    'optimizer_D': optimizer_D
}

schedulers_dict = {
    'scheduler_G': scheduler_G,
    'scheduler_D': scheduler_D
}

try:
    for epoch in tqdm(range(config["epochs"]), desc="Training Progress"):
        # early stopping check for NaN
        if consecutive_nan_count >= max_consecutive_nans:
            print(f"🛑 Early stopping: {consecutive_nan_count} consecutive NaN occurrences")
            break
        
        # set all models to training mode
        for model in models:
            model.train()
        
        epoch_losses = {key: [] for key in loss_history.keys()}
        
        print(f"\n🔄 Epoch {epoch+1}/{config['epochs']}")
        print("-" * 60)
        
        generator_steps = config["generator_steps"]
        disc_steps = config["discriminator_steps"]
            
        # to compute when to train generator and discriminator
        cycle_length = generator_steps + disc_steps
        
        for batch_idx, (x, labels) in enumerate(train_loader):
            x = x.to(device)
            labels = labels.to(device)
            
            # Controllo NaN nei dati di input
            if check_for_nan(x, labels, names=["input_x", "input_labels"]):
                print(f"⚠️  Skipping batch {batch_idx} due to NaN in input data")
                continue
            
            # to compute turns in the cycle
            step_in_cycle = batch_idx % cycle_length
            
            
            # ============================================
            # GENERATORS training
            # ============================================
            if step_in_cycle < generator_steps:
                gen_losses = generator_training_step(x, labels, epoch)
                
                # NaN
                nan_in_gen = any(np.isnan(v) for v in gen_losses.values())
                
                if not nan_in_gen:
                    for key, value in gen_losses.items():
                        if key in epoch_losses:
                            epoch_losses[key].append(value)
                    consecutive_nan_count = 0  # Reset counter su successo
                else:
                    consecutive_nan_count += 1
                    print(f"⚠️  NaN in generator losses (consecutive: {consecutive_nan_count})")           
            
            
            
            # ============================================
            # DISCRIMINATOR training
            # ============================================
            else:
                disc_loss = discriminator_training_step(x, labels, epoch)
                
                # NaN
                if not np.isnan(disc_loss):
                    epoch_losses['disc'].append(disc_loss)
                    consecutive_nan_count = 0  # Reset counter 
                else:
                    consecutive_nan_count += 1
                    print(f"⚠️  NaN in discriminator loss (consecutive: {consecutive_nan_count})")
            
            
            
            
            # ============================================
            # Batch Logging
            # ============================================
            if batch_idx % 1 == 0:
                print(f"\n📊 Batch {batch_idx + 1}:")
                
                # disc loss
                if epoch_losses['disc']:
                    recent_disc = epoch_losses['disc'][-config["discriminator_steps"]:]
                    avg_disc = np.mean(recent_disc)
                    print(f"   Discriminator Loss: {avg_disc:.6f}")
                
                # gen loss
                if epoch_losses['total_G']:
                    recent_gen = epoch_losses['total_G'][-config["generator_steps"]:]
                    avg_total_G = np.mean(recent_gen)
                    print(f"   Generator Total Loss: {avg_total_G:.6f}")
                
                # individual losses
                gen_loss_names = ['adv_gen', 'disent', 'cont', 'margin', 'recon']
                for loss_name in gen_loss_names:
                    if epoch_losses[loss_name]:
                        recent_loss = epoch_losses[loss_name][-config["generator_steps"]:]
                        avg_loss = np.mean(recent_loss)
                        print(f"     {loss_name.replace('_', ' ').title()}: {avg_loss:.6f}")
                

                print(f"   Consecutive NaN count: {consecutive_nan_count}")
                if torch.cuda.is_available():
                    print(f"   GPU Memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
            
            # NaN
            if consecutive_nan_count >= max_consecutive_nans:
                print(f"🛑 Stopping epoch {epoch+1} due to consecutive NaN issues")
                break
        
        # ============================================
        # update and save
        # ============================================
        if consecutive_nan_count < max_consecutive_nans:
            # update schedulers
            if epoch_losses['disc']:
                scheduler_D.step(np.mean(epoch_losses['disc']))
            if epoch_losses['total_G']:
                scheduler_G.step(np.mean(epoch_losses['total_G']))
            
            # add loss
            for key in loss_history.keys():
                if epoch_losses[key]:
                    loss_history[key].extend(epoch_losses[key])
            
            # save checkpoint
            if (epoch + 1) % config["save_interval"] == 0:
                save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
            
            # epoch summary
            print(f"\n✅ Epoch {epoch+1} Summary:")
            print("-" * 40)
            
            for key in ['disc', 'total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']:
                if epoch_losses[key]:
                    avg_loss = np.mean(epoch_losses[key])
                    print(f"   {key.replace('_', ' ').title()}: {avg_loss:.6f}")
            
            current_lr_G = optimizer_G.param_groups[0]['lr']
            current_lr_D = optimizer_D.param_groups[0]['lr']
            print(f"   LR Generator: {current_lr_G:.2e}")
            print(f"   LR Discriminator: {current_lr_D:.2e}")
            
            print("-" * 60)
        
        # memory clean
        if torch.cuda.is_available() and epoch % 10 == 0:
            torch.cuda.empty_cache()

except KeyboardInterrupt:
    print("\n⏹️  Training interrupted by user")
except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()
finally:
    # final save
    print("\n💾 Saving final checkpoint...")
    save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
    
    # Visualizzazione delle loss
    if any(loss_history[key] for key in loss_history.keys()):
        print("📊 Generating loss plots...")
        
        # Configura plot
        plt.style.use('default')
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        # Definisci le loss da plottare
        loss_configs = [
            ('disc', 'Discriminator Loss', 'red'),
            ('total_G', 'Generator Total Loss', 'blue'),
            ('recon', 'Reconstruction Loss', 'green'),
            ('adv_gen', 'Adversarial Generator Loss', 'orange'),
            ('disent', 'Disentanglement Loss', 'purple'),
            ('cont', 'Contrastive Loss', 'brown')
        ]
        
        for i, (loss_name, title, color) in enumerate(loss_configs):
            if loss_history[loss_name]:
                values = loss_history[loss_name]
                
                # Plot raw values
                axes[i].plot(values, alpha=0.4, color=color, linewidth=0.5, label='Raw')
                
                # Plot smoothed values se ci sono abbastanza punti
                if len(values) > 50:
                    window = max(1, len(values) // 50)
                    smoothed = np.convolve(values, np.ones(window)/window, mode='valid')
                    axes[i].plot(smoothed, color=color, linewidth=2, label='Smoothed')
                
                axes[i].set_title(title, fontsize=12, fontweight='bold')
                axes[i].set_xlabel('Iteration')
                axes[i].set_ylabel('Loss')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
                axes[i].set_xlim(0, len(values))
        
        plt.tight_layout()
        plot_path = os.path.join(config["save_dir"], "loss_curves.png")
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"📈 Loss curves saved to: {plot_path}")
    
    # Riepilogo finale
    print(f"\n🎯 Training completed!")
    print(f"   Total epochs processed: {epoch + 1}")
    print(f"   Final consecutive NaN count: {consecutive_nan_count}")
    print(f"   Checkpoints saved in: {config['save_dir']}")
    
    if torch.cuda.is_available():
        print(f"   Final GPU memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
    
    print("🎉 Training session finished successfully!")

🚀 Starting training with enhanced NaN protection...
📊 Configuration: {'style_dim': 256, 'content_dim': 256, 'transformer_heads': 4, 'transformer_layers': 4, 'cnn_channels': [16, 32, 64, 128, 256], 'epochs': 100, 'batch_size': 16, 'lr_gen': 3e-05, 'lr_disc': 1e-05, 'beta1': 0.5, 'beta2': 0.999, 'weight_decay': 0.0001, 'lambda_adv_disc': 1.0, 'lambda_adv_gen': 0.5, 'lambda_disent': 1.0, 'lambda_cont': 0.5, 'lambda_margin': 0.5, 'lambda_recon': 5.0, 'grad_clip_value': 0.5, 'warmup_epochs': 5, 'nan_threshold': 5, 'piano_dir': 'dataset/train/piano', 'violin_dir': 'dataset/train/violin', 'stats_path': 'stats_stft_cqt.npz', 'save_dir': 'checkpoints', 'save_interval': 10, 'discriminator_steps': 3, 'generator_steps': 5}


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]


🔄 Epoch 1/100
------------------------------------------------------------

📊 Batch 1:
   Generator Total Loss: 13.676488
     Adv Gen: -0.346590
     Disent: 0.000856
     Cont: 2.709802
     Margin: 0.085060
     Recon: 2.686298
   Consecutive NaN count: 0

📊 Batch 2:
   Generator Total Loss: 14.303087
     Adv Gen: -0.346758
     Disent: 0.001143
     Cont: 2.708168
     Margin: 0.047283
     Recon: 2.812398
   Consecutive NaN count: 0


## Train with curriculum

In [None]:
def get_curriculum_weights(epoch, total_epochs):
    progress = epoch / total_epochs
    
    # phase 1 (0-20%): only reconstruction
    if progress < 0.2:
        return {
            "lambda_recon": 10.0,
            "lambda_adv_gen": 0.0,
            "lambda_adv_disc": 0.0,
            "lambda_disent": 0.0,
            "lambda_cont": 0.0,
            "lambda_margin": 0.0
        }
    
    # phase 2 (20-40%): add contrastive
    elif progress < 0.4:
        return {
            "lambda_recon": 8.0,
            "lambda_adv_gen": 0.0,
            "lambda_adv_disc": 0.0,
            "lambda_disent": 0.0,
            "lambda_cont": 0.5,
            "lambda_margin": 0.5
        }
    
    # phase 2 (40-60%): add disentanglement
    elif progress < 0.6:
        return {
            "lambda_recon": 5.0,
            "lambda_adv_gen": 0.0,
            "lambda_adv_disc": 0.0,
            "lambda_disent": 0.5,
            "lambda_cont": 0.8,
            "lambda_margin": 1
        }
    
    # phase 4 (60-100%): introduce adversarial training gradually
    else:
        # gradual increase of adversarial loss
        adv_strength = min(1.0, (progress - 0.6) / 0.4)
        return {
            "lambda_recon": 3.0,
            "lambda_adv_gen": 0.5 * adv_strength,
            "lambda_adv_disc": 0.3 * adv_strength,
            "lambda_disent": 0.3,
            "lambda_cont": 0.8,
            "lambda_margin": 1
        }
    

def display_final_loss_summary(loss_history, config):
    """
    Visualizza il riepilogo finale delle loss e genera i grafici
    """
    print("📊 Generating final loss summary and plots...")
    
    # Controlla se ci sono loss da visualizzare
    if not any(loss_history[key] for key in loss_history.keys()):
        print("⚠️  No loss data to display")
        return
    
    # Configura plot
    plt.style.use('default')
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    # Definisci le loss da plottare
    loss_configs = [
        ('disc', 'Discriminator Loss', 'red'),
        ('total_G', 'Generator Total Loss', 'blue'),
        ('recon', 'Reconstruction Loss', 'green'),
        ('adv_gen', 'Adversarial Generator Loss', 'orange'),
        ('disent', 'Disentanglement Loss', 'purple'),
        ('cont', 'Contrastive Loss', 'brown')
    ]
    
    for i, (loss_name, title, color) in enumerate(loss_configs):
        if loss_history[loss_name]:
            values = loss_history[loss_name]
            
            # Plot raw values
            axes[i].plot(values, alpha=0.4, color=color, linewidth=0.5, label='Raw')
            
            # Plot smoothed values se ci sono abbastanza punti
            if len(values) > 50:
                window = max(1, len(values) // 50)
                smoothed = np.convolve(values, np.ones(window)/window, mode='valid')
                axes[i].plot(smoothed, color=color, linewidth=2, label='Smoothed')
            
            axes[i].set_title(title, fontsize=12, fontweight='bold')
            axes[i].set_xlabel('Iteration')
            axes[i].set_ylabel('Loss')
            axes[i].legend()
            axes[i].grid(True, alpha=0.3)
            axes[i].set_xlim(0, len(values))
    
    plt.tight_layout()
    plot_path = os.path.join(config["save_dir"], "curriculum_loss_curves.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"📈 Loss curves saved to: {plot_path}")

In [None]:
def discriminator_training_step_curriculum(x, labels, epoch, total_epochs):
    """
    Step di training per il discriminatore con curriculum learning
    """
    # Ottieni pesi curriculum
    curriculum_weights = get_curriculum_weights(epoch, total_epochs)
    
    # Salta training del discriminatore se non è ancora attivo
    if curriculum_weights["lambda_adv_disc"] == 0:
        return 0.0
    
    # Abilita gradienti solo per discriminatore
    # set_requires_grad(discriminator, True)
    # set_requires_grad([style_encoder, content_encoder, decoder], False)
    
    try:
        # Forward pass senza gradienti per i generatori
        with torch.no_grad():
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
        
        # Controllo NaN negli embeddings
        if check_for_nan(style_emb, class_emb, content_emb, 
                         names=["style_emb", "class_emb", "content_emb"]):
            return float('nan')
        
        # Calcola loss del discriminatore
        # removed detach from style,class,content embeddings
        disc_loss, _ = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=True,
            lambda_content=curriculum_weights["lambda_adv_disc"]
        )
        
        # Controllo NaN nella loss
        if check_for_nan(disc_loss, names=["disc_loss"]):
            return float('nan')
        
        # Backpropagation
        optimizer_D.zero_grad()
        disc_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config["grad_clip_value"])
        
        optimizer_D.step()
        
        return disc_loss.item()
        
    except Exception as e:
        print(f"❌ Error in discriminator training step: {e}")
        return float('nan')



def generator_training_step_curriculum(x, labels, epoch, total_epochs):
    """
    Step di training per i generatori con curriculum learning
    """
    # Ottieni pesi curriculum per l'epoca corrente
    curriculum_weights = get_curriculum_weights(epoch, total_epochs)
    
    # Abilita gradienti per i generatori
    # set_requires_grad([style_encoder, content_encoder, decoder], True)
    # set_requires_grad(discriminator, False)
    
    try:
        # Forward pass
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Controllo NaN
        if check_for_nan(style_emb, class_emb, content_emb,
                         names=["style_emb", "class_emb", "content_emb"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Calcola solo le loss attive nel curriculum
        losses = {}
        
        # Loss di ricostruzione (sempre attiva)
        if curriculum_weights["lambda_recon"] > 0:
            
            stft_part = x[:, :, :, :, :513]
            class_emb = class_emb.repeat_interleave(repeats=stft_part.size(0) // 2, dim=0)
            
            recon_x = decoder(content_emb, class_emb, y=stft_part)
            recon_losses = compute_comprehensive_loss(recon_x, stft_part)
            losses["recon"] = recon_losses['total_loss']
        
        # Loss di disentanglement
        if curriculum_weights["lambda_disent"] > 0:
            losses["disent"] = disentanglement_loss(
                style_emb, content_emb.mean(dim=1), use_hsic=True
            )
        
        # Loss contrastiva
        if curriculum_weights["lambda_cont"] > 0:
            losses["cont"] = infoNCE_loss(style_emb, labels)
        
        # Margin loss
        if curriculum_weights["lambda_margin"] > 0:
            losses["margin"] = margin_loss(class_emb)
        
        # Adversarial loss (solo quando attivata)
        if curriculum_weights["lambda_adv_gen"] > 0:
            _, losses["adv_gen"] = adversarial_loss(
                style_emb, class_emb, content_emb, discriminator, labels,
                compute_for_discriminator=False,
                lambda_content=curriculum_weights["lambda_adv_gen"]
            )
        
        # Controlla NaN in tutte le loss calcolate
        if check_for_nan(*losses.values(), names=list(losses.keys())):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Calcola loss totale usando i pesi del curriculum
        total_G_loss = sum(
            curriculum_weights[f"lambda_{key}"] * loss_val
            for key, loss_val in losses.items()
            if f"lambda_{key}" in curriculum_weights
        )
        
        # Backpropagation
        optimizer_G.zero_grad()
        total_G_loss.backward()
        
        # Gradient clipping
        generator_params = (
            list(style_encoder.parameters()) + 
            list(content_encoder.parameters()) +
            list(decoder.parameters())
        )
        torch.nn.utils.clip_grad_norm_(generator_params, config["grad_clip_value"])
        
        optimizer_G.step()
        
        # Prepara output con tutte le loss (0 se non attive)
        output_losses = {
            'total_G': total_G_loss.item(),
            'recon': losses.get("recon", torch.tensor(0.0)).item(),
            'disent': losses.get("disent", torch.tensor(0.0)).item(),
            'cont': losses.get("cont", torch.tensor(0.0)).item(),
            'margin': losses.get("margin", torch.tensor(0.0)).item(),
            'adv_gen': losses.get("adv_gen", torch.tensor(0.0)).item()
        }
        
        return output_losses
        
    except Exception as e:
        print(f"❌ Error in generator training step: {e}")
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}

In [None]:
def train_with_curriculum():
    """
    Training principale con curriculum learning integrato
    """
    print("🚀 Starting curriculum learning training...")
    
    # Inizializza strutture per il tracking delle loss
    loss_history = {
        'total_G': [],
        'disc': [],
        'disent': [],
        'cont': [],
        'margin': [],
        'recon': [],
        'adv_gen': []
    }
    
    # Contatore per NaN consecutivi
    consecutive_nan_count = 0
    max_consecutive_nans = config["nan_threshold"]
    
    # Dizionari per salvataggio
    models_dict = {
        'style_encoder': style_encoder,
        'content_encoder': content_encoder,
        'discriminator': discriminator,
        'decoder': decoder
    }
    
    optimizers_dict = {
        'optimizer_G': optimizer_G,
        'optimizer_D': optimizer_D
    }
    
    schedulers_dict = {
        'scheduler_G': scheduler_G,
        'scheduler_D': scheduler_D
    }
    
    try:
        for epoch in tqdm(range(config["epochs"]), desc="Curriculum Training"):
            # Controllo early stopping per NaN
            if consecutive_nan_count >= max_consecutive_nans:
                print(f"🛑 Early stopping: {consecutive_nan_count} consecutive NaN occurrences")
                break
            
            # Ottieni pesi curriculum per l'epoca corrente
            curriculum_weights = get_curriculum_weights(epoch, config["epochs"])
            
            # Imposta tutti i modelli in modalità training
            for model in models:
                model.train()
            
            # Tracking delle loss per l'epoca corrente
            epoch_losses = {key: [] for key in loss_history.keys()}
            
            # Stampa fase corrente del curriculum
            progress = epoch / config["epochs"]
            if progress < 0.2:
                phase = "Fase 1: Solo Ricostruzione"
            elif progress < 0.4:
                phase = "Fase 2: + Disentanglement"
            elif progress < 0.6:
                phase = "Fase 3: + Contrastive Learning"
            else:
                phase = "Fase 4: + Adversarial Training"
            
            print(f"\n🔄 Epoch {epoch+1}/{config['epochs']} - {phase}")
            print(f"📊 Curriculum weights: {curriculum_weights}")
            print("-" * 60)
            
            generator_steps = config["generator_steps"]
            disc_steps = config["discriminator_steps"]
            
            # to compute when to train generator and discriminator
            cycle_length = generator_steps + disc_steps
            
            for batch_idx, (x, labels) in enumerate(train_loader):
                # Trasferisce dati su GPU
                x = x.to(device)
                labels = labels.to(device)
                
                # Controllo NaN nei dati di input
                if check_for_nan(x, labels, names=["input_x", "input_labels"]):
                    print(f"⚠️  Skipping batch {batch_idx} due to NaN in input data")
                    continue
                
                # to compute turns in the cycle
                step_in_cycle = batch_idx % cycle_length
                
                # ============================================
                # TRAINING GENERATORE
                # ============================================
                if step_in_cycle < generator_steps:
                    gen_losses = generator_training_step_curriculum(x, labels, epoch, config["epochs"])
                    
                    # Controlla se ci sono NaN nelle loss del generatore
                    nan_in_gen = any(np.isnan(v) for v in gen_losses.values())
                    
                    if not nan_in_gen:
                        # Aggiungi loss valide alla storia
                        for key, value in gen_losses.items():
                            if key in epoch_losses:
                                epoch_losses[key].append(value)
                        consecutive_nan_count = 0
                    else:
                        consecutive_nan_count += 1
                        print(f"⚠️  NaN in generator losses (consecutive: {consecutive_nan_count})")
                        
                
                # ============================================
                # TRAINING DISCRIMINATORE (solo se attivo)
                # ============================================
                else:
                    if curriculum_weights["lambda_adv_disc"] > 0:
                        disc_loss = discriminator_training_step_curriculum(x, labels, epoch, config["epochs"])
                            
                        if not np.isnan(disc_loss):
                            epoch_losses['disc'].append(disc_loss)
                            consecutive_nan_count = 0
                        else:
                            consecutive_nan_count += 1
                            print(f"⚠️  NaN in discriminator loss (consecutive: {consecutive_nan_count})")
                

                
                # ============================================
                # LOGGING PER BATCH
                # ============================================
                if batch_idx % 10 == 0:
                    print(f"\n📊 Batch {batch_idx + 1} / {len(train_loader)}:")
                    
                    # Solo loss attive nel curriculum
                    if curriculum_weights["lambda_adv_disc"] > 0 and epoch_losses['disc']:
                        recent_disc = epoch_losses['disc'][-config["discriminator_steps"]:]
                        avg_disc = np.mean(recent_disc)
                        print(f"   Discriminator Loss: {avg_disc:.6f}")
                    
                    if epoch_losses['total_G']:
                        recent_gen = epoch_losses['total_G'][-config["generator_steps"]:]
                        avg_total_G = np.mean(recent_gen)
                        print(f"   Generator Total Loss: {avg_total_G:.6f}")
                    
                    # Loss individuali attive
                    for loss_name, weight_key in [
                        ('recon', 'lambda_recon'),
                        ('disent', 'lambda_disent'),
                        ('cont', 'lambda_cont'),
                        ('margin', 'lambda_margin'),
                        ('adv_gen', 'lambda_adv_gen')
                    ]:
                        if curriculum_weights[weight_key] > 0 and epoch_losses[loss_name]:
                            recent_loss = epoch_losses[loss_name][-config["generator_steps"]:]
                            avg_loss = np.mean(recent_loss)
                            print(f"     {loss_name.replace('_', ' ').title()}: {avg_loss:.6f}")
                
                # Controllo early stopping durante il batch
                if consecutive_nan_count >= max_consecutive_nans:
                    print(f"🛑 Stopping epoch {epoch+1} due to consecutive NaN issues")
                    break
            
            # ============================================
            # FINE EPOCA: AGGIORNAMENTI E SALVATAGGIO
            # ============================================
            if consecutive_nan_count < max_consecutive_nans:
                # Aggiorna schedulers
                if epoch_losses['disc']:
                    scheduler_D.step(np.mean(epoch_losses['disc']))
                if epoch_losses['total_G']:
                    scheduler_G.step(np.mean(epoch_losses['total_G']))
                
                # Aggiungi loss dell'epoca alla storia globale
                for key in loss_history.keys():
                    if epoch_losses[key]:
                        loss_history[key].extend(epoch_losses[key])
                
                # Salva checkpoint periodicamente
                if (epoch + 1) % config["save_interval"] == 0:
                    save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
                
                # Riepilogo dell'epoca
                print(f"\n✅ Epoch {epoch+1} Summary:")
                print("-" * 40)
                
                # Mostra solo loss attive
                for key, weight_key in [
                    ('disc', 'lambda_adv_disc'),
                    ('total_G', None),
                    ('recon', 'lambda_recon'),
                    ('disent', 'lambda_disent'),
                    ('cont', 'lambda_cont'),
                    ('margin', 'lambda_margin'),
                    ('adv_gen', 'lambda_adv_gen')
                ]:
                    if (weight_key is None or curriculum_weights[weight_key] > 0) and epoch_losses[key]:
                        avg_loss = np.mean(epoch_losses[key])
                        print(f"   {key.replace('_', ' ').title()}: {avg_loss:.6f}")
                
                print("-" * 60)
            
            # Pulizia memoria periodica
            if torch.cuda.is_available() and epoch % 10 == 0:
                torch.cuda.empty_cache()
    
    except KeyboardInterrupt:
        print("\n⏹️  Training interrupted by user")
        print("\n💾 Saving final checkpoint...")
        save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
        
        display_final_loss_summary(loss_history, config)
        print(f"   Checkpoints saved in: {config['save_dir']}")
    except Exception as e:
        print(f"\n❌ Training error: {e}")
        print("\n💾 Saving final checkpoint...")
        save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
        
        display_final_loss_summary(loss_history, config)
        print(f"   Checkpoints saved in: {config['save_dir']}")
        import traceback
        traceback.print_exc()
    finally:
        # Salvataggio finale
        print("\n💾 Saving final checkpoint...")
        save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
        
        display_final_loss_summary(loss_history, config)
        print(f"   Checkpoints saved in: {config['save_dir']}")
        print("🎉 Curriculum learning training completed!")

        for loss_name in ['disc', 'total_G', 'recon', 'adv_gen', 'disent', 'cont', 'margin']:
            if loss_history[loss_name]:
                values = loss_history[loss_name]
                final_avg = np.mean(values[-100:]) if len(values) >= 100 else np.mean(values)
                overall_avg = np.mean(values)
                min_val = np.min(values)
                max_val = np.max(values)
                
                print(f"   {loss_name.replace('_', ' ').title()}:")
                print(f"     Final Average (last 100): {final_avg:.6f}")
                print(f"     Overall Average: {overall_avg:.6f}")
                print(f"     Min: {min_val:.6f}, Max: {max_val:.6f}")
                print(f"     Total iterations: {len(values)}")

        return loss_history

# ================================================================
# AVVIO DEL TRAINING CON CURRICULUM LEARNING
# ================================================================

# Sostituisci il loop di training originale con questo:
loss_history = train_with_curriculum()

## Small test

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import os

def visualize_reconstruction(models_dict, train_loader, device, num_samples=2, save_path=None):
    """
    Visualizza e confronta originale vs ricostruzione per alcuni campioni casuali.
    
    Args:
        models_dict: dizionario con i modelli {'style_encoder': ..., 'content_encoder': ..., 'decoder': ...}
        train_loader: DataLoader per il training
        device: dispositivo (cuda/cpu)
        num_samples: numero di campioni da visualizzare
        save_path: path dove salvare le immagini (opzionale)
    """
    
    # Estrai i modelli
    style_encoder = models_dict['style_encoder']
    content_encoder = models_dict['content_encoder']
    decoder = models_dict['decoder']
    
    # Metti i modelli in modalità eval
    style_encoder.eval()
    content_encoder.eval()
    decoder.eval()
    
    print(f"🎯 Visualizing {num_samples} reconstruction examples...")
    
    with torch.no_grad():
        # Prendi un batch casuale
        batch_iterator = iter(train_loader)
        x, labels = next(batch_iterator)
        
        # Trasferisci su device
        x = x.to(device)
        labels = labels.to(device)
        
        print(f"📊 Batch shape: {x.shape}")
        print(f"📋 Labels: {labels}")
        
        # Limita al numero di campioni richiesto
        x = x[:num_samples]
        labels = labels[:num_samples]
        
        # Forward pass attraverso i modelli
        print("🔄 Forward pass attraverso i modelli...")
        
        # Encoding
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        print(f"   Style embedding shape: {style_emb.shape}")
        print(f"   Class embedding shape: {class_emb.shape}")
        print(f"   Content embedding shape: {content_emb.shape}")
        
        # Estrai la parte STFT (prime 513 frequenze)
        stft_original = x[:, :, :, :, :513]  # (batch, sections, channels, time, freq)
        
        # Ricostruzione
        print("🔧 Generating reconstruction...")
        # duplicate class embedding from (2,d) to (B,d)
        B = content_emb.size(0)
        class_emb = class_emb.repeat(B, 1)
        reconstructed = decoder(content_emb, class_emb, y=stft_original)
        
        print(f"   Original STFT shape: {stft_original.shape}")
        print(f"   Reconstructed shape: {reconstructed.shape}")
        
        # Visualizza ogni campione
        for i in range(num_samples):
            print(f"\n📈 Visualizing sample {i+1}/{num_samples} (Label: {labels[i].item()})")
            
            # Prendi una sezione casuale (es. la prima)
            section_idx = 0
            
            original_section = stft_original[i, section_idx]  # (channels, time, freq)
            reconstructed_section = reconstructed[i, section_idx]  # (channels, time, freq)
            
            print(f"   Section {section_idx} - Original shape: {original_section.shape}")
            print(f"   Section {section_idx} - Reconstructed shape: {reconstructed_section.shape}")
            
            # Visualizza originale
            print(f"   📊 Original STFT - Sample {i+1}")
            plot_stft(original_section, log_scale=True)
            
            # Visualizza ricostruzione
            print(f"   🔧 Reconstructed STFT - Sample {i+1}")
            plot_stft(reconstructed_section, log_scale=True)
            
            # Calcola e visualizza la differenza
            print(f"   📉 Reconstruction Error - Sample {i+1}")
            plot_reconstruction_error(original_section, reconstructed_section)
            
            # Salva se richiesto
            if save_path:
                save_comparison_plots(original_section, reconstructed_section, 
                                    labels[i].item(), i, save_path)


def plot_reconstruction_error(original, reconstructed, sr=22050, hop_length=256):
    """
    Visualizza l'errore di ricostruzione.
    
    Args:
        original: tensor originale (channels, time, freq)
        reconstructed: tensor ricostruito (channels, time, freq)
        sr: sample rate
        hop_length: hop length per STFT
    """
    
    # Calcola magnitude per entrambi
    orig_real, orig_imag = original[0], original[1]
    recon_real, recon_imag = reconstructed[0], reconstructed[1]
    
    orig_magnitude = torch.hypot(orig_real, orig_imag)
    recon_magnitude = torch.hypot(recon_real, recon_imag)
    
    # Errore assoluto
    error = torch.abs(orig_magnitude - recon_magnitude)
    
    # Errore relativo (in dB)
    error_db = 20 * torch.log10(error + 1e-8)
    
    # Visualizza
    plt.figure(figsize=(12, 8))
    
    # Subplot 1: Errore assoluto
    plt.subplot(2, 2, 1)
    plt.imshow(
        error.T.cpu().numpy(),
        origin='lower',
        aspect='auto',
        extent=[0, original.shape[1] * hop_length / sr, 0, sr/2]
    )
    plt.colorbar(label='Absolute Error')
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    plt.title('Absolute Reconstruction Error')
    
    # Subplot 2: Errore in dB
    plt.subplot(2, 2, 2)
    plt.imshow(
        error_db.T.cpu().numpy(),
        origin='lower',
        aspect='auto',
        extent=[0, original.shape[1] * hop_length / sr, 0, sr/2]
    )
    plt.colorbar(label='Error (dB)')
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    plt.title('Reconstruction Error (dB)')
    
    # Subplot 3: Istogramma degli errori
    plt.subplot(2, 2, 3)
    error_flat = error.flatten().cpu().numpy()
    plt.hist(error_flat, bins=50, alpha=0.7, color='red')
    plt.xlabel('Absolute Error')
    plt.ylabel('Frequency')
    plt.title('Error Distribution')
    plt.yscale('log')
    
    # Subplot 4: Statistiche
    plt.subplot(2, 2, 4)
    plt.axis('off')
    
    # Calcola statistiche
    mae = torch.mean(error).item()
    mse = torch.mean(error**2).item()
    rmse = torch.sqrt(torch.mean(error**2)).item()
    max_error = torch.max(error).item()
    
    # Calcola SNR
    signal_power = torch.mean(orig_magnitude**2)
    noise_power = torch.mean(error**2)
    snr_db = 10 * torch.log10(signal_power / noise_power).item()
    
    stats_text = f"""
    Reconstruction Statistics:
    
    MAE: {mae:.6f}
    MSE: {mse:.6f}
    RMSE: {rmse:.6f}
    Max Error: {max_error:.6f}
    
    SNR: {snr_db:.2f} dB
    
    Original Range: [{torch.min(orig_magnitude):.4f}, {torch.max(orig_magnitude):.4f}]
    Reconstructed Range: [{torch.min(recon_magnitude):.4f}, {torch.max(recon_magnitude):.4f}]
    """
    
    plt.text(0.1, 0.9, stats_text, transform=plt.gca().transAxes, 
             fontsize=10, verticalalignment='top', fontfamily='monospace')
    
    plt.tight_layout()
    plt.show()


def save_comparison_plots(original, reconstructed, label, sample_idx, save_path):
    """
    Salva i plot di confronto su file.
    
    Args:
        original: tensor originale
        reconstructed: tensor ricostruito
        label: etichetta del campione
        sample_idx: indice del campione
        save_path: directory dove salvare
    """
    
    os.makedirs(save_path, exist_ok=True)
    
    # Calcola magnitudes
    orig_real, orig_imag = original[0], original[1]
    recon_real, recon_imag = reconstructed[0], reconstructed[1]
    
    orig_magnitude = torch.hypot(orig_real, orig_imag)
    recon_magnitude = torch.hypot(recon_real, recon_imag)
    
    # Converti in dB
    orig_db = 20 * torch.log10(orig_magnitude + 1e-8)
    recon_db = 20 * torch.log10(recon_magnitude + 1e-8)
    
    # Crea figura comparativa
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Originale
    im1 = axes[0].imshow(orig_db.T.cpu().numpy(), origin='lower', aspect='auto')
    axes[0].set_title(f'Original (Label: {label})')
    axes[0].set_xlabel('Time')
    axes[0].set_ylabel('Frequency')
    plt.colorbar(im1, ax=axes[0], label='Magnitude (dB)')
    
    # Ricostruzione
    im2 = axes[1].imshow(recon_db.T.cpu().numpy(), origin='lower', aspect='auto')
    axes[1].set_title(f'Reconstructed (Label: {label})')
    axes[1].set_xlabel('Time')
    axes[1].set_ylabel('Frequency')
    plt.colorbar(im2, ax=axes[1], label='Magnitude (dB)')
    
    # Differenza
    diff = torch.abs(orig_db - recon_db)
    im3 = axes[2].imshow(diff.T.cpu().numpy(), origin='lower', aspect='auto')
    axes[2].set_title(f'Absolute Difference (Label: {label})')
    axes[2].set_xlabel('Time')
    axes[2].set_ylabel('Frequency')
    plt.colorbar(im3, ax=axes[2], label='|Difference| (dB)')
    
    plt.tight_layout()
    
    # Salva
    filename = f'reconstruction_comparison_sample_{sample_idx}_label_{label}.png'
    filepath = os.path.join(save_path, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"   💾 Saved comparison plot: {filepath}")


def audio_reconstruction_test(models_dict, train_loader, device, num_samples=1):
    """
    Test completo che include anche la ricostruzione audio.
    
    Args:
        models_dict: dizionario con i modelli
        train_loader: DataLoader
        device: dispositivo
        num_samples: numero di campioni da testare
    """
    
    print("🎵 Testing audio reconstruction pipeline...")
    
    # Visualizza ricostruzioni spettrali
    visualize_reconstruction(models_dict, train_loader, device, num_samples)
    
    # Test ricostruzione audio (se hai le funzioni inverse)
    print("\n🔊 Testing audio reconstruction...")
    
    style_encoder = models_dict['style_encoder']
    content_encoder = models_dict['content_encoder']
    decoder = models_dict['decoder']
    
    style_encoder.eval()
    content_encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        batch_iterator = iter(train_loader)
        x, labels = next(batch_iterator)
        
        x = x.to(device)
        labels = labels.to(device)
        
        # Prendi un campione
        sample = x[0:1]  # (1, sections, channels, time, freq)
        sample_label = labels[0:1]
        
        # Forward pass
        style_emb, class_emb = style_encoder(sample, sample_label)
        content_emb = content_encoder(sample)
        
        # Ricostruzione
        stft_original = sample[:, :, :, :, :513]

        # duplicate class embedding from (2,d) to (B,d)
        B = content_emb.size(0)
        class_emb = class_emb.repeat_interleave(repeats=B//2, dim=0)
        reconstructed = decoder(content_emb, class_emb, y=stft_original)
        
        print(f"🎯 Audio reconstruction test completed!")
        print(f"   Original shape: {stft_original.shape}")
        print(f"   Reconstructed shape: {reconstructed.shape}")
        
        # Qui potresti aggiungere la ricostruzione audio vera e propria
        # usando le tue funzioni inverse_STFT e sections2spectrogram
        
        return {
            'original': stft_original,
            'reconstructed': reconstructed,
            'style_emb': style_emb,
            'content_emb': content_emb,
            'label': sample_label
        }


# Esempio di utilizzo:
"""
# Dopo aver caricato i modelli
models_dict = {
    'style_encoder': style_encoder,
    'content_encoder': content_encoder,
    'decoder': decoder
}

# Visualizza ricostruzioni
visualize_reconstruction(models_dict, train_loader, device, num_samples=3)

# Test completo con salvataggio
visualize_reconstruction(models_dict, train_loader, device, 
                        num_samples=2, save_path="reconstruction_results")

# Test audio completo
results = audio_reconstruction_test(models_dict, train_loader, device)
"""

In [None]:
models_dict = {
    'style_encoder': style_encoder,
    'content_encoder': content_encoder,
    'decoder': decoder
}

train_loader = get_dataloader(
        piano_dir=config["piano_dir"],
        violin_dir=config["violin_dir"],
        batch_size=config["batch_size"],
        shuffle=True,
        stats_path=config["stats_path"]
    )

# Visualizza 3 campioni casuali
visualize_reconstruction(models_dict, train_loader, device, num_samples=6)
