In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import numpy as np
import os
from tqdm import tqdm
from pydub import AudioSegment
import pandas as pd
from torchmetrics.audio import SignalDistortionRatio

In [None]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, target_duration=10000, target_sample_rate=16000,
                 target_channels=1):
        self.df = pd.read_csv(csv_file)
        self.target_duration = target_duration
        self.target_channels = target_channels
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        background_path = os.path.join(self.df.loc[idx, 'background'])
        mixture_path = os.path.join(self.df.loc[idx, 'mixture'])

        background_audio = AudioSegment.from_file(background_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)
        mixture_audio = AudioSegment.from_file(mixture_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)

        background, _ = self._pydub_to_array(background_audio)
        mixture, _ = self._pydub_to_array(mixture_audio)

        background_tensor = torch.Tensor(background)
        mixture_tensor = torch.Tensor(mixture)

        return mixture_tensor, background_tensor

    def _pydub_to_array(self, audio: AudioSegment) -> (np.ndarray, int):
        return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
                1 << (8 * audio.sample_width - 1)), audio.frame_rate

In [None]:
class WaveUNetBlock(nn.Module):
    """Blocco base per Wave-U-Net con convoluzione 1D"""
    def __init__(self, in_channels, out_channels, kernel_size=15, stride=1, dilation=1):
        super().__init__()
        padding = (kernel_size - 1) // 2 * dilation
        
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 
                             stride=stride, padding=padding, dilation=dilation)
        self.bn = nn.BatchNorm1d(out_channels)
        self.activation = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        return self.activation(self.bn(self.conv(x)))

