In [1]:
import numpy as np
from scipy import signal
from scipy.io import wavfile

In [2]:
def align_signals(mix, isolated, max_offset=1000):
    """Trova l'offset migliore per allineare i segnali"""
    correlation = signal.correlate(mix, isolated, mode='full')
    offset = np.argmax(correlation) - len(isolated) + 1
    return max(-max_offset, min(max_offset, offset))

def extract_signal(mix_file, isolated_file, output_file):
    # Carica i file
    sr_mix, mix = wavfile.read(mix_file)
    sr_iso, isolated = wavfile.read(isolated_file)
    
    # Converti in float per calcoli precisi
    mix = mix.astype(np.float64)
    isolated = isolated.astype(np.float64)
    
    # Trova l'allineamento migliore
    offset = align_signals(mix, isolated)
    
    # Applica l'offset
    if offset > 0:
        isolated = np.pad(isolated, (offset, 0))[:len(mix)]
    elif offset < 0:
        isolated = isolated[-offset:len(mix)-offset]
    
    # Equalizza lunghezze
    min_len = min(len(mix), len(isolated))
    mix = mix[:min_len]
    isolated = isolated[:min_len]
    
    # Estrai il segnale
    extracted = mix - isolated
    
    # Normalizza
    if np.max(np.abs(extracted)) > 0:
        extracted = extracted / np.max(np.abs(extracted)) * 32767
    
    extracted = np.clip(extracted, -32768, 32767).astype(np.int16)
    
    # Salva
    wavfile.write(output_file, sr_mix, extracted)
    
    return extracted

# Uso
extracted_a = extract_signal('audio_sources/dataset_toy/mix_3/mixture.wav', 'audio_sources/dataset_toy/mix_3/background.wav', 'extracted_a.wav')

In [None]:
# Inizializza modelli
generator = WaveUNetGenerator(input_channels=2, output_channels=1).to(device)
discriminator = AudioDiscriminator(input_channels=1).to(device)

# Ottimizzatori
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

# Loss function
bce_loss = nn.BCELoss()
l1_loss = nn.L1Loss()

# Parametro bilanciamento tra perdita avversaria e di ricostruzione
lambda_adv = 0.001

n_epochs = 20

best_g_loss = float('inf')

for epoch in range(n_epochs):
    generator.train()
    discriminator.train()

    running_g_loss = 0.0
    running_d_loss = 0.0

    for mixture, background in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}"):
        mixture = mixture.to(device)         # (B, 1, T)
        background = background.to(device)   # (B, 1, T)

        # === Train Generator ===
        predicted_a = generator(mixture, background)

        # Discriminator stima "falso" su predicted
        d_fake_pred = discriminator(predicted_a)
        g_adv_loss = bce_loss(d_fake_pred, torch.ones_like(d_fake_pred))

        # Loss di ricostruzione: confronta predizione con target reale
        real_a = (mixture - background).detach()  # ground truth
        g_recon_loss = l1_loss(predicted_a, real_a)

        g_loss = g_recon_loss + lambda_adv * g_adv_loss

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # === Train Discriminator ===
        with torch.no_grad():
            fake_a_detached = predicted_a.detach()
        d_real_pred = discriminator(real_a)
        d_fake_pred = discriminator(fake_a_detached)

        d_real_loss = bce_loss(d_real_pred, torch.ones_like(d_real_pred))
        d_fake_loss = bce_loss(d_fake_pred, torch.zeros_like(d_fake_pred))
        d_loss = d_real_loss + d_fake_loss

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        running_g_loss += g_loss.item()
        running_d_loss += d_loss.item()

    avg_g_loss = running_g_loss / len(train_loader)
    avg_d_loss = running_d_loss / len(train_loader)

    print(f"Epoch {epoch+1}/{n_epochs} | G_loss: {avg_g_loss:.4f} | D_loss: {avg_d_loss:.4f}")
   # torch.save(generator.state_dict(), f"generator_epoch_{epoch+1}.pth")
   # torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch+1}.pth")
    if avg_g_loss < best_g_loss:
        best_g_loss = avg_g_loss
        torch.save(generator.state_dict(), "best_generator.pth")
        torch.save(discriminator.state_dict(), "best_discriminator.pth")
        print(f"✅ Saved best models at epoch {epoch+1}")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class UnsupervisedSeparationModel(nn.Module):
    """
    Modello completamente unsupervised per audio source separation
    """
    def __init__(self, input_channels=1, base_channels=16):
        super().__init__()
        
        # Encoder condiviso
        self.shared_encoder = nn.Sequential(
            nn.Conv1d(input_channels, base_channels, 15, 2, 7),
            nn.LeakyReLU(0.2),
            nn.Conv1d(base_channels, base_channels*2, 15, 2, 7),
            nn.BatchNorm1d(base_channels*2),
            nn.LeakyReLU(0.2),
            nn.Conv1d(base_channels*2, base_channels*4, 15, 2, 7),
            nn.BatchNorm1d(base_channels*4),
            nn.LeakyReLU(0.2),
            nn.Conv1d(base_channels*4, base_channels*8, 15, 2, 7),
            nn.BatchNorm1d(base_channels*8),
            nn.LeakyReLU(0.2),
        )
        
        # Decoder separati per ogni sorgente
        self.decoder_a = self._build_decoder(base_channels*8, base_channels)
        self.decoder_b = self._build_decoder(base_channels*8, base_channels)
        
        # Output layers
        self.output_a = nn.Conv1d(base_channels, 1, 1)
        self.output_b = nn.Conv1d(base_channels, 1, 1)
        
    def _build_decoder(self, in_channels, base_channels):
        return nn.Sequential(
            nn.ConvTranspose1d(in_channels, base_channels*4, 4, 2, 1),
            nn.BatchNorm1d(base_channels*4),
            nn.ReLU(),
            nn.ConvTranspose1d(base_channels*4, base_channels*2, 4, 2, 1),
            nn.BatchNorm1d(base_channels*2),
            nn.ReLU(),
            nn.ConvTranspose1d(base_channels*2, base_channels, 4, 2, 1),
            nn.BatchNorm1d(base_channels),
            nn.ReLU(),
            nn.ConvTranspose1d(base_channels, base_channels, 4, 2, 1),
            nn.BatchNorm1d(base_channels),
            nn.ReLU(),
        )
    
    def forward(self, mixture):
        # Shared encoding
        encoded = self.shared_encoder(mixture)
        
        # Separate decoding
        decoded_a = self.decoder_a(encoded)
        decoded_b = self.decoder_b(encoded)
        
        # Output
        source_a = torch.tanh(self.output_a(decoded_a))
        source_b = torch.tanh(self.output_b(decoded_b))
        
        return source_a, source_b

