In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
import warnings

from style_encoder import StyleEncoder
from content_encoder import ContentEncoder
from discriminator import Discriminator
from new_decoder import Decoder, compute_comprehensive_loss
from losses import (infoNCE_loss, margin_loss, adversarial_loss, 
                   disentanglement_loss)
from Dataloader import get_dataloader

# ================================================================
# CONFIGURAZIONE DISPOSITIVO E MEMORIA
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")


Using device: cuda
GPU: NVIDIA GeForce RTX 3080


In [17]:

# ================================================================
# CONFIGURAZIONE TRAINING
# ================================================================
config = {
    # Architettura - parametri originali richiesti
    "style_dim": 256,
    "content_dim": 256,
    "transformer_heads": 4,
    "transformer_layers": 4,
    "cnn_channels": [16, 32, 64, 128, 256],
    
    # Training - parametri conservativi per stabilità
    "epochs": 100,
    "batch_size": 8,          # Ridotto per evitare OOM
    "lr": 3e-5,              # Learning rate conservativo
    "beta1": 0.5,
    "beta2": 0.999,
    "weight_decay": 1e-4,    # Weight decay per prevenire overfitting
    
    # Pesi delle loss - bilanciati per stabilità
    "lambda_adv_disc": 0.8,
    "lambda_adv_gen": 0.1,
    "lambda_disent": 0.3,
    "lambda_cont": 0.2,
    "lambda_margin": 0.1,
    "lambda_recon": 2.0,
    
    # Controlli di stabilità
    "grad_clip_value": 0.5,      # Gradient clipping conservativo
    "warmup_epochs": 5,          # Epochs di warmup
    "nan_threshold": 5,          # Max NaN consecutivi prima di fermarsi
    
    # Percorsi dati
    "piano_dir": "dataset/train/piano",
    "violin_dir": "dataset/train/violin",
    "stats_path": "stats_stft_cqt.npz",
    
    # Salvataggio
    "save_dir": "checkpoints",
    "save_interval": 10,
    
    # Strategia di training
    "discriminator_steps": 2,
    "generator_steps": 1,
}

# Crea directory di salvataggio
os.makedirs(config["save_dir"], exist_ok=True)


In [18]:

# ================================================================
# INIZIALIZZAZIONE PESI CONSERVATIVA
# ================================================================
def init_weights_conservative(m):
    """
    Inizializzazione conservativa dei pesi per prevenire NaN
    """
    if isinstance(m, nn.Conv2d):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        # Xavier uniforme con gain ridotto
        nn.init.xavier_uniform_(m.weight, gain=0.2)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

# ================================================================
# FUNZIONI DI UTILITÀ
# ================================================================
def check_for_nan(*tensors, names=None):
    """
    Controlla se ci sono NaN o Inf nei tensori
    
    Args:
        *tensors: Tensori da controllare
        names: Nomi dei tensori per il debug
    
    Returns:
        bool: True se trovati NaN/Inf
    """
    if names is None:
        names = [f"tensor_{i}" for i in range(len(tensors))]
    
    for tensor, name in zip(tensors, names):
        if torch.isnan(tensor).any():
            print(f"🚨 NaN detected in {name}")
            return True
        if torch.isinf(tensor).any():
            print(f"🚨 Inf detected in {name}")
            return True
    return False

def set_requires_grad(models, requires_grad):
    """
    Abilita/disabilita i gradienti per i modelli
    
    Args:
        models: Modello singolo o lista di modelli
        requires_grad: True per abilitare, False per disabilitare
    """
    if not isinstance(models, list):
        models = [models]
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

def get_learning_rate_multiplier(epoch, warmup_epochs):
    """
    Calcola il moltiplicatore del learning rate per il warmup
    
    Args:
        epoch: Epoca corrente
        warmup_epochs: Numero di epoche di warmup
    
    Returns:
        float: Moltiplicatore del learning rate
    """
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    return 1.0

