In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import numpy as np
import os
from tqdm import tqdm
from pydub import AudioSegment
import pandas as pd
from torchmetrics.audio import SignalDistortionRatio

In [3]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, target_duration=10000, target_sample_rate=16000,
                 target_channels=1):
        self.df = pd.read_csv(csv_file)
        self.target_duration = target_duration
        self.target_channels = target_channels
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        background_path = os.path.join(self.df.loc[idx, 'background'])
        mixture_path = os.path.join(self.df.loc[idx, 'mixture'])

        background_audio = AudioSegment.from_file(background_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)
        mixture_audio = AudioSegment.from_file(mixture_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)

        background, _ = self._pydub_to_array(background_audio)
        mixture, _ = self._pydub_to_array(mixture_audio)

        background_tensor = torch.Tensor(background)
        mixture_tensor = torch.Tensor(mixture)

        return mixture_tensor, background_tensor

    def _pydub_to_array(self, audio: AudioSegment) -> (np.ndarray, int):
        return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
                1 << (8 * audio.sample_width - 1)), audio.frame_rate

In [4]:
class WaveUNetBlock(nn.Module):
    """Blocco base per Wave-U-Net con convoluzione 1D"""
    def __init__(self, in_channels, out_channels, kernel_size=15, stride=1, dilation=1):
        super().__init__()
        padding = (kernel_size - 1) // 2 * dilation
        
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 
                             stride=stride, padding=padding, dilation=dilation)
        self.bn = nn.BatchNorm1d(out_channels)
        self.activation = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        return self.activation(self.bn(self.conv(x)))