class UnsupervisedLoss(nn.Module):
    """
    Loss function completamente unsupervised
    """
    def __init__(self, lambda_recon=1.0, lambda_indep=0.1, lambda_sparse=0.01, 
                 lambda_spectral=0.1, lambda_smooth=0.05):
        super().__init__()
        self.lambda_recon = lambda_recon
        self.lambda_indep = lambda_indep
        self.lambda_sparse = lambda_sparse
        self.lambda_spectral = lambda_spectral
        self.lambda_smooth = lambda_smooth
    
    def forward(self, source_a, source_b, mixture):
        # 1. RECONSTRUCTION LOSS
        # La somma delle sorgenti deve ricostruire il mixture
        reconstructed = source_a + source_b
        recon_loss = F.mse_loss(reconstructed, mixture)
        
        # 2. INDEPENDENCE LOSS
        # Le sorgenti devono essere statisticamente indipendenti
        # Minimizza la correlazione crociata
        independence_loss = self._independence_constraint(source_a, source_b)
        
        # 3. SPARSITY LOSS
        # Incoraggia sorgenti sparse (principio di parsimonia)
        sparsity_loss = self._sparsity_constraint(source_a, source_b)
        
        # 4. SPECTRAL DIVERSITY LOSS
        # Le sorgenti devono essere diverse nel dominio delle frequenze
        spectral_loss = self._spectral_diversity(source_a, source_b)
        
        # 5. SMOOTHNESS LOSS
        # Incoraggia continuità temporale
        smoothness_loss = self._smoothness_constraint(source_a, source_b)
        
        total_loss = (self.lambda_recon * recon_loss + 
                     self.lambda_indep * independence_loss +
                     self.lambda_sparse * sparsity_loss +
                     self.lambda_spectral * spectral_loss +
                     self.lambda_smooth * smoothness_loss)
        
        return {
            'total': total_loss,
            'reconstruction': recon_loss,
            'independence': independence_loss,
            'sparsity': sparsity_loss,
            'spectral': spectral_loss,
            'smoothness': smoothness_loss
        }
    
    def _independence_constraint(self, source_a, source_b):
        """
        Minimizza la correlazione tra le sorgenti
        """
        # Normalizza le sorgenti
        a_norm = F.normalize(source_a.flatten(1), dim=1)
        b_norm = F.normalize(source_b.flatten(1), dim=1)
        
        # Correlazione crociata
        correlation = torch.mean(torch.sum(a_norm * b_norm, dim=1))
        
        return torch.abs(correlation)
    
    def _sparsity_constraint(self, source_a, source_b):
        """
        Incoraggia sparsità nelle sorgenti (L1 norm)
        """
        return torch.mean(torch.abs(source_a)) + torch.mean(torch.abs(source_b))
    
    def _spectral_diversity(self, source_a, source_b):
        """
        Le sorgenti devono essere diverse nel dominio spettrale
        """
        # FFT delle sorgenti
        fft_a = torch.fft.fft(source_a, dim=-1)
        fft_b = torch.fft.fft(source_b, dim=-1)
        
        # Magnitude spectrum
        mag_a = torch.abs(fft_a)
        mag_b = torch.abs(fft_b)
        
        # Minimizza la correlazione spettrale
        spectral_corr = F.cosine_similarity(mag_a.flatten(1), mag_b.flatten(1), dim=1)
        
        return torch.mean(torch.abs(spectral_corr))
    
    def _smoothness_constraint(self, source_a, source_b):
        """
        Incoraggia continuità temporale (TV regularization)
        """
        # Total Variation loss
        tv_a = torch.mean(torch.abs(source_a[:, :, 1:] - source_a[:, :, :-1]))
        tv_b = torch.mean(torch.abs(source_b[:, :, 1:] - source_b[:, :, :-1]))
        
        return tv_a + tv_b