def save_checkpoint(epoch, models_dict, optimizers_dict, schedulers_dict, config):
    """
    Salva checkpoint completo
    
    Args:
        epoch: Epoca corrente
        models_dict: Dizionario dei modelli
        optimizers_dict: Dizionario degli ottimizzatori
        schedulers_dict: Dizionario degli schedulers
        config: Configurazione
    """
    checkpoint = {
        'epoch': epoch,
        'config': config
    }
    
    # Salva stati dei modelli
    for name, model in models_dict.items():
        checkpoint[name] = model.state_dict()
    
    # Salva stati degli ottimizzatori
    for name, optimizer in optimizers_dict.items():
        checkpoint[name] = optimizer.state_dict()
    
    # Salva stati degli schedulers
    for name, scheduler in schedulers_dict.items():
        checkpoint[name] = scheduler.state_dict()
    
    checkpoint_path = os.path.join(config["save_dir"], f"checkpoint_epoch_{epoch}.pth")
    torch.save(checkpoint, checkpoint_path)
    print(f"💾 Checkpoint saved: {checkpoint_path}")

In [19]:

# ================================================================
# INIZIALIZZAZIONE MODELLI
# ================================================================
print("🔧 Initializing models...")

style_encoder = StyleEncoder(
    cnn_out_dim=config["style_dim"],
    transformer_dim=config["style_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

content_encoder = ContentEncoder(
    cnn_out_dim=config["content_dim"],
    transformer_dim=config["content_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

discriminator = Discriminator(
    input_dim=config["style_dim"],
    hidden_dim=128
).to(device)

decoder = Decoder(
    d_model=config["style_dim"],
    nhead=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

# Applica inizializzazione conservativa
models = [style_encoder, content_encoder, discriminator, decoder]
model_names = ["style_encoder", "content_encoder", "discriminator", "decoder"]

for model, name in zip(models, model_names):
    model.apply(init_weights_conservative)
    print(f"✅ {name} initialized")


# ottimizzazione regolare
# from style_encoder import initialize_weights
# initialize_weights(style_encoder)
# initialize_weights(content_encoder)
# initialize_weights(decoder)
# initialize_weights(discriminator)

🔧 Initializing models...
✅ style_encoder initialized
✅ content_encoder initialized
✅ discriminator initialized
✅ decoder initialized


In [20]:
# ================================================================
# OTTIMIZZATORI E SCHEDULERS
# ================================================================
print("🔧 Setting up optimizers and schedulers...")

# Ottimizzatori con AdamW per maggiore stabilità
optimizer_G = optim.AdamW(
    list(style_encoder.parameters()) + 
    list(content_encoder.parameters()) + 
    list(decoder.parameters()),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

optimizer_D = optim.AdamW(
    discriminator.parameters(),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"]),
    weight_decay=config["weight_decay"]
)

# Schedulers per learning rate adattivo
scheduler_G = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_G, mode='min', factor=0.7, patience=5
)

scheduler_D = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_D, mode='min', factor=0.7, patience=5
)

# ================================================================
# DATALOADER
# ================================================================
print("🔧 Setting up dataloader...")

try:
    train_loader = get_dataloader(
        piano_dir=config["piano_dir"],
        violin_dir=config["violin_dir"],
        batch_size=config["batch_size"],
        shuffle=True,
        stats_path=config["stats_path"]
    )
    print(f"✅ DataLoader created successfully with batch_size={config['batch_size']}")
except Exception as e:
    print(f"❌ Error creating DataLoader: {e}")
    raise


🔧 Setting up optimizers and schedulers...
🔧 Setting up dataloader...
✅ DataLoader created successfully with batch_size=8


In [21]:

# ================================================================
# FUNZIONI DI TRAINING
# ================================================================
def discriminator_training_step(x, labels, epoch):
    """
    Step di training per il discriminatore
    
    Args:
        x: Input batch
        labels: Labels del batch
        epoch: Epoca corrente
    
    Returns:
        float: Loss del discriminatore
    """
    # Abilita gradienti solo per discriminatore
    set_requires_grad(discriminator, True)
    set_requires_grad([style_encoder, content_encoder, decoder], False)
    
    try:
        # Forward pass senza gradienti per i generatori
        with torch.no_grad():
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
        
        # Controllo NaN negli embeddings
        if check_for_nan(style_emb, class_emb, content_emb, 
                         names=["style_emb", "class_emb", "content_emb"]):
            return float('nan')
        
        # Calcola loss del discriminatore
        disc_loss, _ = adversarial_loss(
            style_emb.detach(),
            class_emb.detach(),
            content_emb.detach(),
            discriminator,
            labels,
            compute_for_discriminator=True,
            lambda_content=config["lambda_adv_disc"]
        )
        
        # Controllo NaN nella loss
        if check_for_nan(disc_loss, names=["disc_loss"]):
            return float('nan')
        
        # Backpropagation
        optimizer_D.zero_grad()
        disc_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config["grad_clip_value"])
        
        # Controllo gradienti per NaN
        for name, param in discriminator.named_parameters():
            if param.grad is not None:
                if check_for_nan(param.grad, names=[f"discriminator.{name}.grad"]):
                    return float('nan')
        
        optimizer_D.step()
        return disc_loss.item()
        
    except Exception as e:
        print(f"❌ Error in discriminator training step: {e}")
        return float('nan')

def generator_training_step(x, labels, epoch):
    """
    Step di training per i generatori
    
    Args:
        x: Input batch
        labels: Labels del batch
        epoch: Epoca corrente
    
    Returns:
        dict: Dizionario con tutte le loss
    """
    # Abilita gradienti per i generatori
    set_requires_grad([style_encoder, content_encoder, decoder], True)
    set_requires_grad(discriminator, False)
    
    try:
        # Forward pass
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Controllo NaN negli embeddings
        if check_for_nan(style_emb, class_emb, content_emb,
                         names=["style_emb", "class_emb", "content_emb"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Calcola loss avversariale per il generatore
        _, adv_gen_loss = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=False,
            lambda_content=config["lambda_adv_gen"]
        )
        
        # Calcola loss di disentanglement
        disent_loss = disentanglement_loss(
            style_emb,
            content_emb.mean(dim=1),
            use_hsic=True
        )
        
        # Calcola loss contrastiva
        cont_loss = infoNCE_loss(style_emb, labels)
        
        # Calcola margin loss
        margin_loss_val = margin_loss(class_emb)
        
        # Calcola loss di ricostruzione
        stft_part = x[:, :, :, :, :513]
        recon_x = decoder(content_emb, style_emb, y=stft_part)
        recon_losses = compute_comprehensive_loss(recon_x, stft_part)
        recon_loss = recon_losses['total_loss']
        
        # Controllo NaN in tutte le loss
        losses = [adv_gen_loss, disent_loss, cont_loss, margin_loss_val, recon_loss]
        loss_names = ['adv_gen_loss', 'disent_loss', 'cont_loss', 'margin_loss', 'recon_loss']
        
        if check_for_nan(*losses, names=loss_names):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Calcola loss totale con warmup
        lr_multiplier = get_learning_rate_multiplier(epoch, config["warmup_epochs"])
        warmup_factor = lr_multiplier if epoch < config["warmup_epochs"] else 1.0
        
        total_G_loss = (
            config["lambda_adv_gen"] * adv_gen_loss * warmup_factor +
            config["lambda_disent"] * disent_loss * warmup_factor +
            config["lambda_cont"] * cont_loss * warmup_factor +
            config["lambda_margin"] * margin_loss_val * warmup_factor +
            config["lambda_recon"] * recon_loss
        )
        
        # Controllo NaN nella loss totale
        if check_for_nan(total_G_loss, names=["total_G_loss"]):
            return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        # Backpropagation
        optimizer_G.zero_grad()
        total_G_loss.backward()
        
        # Gradient clipping
        generator_params = (
            list(style_encoder.parameters()) + 
            list(content_encoder.parameters()) +
            list(decoder.parameters())
        )
        torch.nn.utils.clip_grad_norm_(generator_params, config["grad_clip_value"])
        
        # Controllo gradienti per NaN
        for model, model_name in [(style_encoder, "style_encoder"), 
                                 (content_encoder, "content_encoder"), 
                                 (decoder, "decoder")]:
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if check_for_nan(param.grad, names=[f"{model_name}.{name}.grad"]):
                        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}
        
        optimizer_G.step()
        
        return {
            'total_G': total_G_loss.item(),
            'adv_gen': adv_gen_loss.item(),
            'disent': disent_loss.item(),
            'cont': cont_loss.item(),
            'margin': margin_loss_val.item(),
            'recon': recon_loss.item()
        }
        
    except Exception as e:
        print(f"❌ Error in generator training step: {e}")
        return {key: float('nan') for key in ['total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']}

In [None]:
# ================================================================
# CICLO DI TRAINING PRINCIPALE
# ================================================================
print("🚀 Starting training with enhanced NaN protection...")
print(f"📊 Configuration: {config}")

# Inizializza strutture per il tracking delle loss
loss_history = {
    'total_G': [],
    'disc': [],
    'disent': [],
    'cont': [],
    'margin': [],
    'recon': [],
    'adv_gen': []
}

# Contatore per NaN consecutivi
consecutive_nan_count = 0
max_consecutive_nans = config["nan_threshold"]

# Dizionari per salvataggio
models_dict = {
    'style_encoder': style_encoder,
    'content_encoder': content_encoder,
    'discriminator': discriminator,
    'decoder': decoder
}

optimizers_dict = {
    'optimizer_G': optimizer_G,
    'optimizer_D': optimizer_D
}

schedulers_dict = {
    'scheduler_G': scheduler_G,
    'scheduler_D': scheduler_D
}

try:
    for epoch in tqdm(range(config["epochs"]), desc="Training Progress"):
        # Controllo early stopping per NaN
        if consecutive_nan_count >= max_consecutive_nans:
            print(f"🛑 Early stopping: {consecutive_nan_count} consecutive NaN occurrences")
            break
        
        # Imposta tutti i modelli in modalità training
        for model in models:
            model.train()
        
        # Tracking delle loss per l'epoca corrente
        epoch_losses = {key: [] for key in loss_history.keys()}
        
        print(f"\n🔄 Epoch {epoch+1}/{config['epochs']}")
        print("-" * 60)
        
        for batch_idx, (x, labels) in enumerate(train_loader):
            # Trasferisce dati su GPU
            x = x.to(device)
            labels = labels.to(device)
            
            # Controllo NaN nei dati di input
            if check_for_nan(x, labels, names=["input_x", "input_labels"]):
                print(f"⚠️  Skipping batch {batch_idx} due to NaN in input data")
                continue
            
            # ============================================
            # TRAINING DISCRIMINATORE
            # ============================================
            for _ in range(config["discriminator_steps"]):
                disc_loss = discriminator_training_step(x, labels, epoch)
                
                if not np.isnan(disc_loss):
                    epoch_losses['disc'].append(disc_loss)
                    consecutive_nan_count = 0  # Reset counter su successo
                else:
                    consecutive_nan_count += 1
                    print(f"⚠️  NaN in discriminator loss (consecutive: {consecutive_nan_count})")
            
            # ============================================
            # TRAINING GENERATORE
            # ============================================
            for _ in range(config["generator_steps"]):
                gen_losses = generator_training_step(x, labels, epoch)
                
                # Controlla se ci sono NaN nelle loss del generatore
                nan_in_gen = any(np.isnan(v) for v in gen_losses.values())
                
                if not nan_in_gen:
                    # Aggiungi loss valide alla storia
                    for key, value in gen_losses.items():
                        if key in epoch_losses:
                            epoch_losses[key].append(value)
                    consecutive_nan_count = 0  # Reset counter su successo
                else:
                    consecutive_nan_count += 1
                    print(f"⚠️  NaN in generator losses (consecutive: {consecutive_nan_count})")
            
            # ============================================
            # LOGGING PER BATCH
            # ============================================
            # Stampa loss dettagliate per ogni batch
            if batch_idx % 25 == 0:
                print(f"\n📊 Batch {batch_idx + 1}:")
                
                # Loss del discriminatore
                if epoch_losses['disc']:
                    recent_disc = epoch_losses['disc'][-config["discriminator_steps"]:]
                    avg_disc = np.mean(recent_disc)
                    print(f"   Discriminator Loss: {avg_disc:.6f}")
                
                # Loss del generatore
                if epoch_losses['total_G']:
                    recent_gen = epoch_losses['total_G'][-config["generator_steps"]:]
                    avg_total_G = np.mean(recent_gen)
                    print(f"   Generator Total Loss: {avg_total_G:.6f}")
                
                # Loss individuali del generatore
                gen_loss_names = ['adv_gen', 'disent', 'cont', 'margin', 'recon']
                for loss_name in gen_loss_names:
                    if epoch_losses[loss_name]:
                        recent_loss = epoch_losses[loss_name][-config["generator_steps"]:]
                        avg_loss = np.mean(recent_loss)
                        print(f"     {loss_name.replace('_', ' ').title()}: {avg_loss:.6f}")
                
                # Informazioni aggiuntive
                print(f"   Consecutive NaN count: {consecutive_nan_count}")
                if torch.cuda.is_available():
                    print(f"   GPU Memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
            
            # Controllo early stopping durante il batch
            if consecutive_nan_count >= max_consecutive_nans:
                print(f"🛑 Stopping epoch {epoch+1} due to consecutive NaN issues")
                break
        
        # ============================================
        # FINE EPOCA: AGGIORNAMENTI E SALVATAGGIO
        # ============================================
        if consecutive_nan_count < max_consecutive_nans:
            # Aggiorna schedulers
            if epoch_losses['disc']:
                scheduler_D.step(np.mean(epoch_losses['disc']))
            if epoch_losses['total_G']:
                scheduler_G.step(np.mean(epoch_losses['total_G']))
            
            # Aggiungi loss dell'epoca alla storia globale
            for key in loss_history.keys():
                if epoch_losses[key]:
                    loss_history[key].extend(epoch_losses[key])
            
            # Salva checkpoint periodicamente
            if (epoch + 1) % config["save_interval"] == 0:
                save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
            
            # Riepilogo dell'epoca
            print(f"\n✅ Epoch {epoch+1} Summary:")
            print("-" * 40)
            
            for key in ['disc', 'total_G', 'adv_gen', 'disent', 'cont', 'margin', 'recon']:
                if epoch_losses[key]:
                    avg_loss = np.mean(epoch_losses[key])
                    print(f"   {key.replace('_', ' ').title()}: {avg_loss:.6f}")
            
            # Informazioni sui learning rates
            current_lr_G = optimizer_G.param_groups[0]['lr']
            current_lr_D = optimizer_D.param_groups[0]['lr']
            print(f"   LR Generator: {current_lr_G:.2e}")
            print(f"   LR Discriminator: {current_lr_D:.2e}")
            
            print("-" * 60)
        
        # Pulizia memoria periodica
        if torch.cuda.is_available() and epoch % 10 == 0:
            torch.cuda.empty_cache()

except KeyboardInterrupt:
    print("\n⏹️  Training interrupted by user")
except Exception as e:
    print(f"\n❌ Training error: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Salvataggio finale
    print("\n💾 Saving final checkpoint...")
    save_checkpoint(epoch + 1, models_dict, optimizers_dict, schedulers_dict, config)
    
    # Visualizzazione delle loss
    if any(loss_history[key] for key in loss_history.keys()):
        print("📊 Generating loss plots...")
        
        # Configura plot
        plt.style.use('default')
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        # Definisci le loss da plottare
        loss_configs = [
            ('disc', 'Discriminator Loss', 'red'),
            ('total_G', 'Generator Total Loss', 'blue'),
            ('recon', 'Reconstruction Loss', 'green'),
            ('adv_gen', 'Adversarial Generator Loss', 'orange'),
            ('disent', 'Disentanglement Loss', 'purple'),
            ('cont', 'Contrastive Loss', 'brown')
        ]
        
        for i, (loss_name, title, color) in enumerate(loss_configs):
            if loss_history[loss_name]:
                values = loss_history[loss_name]
                
                # Plot raw values
                axes[i].plot(values, alpha=0.4, color=color, linewidth=0.5, label='Raw')
                
                # Plot smoothed values se ci sono abbastanza punti
                if len(values) > 50:
                    window = max(1, len(values) // 50)
                    smoothed = np.convolve(values, np.ones(window)/window, mode='valid')
                    axes[i].plot(smoothed, color=color, linewidth=2, label='Smoothed')
                
                axes[i].set_title(title, fontsize=12, fontweight='bold')
                axes[i].set_xlabel('Iteration')
                axes[i].set_ylabel('Loss')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
                axes[i].set_xlim(0, len(values))
        
        plt.tight_layout()
        plot_path = os.path.join(config["save_dir"], "loss_curves.png")
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"📈 Loss curves saved to: {plot_path}")
    
    # Riepilogo finale
    print(f"\n🎯 Training completed!")
    print(f"   Total epochs processed: {epoch + 1}")
    print(f"   Final consecutive NaN count: {consecutive_nan_count}")
    print(f"   Checkpoints saved in: {config['save_dir']}")
    
    if torch.cuda.is_available():
        print(f"   Final GPU memory: {torch.cuda.memory_allocated()/1024**2:.1f} MB")
    
    print("🎉 Training session finished successfully!")

🚀 Starting training with enhanced NaN protection...
📊 Configuration: {'style_dim': 256, 'content_dim': 256, 'transformer_heads': 4, 'transformer_layers': 4, 'cnn_channels': [16, 32, 64, 128, 256], 'epochs': 100, 'batch_size': 8, 'lr': 3e-05, 'beta1': 0.5, 'beta2': 0.999, 'weight_decay': 0.0001, 'lambda_adv_disc': 0.8, 'lambda_adv_gen': 0.1, 'lambda_disent': 0.3, 'lambda_cont': 0.2, 'lambda_margin': 0.1, 'lambda_recon': 2.0, 'grad_clip_value': 0.5, 'warmup_epochs': 5, 'nan_threshold': 5, 'piano_dir': 'dataset/train/piano', 'violin_dir': 'dataset/train/violin', 'stats_path': 'stats_stft_cqt.npz', 'save_dir': 'checkpoints', 'save_interval': 10, 'discriminator_steps': 2, 'generator_steps': 1}


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]


🔄 Epoch 1/100
------------------------------------------------------------

📊 Batch 1:
   Discriminator Loss: 1.594134
   Generator Total Loss: 7.922220
     Adv Gen: -0.693144
     Disent: 0.004668
     Cont: 1.930429
     Margin: 0.000000
     Recon: 3.929293
   Consecutive NaN count: 0
   GPU Memory: 392.8 MB

📊 Batch 2:
   Discriminator Loss: 1.593978
   Generator Total Loss: 7.215709
     Adv Gen: -0.693145
     Disent: 0.001488
     Cont: 1.920036
     Margin: 0.000000
     Recon: 3.576340
   Consecutive NaN count: 0
   GPU Memory: 392.4 MB

📊 Batch 3:
   Discriminator Loss: 1.593538
   Generator Total Loss: 6.371064
     Adv Gen: -0.693146
     Disent: 0.000555
     Cont: 1.915881
     Margin: 0.000000
     Recon: 3.154129
   Consecutive NaN count: 0
   GPU Memory: 392.4 MB

📊 Batch 4:
   Discriminator Loss: 1.593152
   Generator Total Loss: 7.179996
     Adv Gen: -0.693146
     Disent: 0.000518
     Cont: 1.886470
     Margin: 0.000000
     Recon: 3.559184
   Consecutive NaN c

Training Progress:   0%|          | 0/100 [00:14<?, ?it/s]


📊 Batch 19:
   Discriminator Loss: 1.557698
   Generator Total Loss: 10.662372
     Adv Gen: -0.693141
     Disent: 0.000123
     Cont: 1.171685
     Margin: 0.000000
     Recon: 5.314680
   Consecutive NaN count: 0
   GPU Memory: 392.4 MB






⏹️  Training interrupted by user

💾 Saving final checkpoint...
💾 Checkpoint saved: checkpoints\checkpoint_epoch_1.pth

🎯 Training completed!
   Total epochs processed: 1
   Final consecutive NaN count: 0
   Checkpoints saved in: checkpoints
   Final GPU memory: 392.4 MB
🎉 Training session finished successfully!
