In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import pandas as pd
import os
from pydub import AudioSegment
import torch
from torch import nn, optim
from tqdm import tqdm
import csv
from itertools import cycle

In [2]:
class AudioDatasetMixture(torch.utils.data.Dataset):
    def __init__(self, csv_file, target_duration=10000, target_sample_rate=16000,
                 target_channels=1):
        self.df = pd.read_csv(csv_file)
        self.target_duration = target_duration
        self.target_channels = target_channels
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        mixture_path = os.path.join(self.df.loc[idx, 'mixture'])
        mixture_audio = AudioSegment.from_file(mixture_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)
        mixture, _ = self._pydub_to_array(mixture_audio)
        mixture_tensor = torch.Tensor(mixture)

        return mixture_tensor

    def _pydub_to_array(self, audio: AudioSegment) -> (np.ndarray, int):
        return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
                1 << (8 * audio.sample_width - 1)), audio.frame_rate

In [3]:
class AudioDatasetBackground(torch.utils.data.Dataset):
    def __init__(self, csv_file, target_duration=10000, target_sample_rate=16000,
                 target_channels=1):
        self.df = pd.read_csv(csv_file)
        self.target_duration = target_duration
        self.target_channels = target_channels
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        background_path = os.path.join(self.df.loc[idx, 'background'])
        background_audio = AudioSegment.from_file(background_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)
        background, _ = self._pydub_to_array(background_audio)
        background_tensor = torch.Tensor(background)

        return background_tensor

    def _pydub_to_array(self, audio: AudioSegment) -> (np.ndarray, int):
        return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
                1 << (8 * audio.sample_width - 1)), audio.frame_rate

In [4]:
def pydub_to_array(audio: AudioSegment) -> (np.ndarray, int):
    return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
            1 << (8 * audio.sample_width - 1)), audio.frame_rate

def array_to_pydub(audio_np_array: np.ndarray, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1) -> AudioSegment:
    return AudioSegment((audio_np_array * (2 ** (8 * sample_width - 1))).astype(np.int16).tobytes(),
                        frame_rate=sample_rate, sample_width=sample_width, channels=channels)

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_csv = "csv/train_wav.csv"
validation_csv = "csv/val_wav.csv"
test_csv = "csv/test_wav.csv"

train_m = AudioDatasetMixture(csv_file=train_csv)
validation_m = AudioDatasetMixture(csv_file=validation_csv)
test_m = AudioDatasetMixture(csv_file=test_csv)

train_b = AudioDatasetBackground(csv_file=train_csv)
validation_b = AudioDatasetBackground(csv_file=validation_csv)
test_b = AudioDatasetBackground(csv_file=test_csv)

batch_size = 18

mixture_train_loader = torch.utils.data.DataLoader(train_m, batch_size=batch_size, shuffle=True)
mixture_validation_loader = torch.utils.data.DataLoader(validation_m, batch_size=batch_size, shuffle=False)
mixture_test_loader = torch.utils.data.DataLoader(test_m, batch_size=batch_size, shuffle=False)

background_train_loader = torch.utils.data.DataLoader(train_b, batch_size=batch_size, shuffle=True)
background_validation_loader = torch.utils.data.DataLoader(validation_b, batch_size=batch_size, shuffle=False)
background_test_loader = torch.utils.data.DataLoader(test_b, batch_size=batch_size, shuffle=False)

In [6]:
class SeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, bias=True, norm=True, activation=True):
        super(SeparableConv1d, self).__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels, kernel_size=kernel_size,
                                   stride=stride, padding=padding, dilation=dilation,
                                   groups=in_channels, bias=bias)
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1,
                                   stride=1, padding=0, bias=bias)

        self.bn = nn.BatchNorm1d(out_channels) if norm else nn.Identity()
        self.act = nn.LeakyReLU(inplace=True) if activation else nn.Identity()

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        x = self.act(x)
        return x