class WaveUNetGenerator(nn.Module):
    """Wave-U-Net Generator per separazione audio"""
    def __init__(self, input_channels=2, output_channels=1, base_channels=16):
        super().__init__()
        
        # Encoder (downsampling)
        self.enc1 = WaveUNetBlock(input_channels, base_channels, kernel_size=15)
        self.enc2 = WaveUNetBlock(base_channels, base_channels*2, stride=2)
        self.enc3 = WaveUNetBlock(base_channels*2, base_channels*4, stride=2)
        self.enc4 = WaveUNetBlock(base_channels*4, base_channels*8, stride=2)
        self.enc5 = WaveUNetBlock(base_channels*8, base_channels*16, stride=2)
        
        # Bottleneck
        self.bottleneck = WaveUNetBlock(base_channels*16, base_channels*16, kernel_size=15)
        
        # Decoder (upsampling)
        self.dec5 = nn.ConvTranspose1d(base_channels*16, base_channels*8, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec5_conv = WaveUNetBlock(base_channels*16, base_channels*8)
        
        self.dec4 = nn.ConvTranspose1d(base_channels*8, base_channels*4, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec4_conv = WaveUNetBlock(base_channels*8, base_channels*4)
        
        self.dec3 = nn.ConvTranspose1d(base_channels*4, base_channels*2, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec3_conv = WaveUNetBlock(base_channels*4, base_channels*2)
        
        self.dec2 = nn.ConvTranspose1d(base_channels*2, base_channels, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec2_conv = WaveUNetBlock(base_channels*2, base_channels)
        
        # Output layer
        self.final_conv = nn.Conv1d(base_channels, output_channels, kernel_size=1)
        
    def forward(self, mixture, source_b):
        # Input: concatena mixture e source_b
        x = torch.cat([mixture, source_b], dim=1)
        
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        
        # Bottleneck
        b = self.bottleneck(e5)
        
        # Decoder con skip connections
        d5 = self.dec5(b)
        d5 = torch.cat([d5, e4], dim=1)
        d5 = self.dec5_conv(d5)
        
        d4 = self.dec4(d5)
        d4 = torch.cat([d4, e3], dim=1)
        d4 = self.dec4_conv(d4)
        
        d3 = self.dec3(d4)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3_conv(d3)
        
        d2 = self.dec2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2_conv(d2)
        
        # Output
        output = torch.tanh(self.final_conv(d2))
        return output

In [5]:

class AudioDiscriminator(nn.Module):
    """Discriminatore per audio 1D"""
    def __init__(self, input_channels=1, base_channels=64):
        super().__init__()
        
        self.layers = nn.Sequential(
            # Layer 1
            nn.Conv1d(input_channels, base_channels, kernel_size=15, stride=2, padding=7),
            nn.LeakyReLU(0.2),
            
            # Layer 2  
            nn.Conv1d(base_channels, base_channels*2, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*2),
            nn.LeakyReLU(0.2),
            
            # Layer 3
            nn.Conv1d(base_channels*2, base_channels*4, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*4),
            nn.LeakyReLU(0.2),
            
            # Layer 4
            nn.Conv1d(base_channels*4, base_channels*8, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*8),
            nn.LeakyReLU(0.2),
            
            # Layer 5
            nn.Conv1d(base_channels*8, base_channels*16, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*16),
            nn.LeakyReLU(0.2),
        )
        
        # Global average pooling + classificazione
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(base_channels*16, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        features = self.layers(x)
        output = self.classifier(features)
        return output

In [6]:
def pydub_to_array(audio: AudioSegment) -> (np.ndarray, int):
    return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
            1 << (8 * audio.sample_width - 1)), audio.frame_rate

def array_to_pydub(audio_np_array: np.ndarray, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1) -> AudioSegment:
    return AudioSegment((audio_np_array * (2 ** (8 * sample_width - 1))).astype(np.int16).tobytes(),
                        frame_rate=sample_rate, sample_width=sample_width, channels=channels)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_csv = "train_dataset_wav.csv"
validation_csv = "val_dataset_wav.csv"
test_csv = "test_dataset_wav.csv"

# train_csv = 'pezz.csv'
# validation_csv = 'pezz.csv'
# test_csv = 'pezz.csv'

train = AudioDataset(csv_file=train_csv)
validation = AudioDataset(csv_file=validation_csv)
test = AudioDataset(csv_file=test_csv)

batch_size = 2

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ---------- modelli ----------
generator = WaveUNetGenerator(input_channels=2, output_channels=1).to(device)
discriminator = AudioDiscriminator(input_channels=1).to(device)

# ---------- ottimizzatori ----------
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

# ---------- loss ----------
bce_loss = nn.BCELoss()
l1_loss  = nn.L1Loss()
lambda_adv = 1e-3          # bilanciamento adversarial vs ricostruzione

# ---------- training ----------
n_epochs          = 10
best_val_g_loss   = float("inf")
os.makedirs("models", exist_ok=True)

for epoch in range(1, n_epochs + 1):
    # ======== TRAIN ========
    generator.train()
    discriminator.train()
    train_g_loss = 0.0
    train_d_loss = 0.0

    for mixture, background in tqdm(train_loader,
                                    desc=f"Epoch {epoch}/{n_epochs} [train]",
                                    leave=False):
        mixture    = mixture.to(device)      # (B,1,T)
        background = background.to(device)   # (B,1,T)
        real_a     = (mixture - background).detach()

        # ---- Generatore ----
        pred_a = generator(mixture, background)
        d_fake = discriminator(pred_a)
        g_adv  = bce_loss(d_fake, torch.ones_like(d_fake))
        g_rec  = l1_loss(pred_a, real_a)
        g_loss = g_rec + lambda_adv * g_adv

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # ---- Discriminatore ----
        with torch.no_grad():
            fake_detached = pred_a.detach()
        d_real = discriminator(real_a)
        d_fake = discriminator(fake_detached)
        d_loss = bce_loss(d_real, torch.ones_like(d_real)) + \
                 bce_loss(d_fake, torch.zeros_like(d_fake))

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        train_g_loss += g_loss.item()
        train_d_loss += d_loss.item()

    avg_train_g = train_g_loss / len(train_loader)
    avg_train_d = train_d_loss / len(train_loader)

    # ======== VALIDATION ========
    generator.eval()
    discriminator.eval()
    val_g_loss = 0.0
    val_d_loss = 0.0

    with torch.no_grad():
        for mixture, background in tqdm(validation_loader,
                                        desc=f"Epoch {epoch}/{n_epochs} [val ]",
                                        leave=False):
            mixture    = mixture.to(device)
            background = background.to(device)
            real_a     = (mixture - background)

            pred_a = generator(mixture, background)
            d_fake = discriminator(pred_a)
            g_adv  = bce_loss(d_fake, torch.ones_like(d_fake))
            g_rec  = l1_loss(pred_a, real_a)
            g_loss = g_rec + lambda_adv * g_adv
            val_g_loss += g_loss.item()

            # Discriminatore – solo forward (no grad)
            d_real = discriminator(real_a)
            d_fake = discriminator(pred_a)
            d_loss = bce_loss(d_real, torch.ones_like(d_real)) + \
                     bce_loss(d_fake, torch.zeros_like(d_fake))
            val_d_loss += d_loss.item()

    avg_val_g = val_g_loss / len(validation_loader)
    avg_val_d = val_d_loss / len(validation_loader)

    # ----- log -----
    print(f"Epoch {epoch:02d}/{n_epochs} | "
          f"train G:{avg_train_g:.4f} D:{avg_train_d:.4f} || "
          f"val G:{avg_val_g:.4f} D:{avg_val_d:.4f}")

    # ----- checkpoint sul miglior val_G -----
    if avg_val_g < best_val_g_loss:
        best_val_g_loss = avg_val_g
        torch.save(generator.state_dict(), "models/best_generator.pth")
        torch.save(discriminator.state_dict(), "models/best_discriminator.pth")
        print(f"Saved best models (val G loss {best_val_g_loss:.4f})")


                                                                 

Epoch 01/10 | train G:0.2118 D:1.1485 || val G:0.2136 D:1.6362
Saved best models (val G loss 0.2136)


                                                                 

Epoch 02/10 | train G:0.1681 D:0.7624 || val G:0.1858 D:2.7815
Saved best models (val G loss 0.1858)


                                                                 

Epoch 03/10 | train G:0.1596 D:0.6032 || val G:0.1565 D:3.4864
Saved best models (val G loss 0.1565)


                                                                 

Epoch 04/10 | train G:0.1537 D:0.5417 || val G:0.1298 D:2.7105
Saved best models (val G loss 0.1298)


                                                                 

Epoch 05/10 | train G:0.1404 D:0.4661 || val G:0.1052 D:2.8606
Saved best models (val G loss 0.1052)


                                                                 

Epoch 06/10 | train G:0.1339 D:0.5403 || val G:0.0880 D:5.0187
Saved best models (val G loss 0.0880)


                                                                 

Epoch 07/10 | train G:0.1255 D:0.3899 || val G:0.0729 D:3.9754
Saved best models (val G loss 0.0729)


                                                                 

Epoch 08/10 | train G:0.1163 D:0.3236 || val G:0.0596 D:3.8117
Saved best models (val G loss 0.0596)


                                                                 

Epoch 09/10 | train G:0.1107 D:0.2770 || val G:0.0705 D:4.2961


                                                                  

Epoch 10/10 | train G:0.1107 D:0.4047 || val G:0.0415 D:4.1080
Saved best models (val G loss 0.0415)




In [9]:
def infer_loop(generator, discriminator, dataloader, device="cpu", save_audio=False, save_path="outputs"):
    import os
    from torchaudio import save as save_wav
    os.makedirs(save_path, exist_ok=True)

    generator.eval()
    if discriminator:
        discriminator.eval()

    device = torch.device(device)
    generator.to(device)
    if discriminator:
        discriminator.to(device)

    results = []

    for i, (mixture, _) in enumerate(dataloader):
        mixture = mixture.to(device)  # (B, 1, T)

        # Dummy source B per l'inferenza
        source_b_dummy = torch.zeros_like(mixture[:, :1])

        with torch.no_grad():
            pred_source_a = generator(mixture, source_b_dummy)  # (B, 1, T)

            if discriminator:
                realism_score = discriminator(pred_source_a)
            else:
                realism_score = None

        # Salva o colleziona risultati
        for j in range(pred_source_a.size(0)):
            idx = i * dataloader.batch_size + j
            if save_audio:
                filename = os.path.join(save_path, f"predicted_v2{idx}.wav")
                save_wav(filename, pred_source_a[j].cpu(), sample_rate=16000)

            results.append({
                "index": idx,
                "realism_score": realism_score[j].item() if realism_score is not None else None,
                "prediction": pred_source_a[j].detach().cpu()
            })

    return results


In [10]:
generator = WaveUNetGenerator(input_channels=2, output_channels=1)
generator.load_state_dict(torch.load("models/best_generator.pth"))
discriminator = AudioDiscriminator(input_channels=1)
discriminator.load_state_dict(torch.load("models/best_discriminator.pth"))

results = infer_loop(generator, discriminator, test_loader, device="cpu", save_audio=True)