class WaveUNetGenerator(nn.Module):
    """Wave-U-Net Generator per separazione audio"""
    def __init__(self, input_channels=2, output_channels=1, base_channels=16):
        super().__init__()
        
        # Encoder (downsampling)
        self.enc1 = WaveUNetBlock(input_channels, base_channels, kernel_size=15)
        self.enc2 = WaveUNetBlock(base_channels, base_channels*2, stride=2)
        self.enc3 = WaveUNetBlock(base_channels*2, base_channels*4, stride=2)
        self.enc4 = WaveUNetBlock(base_channels*4, base_channels*8, stride=2)
        self.enc5 = WaveUNetBlock(base_channels*8, base_channels*16, stride=2)
        
        # Bottleneck
        self.bottleneck = WaveUNetBlock(base_channels*16, base_channels*16, kernel_size=15)
        
        # Decoder (upsampling)
        self.dec5 = nn.ConvTranspose1d(base_channels*16, base_channels*8, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec5_conv = WaveUNetBlock(base_channels*16, base_channels*8)
        
        self.dec4 = nn.ConvTranspose1d(base_channels*8, base_channels*4, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec4_conv = WaveUNetBlock(base_channels*8, base_channels*4)
        
        self.dec3 = nn.ConvTranspose1d(base_channels*4, base_channels*2, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec3_conv = WaveUNetBlock(base_channels*4, base_channels*2)
        
        self.dec2 = nn.ConvTranspose1d(base_channels*2, base_channels, 
                                      kernel_size=4, stride=2, padding=1)
        self.dec2_conv = WaveUNetBlock(base_channels*2, base_channels)
        
        # Output layer
        self.final_conv = nn.Conv1d(base_channels, output_channels, kernel_size=1)
        
    def forward(self, mixture, source_b):
        # Input: concatena mixture e source_b
        x = torch.cat([mixture, source_b], dim=1)
        
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        
        # Bottleneck
        b = self.bottleneck(e5)
        
        # Decoder con skip connections
        d5 = self.dec5(b)
        d5 = torch.cat([d5, e4], dim=1)
        d5 = self.dec5_conv(d5)
        
        d4 = self.dec4(d5)
        d4 = torch.cat([d4, e3], dim=1)
        d4 = self.dec4_conv(d4)
        
        d3 = self.dec3(d4)
        d3 = torch.cat([d3, e2], dim=1)
        d3 = self.dec3_conv(d3)
        
        d2 = self.dec2(d3)
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2_conv(d2)
        
        # Output
        output = torch.tanh(self.final_conv(d2))
        return output

In [None]:

class AudioDiscriminator(nn.Module):
    """Discriminatore per audio 1D"""
    def __init__(self, input_channels=1, base_channels=64):
        super().__init__()
        
        self.layers = nn.Sequential(
            # Layer 1
            nn.Conv1d(input_channels, base_channels, kernel_size=15, stride=2, padding=7),
            nn.LeakyReLU(0.2),
            
            # Layer 2  
            nn.Conv1d(base_channels, base_channels*2, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*2),
            nn.LeakyReLU(0.2),
            
            # Layer 3
            nn.Conv1d(base_channels*2, base_channels*4, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*4),
            nn.LeakyReLU(0.2),
            
            # Layer 4
            nn.Conv1d(base_channels*4, base_channels*8, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*8),
            nn.LeakyReLU(0.2),
            
            # Layer 5
            nn.Conv1d(base_channels*8, base_channels*16, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(base_channels*16),
            nn.LeakyReLU(0.2),
        )
        
        # Global average pooling + classificazione
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(base_channels*16, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        features = self.layers(x)
        output = self.classifier(features)
        return output

In [35]:
def pydub_to_array(audio: AudioSegment) -> (np.ndarray, int):
    return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
            1 << (8 * audio.sample_width - 1)), audio.frame_rate

def array_to_pydub(audio_np_array: np.ndarray, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1) -> AudioSegment:
    return AudioSegment((audio_np_array * (2 ** (8 * sample_width - 1))).astype(np.int16).tobytes(),
                        frame_rate=sample_rate, sample_width=sample_width, channels=channels)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_csv = "train_dataset_wav.csv"
validation_csv = "val_dataset_wav.csv"
test_csv = "test_dataset_wav.csv"

# train_csv = 'pezz.csv'
# validation_csv = 'pezz.csv'
# test_csv = 'pezz.csv'

train = AudioDataset(csv_file=train_csv)
validation = AudioDataset(csv_file=validation_csv)
test = AudioDataset(csv_file=test_csv)

batch_size = 4

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

In [None]:
import torch
from torch import nn
from tqdm import tqdm


def train_event_separator(model,
                          train_loader,
                          val_loader=None,
                          *,
                          device: str = "cuda" if torch.cuda.is_available() else "cpu",
                          num_epochs: int = 50,
                          lr: float = 1e-4,
                          alpha: float = 1.0,
                          beta: float = 0.5,
                          scheduler=None,
                          save_path: str = "best_event_sep.pth"):

    device = torch.device(device)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    l1_loss = nn.L1Loss()

    history = {"train": [], "val": []}
    best_val_loss = float("inf")

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_train_loss = 0.0

        with tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [train]", unit="batch") as bar:
            for mixture, source_b in bar:
                mixture = mixture.to(device)
                source_b = source_b.to(device)

                # Calcola source_a come residuo
                source_a = mixture - source_b

                optimizer.zero_grad()

                # Il modello predice source_a a partire da (mixture, source_b)
                pred_source_a = model(mixture, source_b)

                # Loss di ricostruzione: pred + source_b ≈ mixture
                loss_rec = l1_loss(pred_source_a + source_b, mixture)

                # Loss sull’evento: pred ≈ source_a
                loss_evt = l1_loss(pred_source_a, source_a)

                loss = alpha * loss_rec + beta * loss_evt
                loss.backward()
                optimizer.step()

                running_train_loss += loss.item()
                bar.set_postfix(loss=running_train_loss / (bar.n + 1))

        epoch_train_loss = running_train_loss / len(train_loader)
        history["train"].append(epoch_train_loss)

        # -----------------------
        # VALIDATION
        # -----------------------
        if val_loader is not None:
            model.eval()
            running_val_loss = 0.0
            with torch.no_grad():
                for mixture, source_b in val_loader:
                    mixture = mixture.to(device)
                    source_b = source_b.to(device)
                    source_a = mixture - source_b

                    pred_source_a = model(mixture, source_b)
                    loss_rec = l1_loss(pred_source_a + source_b, mixture)
                    loss_evt = l1_loss(pred_source_a, source_a)
                    loss = alpha * loss_rec + beta * loss_evt

                    running_val_loss += loss.item()

            epoch_val_loss = running_val_loss / len(val_loader)
            history["val"].append(epoch_val_loss)

            print(f"\nEpoch {epoch}/{num_epochs} → train: {epoch_train_loss:.4f} | val: {epoch_val_loss:.4f}")
        else:
            history["val"].append(float("nan"))
            print(f"\nEpoch {epoch}/{num_epochs} → train: {epoch_train_loss:.4f}")

        # Scheduler
        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(epoch_val_loss if val_loader is not None else epoch_train_loss)
            else:
                scheduler.step()

        # Checkpoint
        if val_loader is not None and epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"→ Best model saved to '{save_path}' (val loss {best_val_loss:.4f})")

    return history


In [39]:
def infer_source_a(model, mixture, device="cpu"):
    """
    Estrae source_a da una mixture senza source_b.
    source_b viene impostato a zero.
    """
    model.eval()
    device = torch.device(device)
    mixture = mixture.to(device)

    # Crea source_b dummy con shape corretta
    source_b_dummy = torch.zeros_like(mixture[:, :1])  # Se mono
    with torch.no_grad():
        pred_source_a = model(mixture, source_b_dummy)
    return pred_source_a


In [None]:
model = WaveUNetGenerator(input_channels=2, output_channels=1)

history = train_event_separator(
    model,
    train_loader,
    val_loader=validation_loader,
    num_epochs=20,
    lr=1e-4,
    alpha=1.0,
    beta=0.5,
    save_path="best_model.pth"
)

In [42]:
mixture = AudioSegment.from_file("audio_sources\dataset_toy\mix_2\mixture.wav").set_channels(1).set_frame_rate(16000)

  mixture = AudioSegment.from_file("audio_sources\dataset_toy\mix_2\mixture.wav").set_channels(1).set_frame_rate(16000)


In [43]:
mixture_array, _ = pydub_to_array(mixture)
mixture_tensor = torch.tensor(mixture_array)

In [44]:
mixture_tensor.unsqueeze(0).shape

torch.Size([1, 1, 160000])

In [46]:
# Carica modello
model = WaveUNetGenerator(input_channels=2, output_channels=1)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

# Assumiamo mixture_tensor con shape (1, T) o (T,)
if mixture_tensor.dim() == 1:
    mixture_tensor = mixture_tensor.unsqueeze(0).unsqueeze(0)  # (1, 1, T)
elif mixture_tensor.dim() == 2:
    mixture_tensor = mixture_tensor.unsqueeze(0)  # (1, C, T)

# Inferenza
predicted_source_a = infer_source_a(model, mixture_tensor)

# Salvataggio - rimuovo dimensione batch (0)
torchaudio.save("predicted2.wav", predicted_source_a.squeeze(0).cpu(), _)