class UnsupervisedTrainer:
    """
    Trainer per il modello unsupervised
    """
    def __init__(self, model, loss_fn, optimizer, device='cpu'):
        self.model = model.to(device)
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = device
        
    def train_epoch(self, dataloader):
        self.model.train()
        total_losses = {}
        
        for batch_idx, (mixture, _) in enumerate(dataloader):
            mixture = mixture.to(self.device)
            
            # Forward pass
            source_a, source_b = self.model(mixture)
            
            # Compute losses
            losses = self.loss_fn(source_a, source_b, mixture)
            
            # Backward pass
            self.optimizer.zero_grad()
            losses['total'].backward()
            
            # Gradient clipping per stabilità
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # Accumula losses
            for key, value in losses.items():
                if key not in total_losses:
                    total_losses[key] = 0
                total_losses[key] += value.item()
        
        # Media delle losses
        for key in total_losses:
            total_losses[key] /= len(dataloader)
            
        return total_losses
    
    def evaluate(self, dataloader):
        self.model.eval()
        total_losses = {}
        
        with torch.no_grad():
            for mixture, _ in dataloader:
                mixture = mixture.to(self.device)
                source_a, source_b = self.model(mixture)
                losses = self.loss_fn(source_a, source_b, mixture)
                
                for key, value in losses.items():
                    if key not in total_losses:
                        total_losses[key] = 0
                    total_losses[key] += value.item()
        
        for key in total_losses:
            total_losses[key] /= len(dataloader)
            
        return total_losses

# Esempio di utilizzo
def train_unsupervised_model():
    """
    Esempio di training completamente unsupervised
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Modello
    model = UnsupervisedSeparationModel()
    
    # Loss function
    loss_fn = UnsupervisedLoss(
        lambda_recon=1.0,    # Ricostruzione è la più importante
        lambda_indep=0.1,    # Indipendenza
        lambda_sparse=0.01,  # Sparsità
        lambda_spectral=0.1, # Diversità spettrale
        lambda_smooth=0.05   # Smoothness
    )
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
    
    # Trainer
    trainer = UnsupervisedTrainer(model, loss_fn, optimizer, device)
    
    return trainer, model

# Funzione di analisi dei risultati
def analyze_separation_quality(source_a, source_b, mixture):
    """
    Analizza la qualità della separazione senza ground truth
    """
    # 1. Reconstruction error
    recon_error = F.mse_loss(source_a + source_b, mixture)
    
    # 2. Source independence
    correlation = F.cosine_similarity(source_a.flatten(), source_b.flatten(), dim=0)
    
    # 3. Energy distribution
    energy_a = torch.mean(source_a**2)
    energy_b = torch.mean(source_b**2)
    energy_ratio = min(energy_a, energy_b) / max(energy_a, energy_b)
    
    # 4. Spectral diversity
    fft_a = torch.abs(torch.fft.fft(source_a))
    fft_b = torch.abs(torch.fft.fft(source_b))
    spectral_sim = F.cosine_similarity(fft_a.flatten(), fft_b.flatten(), dim=0)
    
    return {
        'reconstruction_error': recon_error.item(),
        'correlation': correlation.item(),
        'energy_ratio': energy_ratio.item(),
        'spectral_similarity': spectral_sim.item()
    }