In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt

from style_encoder import StyleEncoder
from content_encoder import ContentEncoder
from discriminator import Discriminator
from decoder import Decoder
from losses import (infoNCE_loss, margin_loss, adversarial_loss, 
                   disentanglement_loss, compute_comprehensive_loss)
from Dataloader import get_dataloader

# Configurazione dei dispositivi
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Parametri di configurazione
# ==============================
config = {
    # Architettura
    "style_dim": 256,
    "content_dim": 256,
    "transformer_heads": 4,
    "transformer_layers": 4,
    "cnn_channels": [16, 32, 64, 128, 256],
    
    # Training
    "epochs": 100,
    "batch_size": 16,  # Deve essere pari per bilanciare piano/violino
    "lr": 2e-4,
    "beta1": 0.5,
    "beta2": 0.999,
    
    # Pesi delle loss
    "lambda_adv_disc": 1.0,     # Discriminatore
    "lambda_adv_gen": 0.1,      # Generatore (entropia)
    "lambda_disent": 0.5,       # Disaccoppiamento stile-contenuto
    "lambda_cont": 0.5,         # Loss contrastiva (InfoNCE)
    "lambda_margin": 0.2,       # Margin loss
    "lambda_recon": 10.0,       # Ricostruzione
    
    # Percorsi dati
    "piano_dir": "path/to/piano_dataset",
    "violin_dir": "path/to/violin_dataset",
    "stats_path": "stats_stft_cqt.npz",
    
    # Salvataggio
    "save_dir": "checkpoints",
    "save_interval": 5,
}

# Creazione directory salvataggio
os.makedirs(config["save_dir"], exist_ok=True)

In [None]:

# Inizializzazione modelli
style_encoder = StyleEncoder(
    cnn_out_dim=config["style_dim"],
    transformer_dim=config["style_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

content_encoder = ContentEncoder(
    cnn_out_dim=config["content_dim"],
    transformer_dim=config["content_dim"],
    num_heads=config["transformer_heads"],
    num_layers=config["transformer_layers"],
    channels_list=config["cnn_channels"]
).to(device)

discriminator = Discriminator(
    input_dim=config["style_dim"],
    hidden_dim=128
).to(device)

decoder = Decoder(
    d_model=config["style_dim"],
    nhead=config["transformer_heads"],
    num_layers=config["transformer_layers"]
).to(device)

# Inizializzazione pesi
def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

style_encoder.apply(init_weights)
content_encoder.apply(init_weights)
discriminator.apply(init_weights)
decoder.apply(init_weights)

# 3. Ottimizzatori
# =================
# Ottimizzatori separati per gruppi di modelli
optimizer_G = optim.Adam(
    list(style_encoder.parameters()) + 
    list(content_encoder.parameters()) + 
    list(decoder.parameters()),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"])
)

optimizer_D = optim.Adam(
    discriminator.parameters(),
    lr=config["lr"],
    betas=(config["beta1"], config["beta2"])
)


In [None]:

# 4. Dataset e DataLoader
# ========================


train_loader = get_dataloader(
    piano_dir=config["piano_dir"],
    violin_dir=config["violin_dir"],
    batch_size=config["batch_size"],
    shuffle=True,
    stats_path=config["stats_path"]
)

# 5. Funzioni di utilità
# =======================
def set_requires_grad(models, requires_grad):
    """Abilita/disabilita i gradienti per un insieme di modelli"""
    if not isinstance(models, list):
        models = [models]
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

def save_checkpoint(epoch):
    """Salva i checkpoint dei modelli"""
    checkpoint = {
        'epoch': epoch,
        'style_encoder': style_encoder.state_dict(),
        'content_encoder': content_encoder.state_dict(),
        'discriminator': discriminator.state_dict(),
        'decoder': decoder.state_dict(),
        'optimizer_G': optimizer_G.state_dict(),
        'optimizer_D': optimizer_D.state_dict(),
        'config': config
    }
    torch.save(checkpoint, os.path.join(config["save_dir"], f"checkpoint_epoch_{epoch}.pth"))

# 6. Ciclo di training
# =====================
# Struttura per tenere traccia delle loss
loss_history = {
    'total_G': [],
    'disc': [],
    'disent': [],
    'cont': [],
    'margin': [],
    'recon': [],
    'adv_gen': []
}


In [None]:
for epoch in tqdm(range(config["epochs"]), desc="Training Progress"):
    # Modalità training per tutti i modelli
    style_encoder.train()
    content_encoder.train()
    discriminator.train()
    decoder.train()
    
    for batch_idx, (x, labels) in enumerate(train_loader):
        x = x.to(device)          # (B, S, 2, T, F)
        labels = labels.to(device) # (B,)
        
        # ==================================================================
        # Fase 1: Aggiornamento del Discriminatore
        # ==================================================================
        set_requires_grad([style_encoder, content_encoder, decoder], False)
        set_requires_grad(discriminator, True)
        
        # Forward pass encoders
        with torch.no_grad():
            style_emb, class_emb = style_encoder(x, labels)
            content_emb = content_encoder(x)
        
        # Calcolo loss adversarial per il discriminatore
        disc_loss, _ = adversarial_loss(
            style_emb.detach(),
            class_emb.detach(),
            content_emb.detach(),
            discriminator,
            labels,
            compute_for_discriminator=True,
            lambda_content=config["lambda_adv_disc"]
        )
        
        # Backpropagazione e aggiornamento
        optimizer_D.zero_grad()
        disc_loss.backward()
        optimizer_D.step()
        
        # ==================================================================
        # Fase 2: Aggiornamento dei Generator (Encoders + Decoder)
        # ==================================================================
        set_requires_grad([style_encoder, content_encoder, decoder], True)
        set_requires_grad(discriminator, False)
        
        # Forward pass encoders
        style_emb, class_emb = style_encoder(x, labels)
        content_emb = content_encoder(x)
        
        # Calcolo loss adversarial per il generatore
        _, adv_gen_loss = adversarial_loss(
            style_emb,
            class_emb,
            content_emb,
            discriminator,
            labels,
            compute_for_discriminator=False,
            lambda_content=config["lambda_adv_gen"]
        )
        
        # Loss di disaccoppiamento stile-contenuto
        disent_loss = disentanglement_loss(
            style_emb,
            content_emb.mean(dim=1),  # Media lungo la sequenza
            use_hsic=True
        )
        
        # Loss contrastive
        cont_loss = infoNCE_loss(style_emb, labels)
        
        # Margin loss per class embeddings
        margin_loss_val = margin_loss(class_emb)
        
        # Ricostruzione audio
        recon_x = decoder(content_emb, class_emb, y=x)
        recon_losses = compute_comprehensive_loss(recon_x, x)
        recon_loss = recon_losses['total_loss']
        
        # Loss totale per il generatore
        total_G_loss = (
            config["lambda_adv_gen"] * adv_gen_loss +
            config["lambda_disent"] * disent_loss +
            config["lambda_cont"] * cont_loss +
            config["lambda_margin"] * margin_loss_val +
            config["lambda_recon"] * recon_loss
        )
        
        # Backpropagazione e aggiornamento
        optimizer_G.zero_grad()
        total_G_loss.backward()
        
        # Gradient clipping per stabilizzare il training
        torch.nn.utils.clip_grad_norm_(
            list(style_encoder.parameters()) + 
            list(content_encoder.parameters()) +
            list(decoder.parameters()),
            1.0
        )
        
        optimizer_G.step()
        
        # Registrazione delle loss
        loss_history['disc'].append(disc_loss.item())
        loss_history['disent'].append(disent_loss.item())
        loss_history['cont'].append(cont_loss.item())
        loss_history['margin'].append(margin_loss_val.item())
        loss_history['recon'].append(recon_loss.item())
        loss_history['adv_gen'].append(adv_gen_loss.item())
        loss_history['total_G'].append(total_G_loss.item())
    
    # ==================================================================
    # Operazioni di fine epoca
    # ==================================================================
    # Salvataggio checkpoint
    if (epoch + 1) % config["save_interval"] == 0:
        save_checkpoint(epoch + 1)
    
    # Logging delle loss medie
    avg_losses = {k: np.mean(v[-len(train_loader):]) for k, v in loss_history.items()}
    print(f"\nEpoch {epoch+1}/{config['epochs']}:")
    print(f"  Disc Loss: {avg_losses['disc']:.4f}")
    print(f"  Total G Loss: {avg_losses['total_G']:.4f}")
    print(f"  Recon Loss: {avg_losses['recon']:.4f}")
    print(f"  Disent Loss: {avg_losses['disent']:.4f}")

# 7. Salvataggio finale e visualizzazione
# ========================================
save_checkpoint(config["epochs"])

# Plot delle loss
plt.figure(figsize=(12, 8))
for loss_name, values in loss_history.items():
    plt.plot(values, label=loss_name)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Losses")
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(config["save_dir"], "loss_curves.png"))
plt.show()