In [7]:
class LocalSelfAttention1D(nn.Module):
    def __init__(self, embed_dim, num_heads=4, window_size=31):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.window_size = window_size

    def forward(self, x):
        # x: (B, C, T) → (B, T, C) for attention
        x = x.permute(0, 2, 1)
        B, T, C = x.shape

        out = torch.zeros_like(x)

        half_window = self.window_size // 2
        for t in range(T):
            start = max(0, t - half_window)
            end = min(T, t + half_window + 1)

            context = x[:, start:end, :]
            query = x[:, t:t+1, :]
            attn_out, _ = self.attn(query, context, context)
            out[:, t:t+1, :] = attn_out

        # back to (B, C, T)
        return out.permute(0, 2, 1)

In [8]:
### UNet1D con SepConv - Profondità 4 ###
class UNet1D(nn.Module):
    def __init__(self, input_channels=1, base_channels=32):
        super(UNet1D, self).__init__()
        # Encoder
        self.enc1 = SeparableConv1d(input_channels, base_channels, kernel_size=9, padding=4)
        self.enc2 = SeparableConv1d(base_channels, base_channels*2, kernel_size=9, stride=2, padding=4)
        self.enc3 = SeparableConv1d(base_channels*2, base_channels*4, kernel_size=9, stride=2, padding=4)
        self.enc4 = SeparableConv1d(base_channels*4, base_channels*8, kernel_size=9, stride=2, padding=4)
        
        # Bottleneck
        self.bottleneck = SeparableConv1d(base_channels*8, base_channels*8, kernel_size=9, padding=4)
        
        # Decoder
        self.dec4 = nn.Sequential(
            nn.ConvTranspose1d(base_channels*8, base_channels*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(base_channels*4),
            nn.LeakyReLU(inplace=True)
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose1d(base_channels*8, base_channels*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(base_channels*2),
            nn.LeakyReLU(inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose1d(base_channels*4, base_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm1d(base_channels),
            nn.LeakyReLU(inplace=True)
        )
        self.dec1 = nn.Sequential(
            SeparableConv1d(base_channels*2, input_channels, kernel_size=9, padding=4, norm=False, activation=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        
        # Bottleneck
        b = self.bottleneck(e4)
        
        # Decoder
        d4 = self.dec4(b)
        d4 = torch.cat([d4, e3], dim=1)
        d3 = self.dec3(d4)
        d3 = torch.cat([d3, e2], dim=1)
        d2 = self.dec2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        out = self.dec1(d2)
        
        return out

In [9]:
class Discriminator1D(nn.Module):
    def __init__(self, input_channels=1, base_channels=32):
        super().__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv1d(input_channels, base_channels, kernel_size=7, stride=2, padding=3),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv1d(base_channels, base_channels*2, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(base_channels*2),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv1d(base_channels*2, base_channels*4, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(base_channels*4),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.layer4 = nn.Sequential(
            nn.Conv1d(base_channels*4, base_channels*8, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(base_channels*8),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.final_conv = nn.Conv1d(base_channels*8, 1, kernel_size=3, padding=1)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.final_conv(x)
        x = self.avgpool(x)
        return x.view(-1)

In [10]:
# # === MODELLI & PARAMETRI ===
# generator = UNet1D().to(device)
# discriminator = Discriminator1D().to(device)
# 
# num_epochs = 30
# adversarial_loss = nn.BCEWithLogitsLoss()
# lambda_l1 = 10 
# loss_fn = nn.L1Loss()
# 
# # Learning rate bilanciati
# optimizer_G = torch.optim.AdamW(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
# optimizer_D = torch.optim.AdamW(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
# 
# # === FILE E PATH ===
# os.makedirs("results-3", exist_ok=True)
# csv_path = "results-3/training_history.csv"
# checkpoint_path = "results-3/checkpoint.pth"
# 
# # Definizione delle intestazioni del CSV 
# fieldnames = ['epoch', 'train_loss_D', 'train_loss_G', 'train_loss_G_GAN', 'train_loss_G_L1', 'val_loss_L1']
# 
# # Inizializza il CSV solo se non esiste
# if not os.path.exists(csv_path):
#     with open(csv_path, 'w', newline='') as csvfile:
#         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
#         writer.writeheader()
# 
# # === CHECKPOINT: riprendi se esiste ===
# start_epoch = 0
# best_val_loss = float('inf')
# if os.path.exists(checkpoint_path):
#     print("Loading checkpoint...")
#     checkpoint = torch.load(checkpoint_path, map_location=device)
#     generator.load_state_dict(checkpoint['generator_state_dict'])
#     discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
#     optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
#     optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
#     start_epoch = checkpoint['epoch'] + 1
#     best_val_loss = checkpoint.get('best_val_loss', float('inf'))
#     print(f"Resuming training from epoch {start_epoch}")
# 
# # === TRAINING LOOP ===
# for epoch in range(start_epoch, num_epochs):
#     generator.train()
#     discriminator.train()
#     running_loss_D = 0.0
#     running_loss_G = 0.0
#     running_loss_G_GAN = 0.0
#     running_loss_G_L1 = 0.0
# 
#     train_pbar = tqdm(zip(mixture_train_loader, cycle(background_train_loader)),
#                       total=len(mixture_train_loader),
#                       desc=f'Epoch {epoch+1}/{num_epochs} - Training',
#                       leave=False, unit='batch')
# 
#     for mixture, background in train_pbar:
#         mixture = mixture.to(device)
#         background = background.to(device)
#         if mixture.ndim == 2:
#             mixture = mixture.unsqueeze(1)
#         if background.ndim == 2:
#             background = background.unsqueeze(1)
# 
#         # --- 1. Train Discriminator ---
#         optimizer_D.zero_grad()
#         with torch.no_grad():
#             fake_background = generator(mixture).detach()
#         real_output = discriminator(background)
#         fake_output = discriminator(fake_background)
#         real_labels = torch.ones_like(real_output)
#         fake_labels = torch.zeros_like(fake_output)
#         loss_D_real = adversarial_loss(real_output, real_labels)
#         loss_D_fake = adversarial_loss(fake_output, fake_labels)
#         loss_D = 0.5 * (loss_D_real + loss_D_fake)
#         loss_D.backward()
#         optimizer_D.step()
# 
#         # --- 2. Train Generator ---
#         optimizer_G.zero_grad()
#         fake_background = generator(mixture)
#         fake_output = discriminator(fake_background)
#         loss_G_GAN = adversarial_loss(fake_output, real_labels)
#         loss_G_L1 = F.l1_loss(fake_background, background)
#         loss_G = loss_G_GAN + lambda_l1 * loss_G_L1
#         loss_G.backward()
#         optimizer_G.step()
# 
#         running_loss_D += loss_D.item()
#         running_loss_G += loss_G.item()
#         running_loss_G_GAN += loss_G_GAN.item()
#         running_loss_G_L1 += loss_G_L1.item()
# 
#         train_pbar.set_postfix({
#             'D_loss': f'{loss_D.item():.4f}',
#             'G_loss': f'{loss_G.item():.4f}',
#             'G_GAN': f'{loss_G_GAN.item():.4f}',
#             'G_L1': f'{loss_G_L1.item():.4f}'
#         })
# 
#     train_pbar.close()
# 
#     avg_loss_D = running_loss_D / len(mixture_train_loader)
#     avg_loss_G = running_loss_G / len(mixture_train_loader)
#     avg_loss_G_GAN = running_loss_G_GAN / len(mixture_train_loader)
#     avg_loss_G_L1 = running_loss_G_L1 / len(mixture_train_loader)
# 
#     # === VALIDATION ===
#     generator.eval()
#     val_loss_L1 = 0.0
# 
#     val_pbar = tqdm(zip(mixture_validation_loader, cycle(background_validation_loader)),
#                     total=len(mixture_validation_loader),
#                     desc=f'Epoch {epoch+1}/{num_epochs} - Validation',
#                     leave=False, unit='batch')
# 
#     with torch.no_grad():
#         for mixture, background in val_pbar:
#             mixture = mixture.to(device)
#             background = background.to(device)
#             if mixture.ndim == 2:
#                 mixture = mixture.unsqueeze(1)
#             if background.ndim == 2:
#                 background = background.unsqueeze(1)
# 
#             pred_background = generator(mixture)
#             batch_val_loss = F.l1_loss(pred_background, background).item()
#             val_loss_L1 += batch_val_loss
# 
#             val_pbar.set_postfix({'Val_L1': f'{batch_val_loss:.4f}'})
# 
#     val_pbar.close()
#     val_loss_L1 /= len(mixture_validation_loader)
# 
#     print(f"[Epoch {epoch+1}/{num_epochs}] "
#           f"Train_Loss_D: {avg_loss_D:.4f} | "
#           f"Train_Loss_G: {avg_loss_G:.4f} | "
#           f"G_GAN: {avg_loss_G_GAN:.4f} | "
#           f"G_L1: {avg_loss_G_L1:.4f} | "
#           f"Val_L1: {val_loss_L1:.4f}")
# 
#     # === LOG E CSV ===
#     with open(csv_path, 'a', newline='') as csvfile:
#         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
#         writer.writerow({
#             'epoch': epoch + 1,
#             'train_loss_D': avg_loss_D,
#             'train_loss_G': avg_loss_G,
#             'train_loss_G_GAN': avg_loss_G_GAN,
#             'train_loss_G_L1': avg_loss_G_L1,
#             'val_loss_L1': val_loss_L1
#         })
# 
#     # === CHECKPOINT ===
#     checkpoint = {
#         'epoch': epoch,
#         'generator_state_dict': generator.state_dict(),
#         'discriminator_state_dict': discriminator.state_dict(),
#         'optimizer_G_state_dict': optimizer_G.state_dict(),
#         'optimizer_D_state_dict': optimizer_D.state_dict(),
#         'best_val_loss': best_val_loss
#     }
#     torch.save(checkpoint, checkpoint_path)
# 
#     # === BEST MODEL ===
#     if val_loss_L1 < best_val_loss:
#         best_val_loss = val_loss_L1
#         torch.save(generator.state_dict(), "results-3/best_generator.pth")
#         torch.save(discriminator.state_dict(), "results-3/best_discriminator.pth")
#         print(f"New best model saved at epoch {epoch+1} with Val_L1: {val_loss_L1:.4f}")
# 
# print(f"\nTraining completed! Best validation L1 loss: {best_val_loss:.4f}")
# print(f"Training history saved to: {csv_path}")


In [None]:
# === MODELLI & PARAMETRI ===
generator = UNet1D().to(device)
discriminator = Discriminator1D().to(device)

num_epochs = 30
adversarial_loss = nn.BCEWithLogitsLoss()
lambda_l1 = 10 
loss_fn = nn.L1Loss()

# === STRATEGIA 1: Learning rates differenti ===
# Discriminatore più lento del generatore
optimizer_G = torch.optim.AdamW(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.AdamW(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))  # Più lento

# === STRATEGIA 2: Frequency training ===
# Allena il discriminatore ogni N step del generatore
d_train_freq = 2  # Allena D ogni 2 step di G
g_steps = 0

# === STRATEGIA 3: Soglie dinamiche ===
# Se D diventa troppo forte, rallentalo ulteriormente
d_accuracy_threshold = 0.8  # Se D è accurato > 80%, rallentalo
g_loss_threshold = 2.0      # Se G loss è alta, aiuta G

# === FILE E PATH ===
os.makedirs("results-4", exist_ok=True)
csv_path = "results-4/training_history.csv"
checkpoint_path = "results-4/checkpoint.pth"

# Definizione delle intestazioni del CSV 
fieldnames = ['epoch', 'train_loss_D', 'train_loss_G', 'train_loss_G_GAN', 'train_loss_G_L1', 
              'val_loss_L1', 'd_accuracy', 'd_train_steps', 'g_train_steps']

# Inizializza il CSV solo se non esiste
if not os.path.exists(csv_path):
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

# === CHECKPOINT: riprendi se esiste ===
start_epoch = 0
best_val_loss = float('inf')
if os.path.exists(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    g_steps = checkpoint.get('g_steps', 0)
    print(f"Resuming training from epoch {start_epoch}")

# === TRAINING LOOP ===
for epoch in range(start_epoch, num_epochs):
    generator.train()
    discriminator.train()
    running_loss_D = 0.0
    running_loss_G = 0.0
    running_loss_G_GAN = 0.0
    running_loss_G_L1 = 0.0
    
    # Contatori per il bilanciamento
    d_train_steps = 0
    g_train_steps = 0
    d_correct_predictions = 0
    total_predictions = 0

    train_pbar = tqdm(zip(mixture_train_loader, cycle(background_train_loader)),
                      total=len(mixture_train_loader),
                      desc=f'Epoch {epoch+1}/{num_epochs} - Training',
                      leave=False, unit='batch')

    for batch_idx, (mixture, background) in enumerate(train_pbar):
        mixture = mixture.to(device)
        background = background.to(device)
        if mixture.ndim == 2:
            mixture = mixture.unsqueeze(1)
        if background.ndim == 2:
            background = background.unsqueeze(1)

        # Genera fake samples
        with torch.no_grad():
            fake_background = generator(mixture)

        # === CALCOLA ACCURATEZZA DISCRIMINATORE ===
        with torch.no_grad():
            real_output = discriminator(background)
            fake_output = discriminator(fake_background.detach())
            
            # Calcola accuratezza (probabilità > 0.5 per real, < 0.5 per fake)
            real_preds = torch.sigmoid(real_output) > 0.5
            fake_preds = torch.sigmoid(fake_output) < 0.5
            
            d_correct_predictions += (real_preds.sum() + fake_preds.sum()).item()
            total_predictions += real_preds.numel() + fake_preds.numel()

        current_d_accuracy = d_correct_predictions / max(total_predictions, 1)

        # === DECISIONE: ALLENARE DISCRIMINATORE? ===
        should_train_d = True
        
        # Strategia frequency-based
        if g_steps % d_train_freq != 0:
            should_train_d = False
            
        # Strategia threshold-based: se D è troppo accurato, non allenarlo
        if current_d_accuracy > d_accuracy_threshold:
            should_train_d = False
            
        # --- TRAIN DISCRIMINATOR (condizionale) ---
        if should_train_d:
            optimizer_D.zero_grad()
            
            # Fresh forward pass per il discriminatore
            real_output = discriminator(background)
            fake_output = discriminator(fake_background.detach())
            
            real_labels = torch.ones_like(real_output)
            fake_labels = torch.zeros_like(fake_output)
            
            loss_D_real = adversarial_loss(real_output, real_labels)
            loss_D_fake = adversarial_loss(fake_output, fake_labels)
            loss_D = 0.5 * (loss_D_real + loss_D_fake)
            
            loss_D.backward()
            optimizer_D.step()
            d_train_steps += 1
        else:
            # Se non alleniamo D, usa l'ultima loss per logging
            with torch.no_grad():
                real_output = discriminator(background)
                fake_output = discriminator(fake_background.detach())
                real_labels = torch.ones_like(real_output)
                fake_labels = torch.zeros_like(fake_output)
                loss_D_real = adversarial_loss(real_output, real_labels)
                loss_D_fake = adversarial_loss(fake_output, fake_labels)
                loss_D = 0.5 * (loss_D_real + loss_D_fake)

        # --- TRAIN GENERATOR ---
        optimizer_G.zero_grad()
        fake_background = generator(mixture)
        fake_output = discriminator(fake_background)
        
        real_labels = torch.ones_like(fake_output)
        loss_G_GAN = adversarial_loss(fake_output, real_labels)
        loss_G_L1 = F.l1_loss(fake_background, background)
        
        # Peso dinamico: se G sta soffrendo troppo, riduci il termine adversarial
        if loss_G_GAN.item() > g_loss_threshold:
            adaptive_lambda = lambda_l1 * 1.5  # Aumenta L1 weight
            loss_G = 0.5 * loss_G_GAN + adaptive_lambda * loss_G_L1
        else:
            loss_G = loss_G_GAN + lambda_l1 * loss_G_L1
            
        loss_G.backward()
        optimizer_G.step()
        
        g_train_steps += 1
        g_steps += 1

        # Update running losses
        running_loss_D += loss_D.item()
        running_loss_G += loss_G.item()
        running_loss_G_GAN += loss_G_GAN.item()
        running_loss_G_L1 += loss_G_L1.item()

        # Update progress bar
        train_pbar.set_postfix({
            'D_loss': f'{loss_D.item():.4f}',
            'G_loss': f'{loss_G.item():.4f}',
            'D_acc': f'{current_d_accuracy:.3f}',
            'D_trained': 'Y' if should_train_d else 'N',
            'G_GAN': f'{loss_G_GAN.item():.4f}',
            'G_L1': f'{loss_G_L1.item():.4f}'
        })

    train_pbar.close()

    # Calcola medie
    avg_loss_D = running_loss_D / len(mixture_train_loader)
    avg_loss_G = running_loss_G / len(mixture_train_loader)
    avg_loss_G_GAN = running_loss_G_GAN / len(mixture_train_loader)
    avg_loss_G_L1 = running_loss_G_L1 / len(mixture_train_loader)
    final_d_accuracy = d_correct_predictions / max(total_predictions, 1)

    # === VALIDATION ===
    generator.eval()
    val_loss_L1 = 0.0

    val_pbar = tqdm(zip(mixture_validation_loader, cycle(background_validation_loader)),
                    total=len(mixture_validation_loader),
                    desc=f'Epoch {epoch+1}/{num_epochs} - Validation',
                    leave=False, unit='batch')

    with torch.no_grad():
        for mixture, background in val_pbar:
            mixture = mixture.to(device)
            background = background.to(device)
            if mixture.ndim == 2:
                mixture = mixture.unsqueeze(1)
            if background.ndim == 2:
                background = background.unsqueeze(1)

            pred_background = generator(mixture)
            batch_val_loss = F.l1_loss(pred_background, background).item()
            val_loss_L1 += batch_val_loss

            val_pbar.set_postfix({'Val_L1': f'{batch_val_loss:.4f}'})

    val_pbar.close()
    val_loss_L1 /= len(mixture_validation_loader)

    print(f"[Epoch {epoch+1}/{num_epochs}] "
          f"Train_Loss_D: {avg_loss_D:.4f} | "
          f"Train_Loss_G: {avg_loss_G:.4f} | "
          f"G_GAN: {avg_loss_G_GAN:.4f} | "
          f"G_L1: {avg_loss_G_L1:.4f} | "
          f"Val_L1: {val_loss_L1:.4f} | "
          f"D_Acc: {final_d_accuracy:.3f} | "
          f"D_Steps: {d_train_steps} | G_Steps: {g_train_steps}")

    # === LOG E CSV ===
    with open(csv_path, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writerow({
            'epoch': epoch + 1,
            'train_loss_D': avg_loss_D,
            'train_loss_G': avg_loss_G,
            'train_loss_G_GAN': avg_loss_G_GAN,
            'train_loss_G_L1': avg_loss_G_L1,
            'val_loss_L1': val_loss_L1,
            'd_accuracy': final_d_accuracy,
            'd_train_steps': d_train_steps,
            'g_train_steps': g_train_steps
        })

    # === CHECKPOINT ===
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'best_val_loss': best_val_loss,
        'g_steps': g_steps
    }
    torch.save(checkpoint, checkpoint_path)

    # === BEST MODEL ===
    if val_loss_L1 < best_val_loss:
        best_val_loss = val_loss_L1
        torch.save(generator.state_dict(), "results-4/best_generator.pth")
        torch.save(discriminator.state_dict(), "results-4/best_discriminator.pth")
        print(f"New best model saved at epoch {epoch+1} with Val_L1: {val_loss_L1:.4f}")

    # === ADATTAMENTO DINAMICO DEI PARAMETRI ===
    # Se il discriminatore è troppo debole (accuratezza < 60%), aumenta la sua learning rate
    if final_d_accuracy < 0.6:
        for param_group in optimizer_D.param_groups:
            param_group['lr'] = min(param_group['lr'] * 1.1, 5e-4)  # Max 5e-4
        print(f"D accuracy too low ({final_d_accuracy:.3f}), increasing D learning rate to {optimizer_D.param_groups[0]['lr']:.2e}")
    
    # Se il discriminatore è troppo forte (accuratezza > 85%), riduci la sua learning rate
    elif final_d_accuracy > 0.85:
        for param_group in optimizer_D.param_groups:
            param_group['lr'] = max(param_group['lr'] * 0.9, 1e-5)  # Min 1e-5
        print(f"D accuracy too high ({final_d_accuracy:.3f}), decreasing D learning rate to {optimizer_D.param_groups[0]['lr']:.2e}")

print(f"\nTraining completed! Best validation L1 loss: {best_val_loss:.4f}")
print(f"Training history saved to: {csv_path}")

Loading checkpoint...
Resuming training from epoch 3


Epoch 4/30 - Training:   0%|          | 15/3889 [00:09<35:36,  1.81batch/s, D_loss=0.4561, G_loss=1.3446, D_acc=0.952, D_trained=N, G_GAN=1.1732, G_L1=0.0171] 

In [None]:
def preprocess_audio(filepath, sample_rate=16000, channels=1):
    audio = AudioSegment.from_file(filepath).set_frame_rate(sample_rate).set_channels(channels)
    audio_array = np.array(audio.get_array_of_samples(), dtype=np.float32)
    audio_array /= (1 << (8 * audio.sample_width - 1))  # normalize [-1, 1]
    audio_tensor = torch.tensor(audio_array).unsqueeze(0).unsqueeze(0)  # shape: (1, 1, T)
    return audio_tensor.to(device), audio.frame_rate

def postprocess_and_export(tensor, filename, sample_rate=16000):
    audio_np = tensor.squeeze().cpu().numpy()
    audio_np = np.clip(audio_np, -1.0, 1.0)
    audio = array_to_pydub(audio_np, sample_rate=sample_rate)
    audio.export(filename, format="wav")

# Inferenza su file arbitrario
def infer_from_path(path_to_wav, output_event_path="event_output.wav"):
    generator.eval()
    with torch.no_grad():
        mixture_tensor, sr = preprocess_audio(path_to_wav)
        output_background = generator(mixture_tensor)

        # Allinea dimensioni se serve
        min_len = min(mixture_tensor.shape[-1], output_background.shape[-1])
        mixture_tensor = mixture_tensor[..., :min_len]
        output_background = output_background[..., :min_len]

        # Residuo = Evento
        estimated_event = mixture_tensor - output_background

        postprocess_and_export(estimated_event, output_event_path, sample_rate=sr)
        print(f"Evento stimato salvato in: {output_event_path}")


In [None]:
infer_from_path("audio_sources/train_set\mix_4962\mixture.wav", "pc12_output_event.wav")