In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchinfo
import torchaudio
import torchaudio.transforms as T
from torchinfo import summary
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
from pydub import AudioSegment
import warnings
import random
import torch.nn.init as init
from local_attention import LocalMHA
from torchmetrics.audio import SignalDistortionRatio
from demucs.hdemucs import HDemucs


warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [16]:
class WaveNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.wave_enc1_conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=8, stride=4)
        self.wave_enc1_norm1 = nn.BatchNorm1d(32)
        self.wave_enc1_conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=1)
        self.wave_enc1_norm2 = nn.BatchNorm1d(32)

        self.wave_enc2_conv1 = nn.Conv1d(in_channels=32, out_channels=128, kernel_size=8, stride=4)
        self.wave_enc2_norm1 = nn.BatchNorm1d(128)
        self.wave_enc2_conv2 = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=1)
        self.wave_enc2_norm2 = nn.BatchNorm1d(128)

        self.wave_enc3_conv1 = nn.Conv1d(in_channels=128, out_channels=512, kernel_size=8, stride=4)
        self.wave_enc3_norm1 = nn.BatchNorm1d(512)
        self.wave_enc3_conv2 = nn.Conv1d(in_channels=512, out_channels=512, kernel_size=1)
        self.wave_enc3_norm2 = nn.BatchNorm1d(512)

        self.wave_enc4_conv1 = nn.Conv1d(in_channels=512, out_channels=2048, kernel_size=8, stride=4)
        self.wave_enc4_norm1 = nn.BatchNorm1d(2048)
        self.wave_enc4_conv2 = nn.Conv1d(in_channels=2048, out_channels=2048, kernel_size=1)
        self.wave_enc4_norm2 = nn.BatchNorm1d(2048)

        # Bottleneck Attention
        self.wave_bn_conv1 = nn.Conv1d(in_channels=2048, out_channels=1024, kernel_size=1)
        self.wave_bn_norm1 = nn.BatchNorm1d(1024)
        self.wave_bn_local_attention = LocalMHA(dim=1024, window_size=32, heads=4, dropout=0.3, causal=False,
                                                  prenorm=False, exact_windowsize=False)
        self.wave_bn_conv2 = nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=1)
        self.wave_bn_norm2 = nn.BatchNorm1d(1024, affine=True)

        # Decoder
        self.wave_dec4_conv1 = nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=1, stride=1)
        self.wave_dec4_norm1 = nn.BatchNorm1d(512)
        self.wave_dec4_conv2 = nn.Conv1d(2560, 512, kernel_size=1)
        self.wave_dec4_norm2 = nn.BatchNorm1d(512)
        self.wave_dec4_deconv = nn.ConvTranspose1d(in_channels=512, out_channels=512, kernel_size=8, stride=4)
        self.wave_dec4_norm3 = nn.BatchNorm1d(512)

        self.wave_dec3_conv1 = nn.Conv1d(in_channels=512, out_channels=128, kernel_size=1, stride=1)
        self.wave_dec3_norm1 = nn.BatchNorm1d(128)
        self.wave_dec3_conv2 = nn.Conv1d(640, 128, kernel_size=1)
        self.wave_dec3_norm2 = nn.BatchNorm1d(128)
        self.wave_dec3_deconv = nn.ConvTranspose1d(in_channels=128, out_channels=128, kernel_size=8, stride=4)
        self.wave_dec3_norm3 = nn.BatchNorm1d(128)

        self.wave_dec2_conv1 = nn.Conv1d(in_channels=128, out_channels=32, kernel_size=1, stride=1)
        self.wave_dec2_norm1 = nn.BatchNorm1d(32)
        self.wave_dec2_conv2 = nn.Conv1d(160, 32, kernel_size=1)
        self.wave_dec2_norm2 = nn.BatchNorm1d(32)
        self.wave_dec2_deconv = nn.ConvTranspose1d(in_channels=32, out_channels=32, kernel_size=8, stride=4)
        self.wave_dec2_norm3 = nn.BatchNorm1d(32)

        self.wave_dec1_conv1 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=1, stride=1)
        self.wave_dec1_norm1 = nn.BatchNorm1d(16)
        self.wave_dec1_deconv = nn.ConvTranspose1d(in_channels=48, out_channels=16, kernel_size=8, stride=4)
        self.wave_dec1_norm2 = nn.BatchNorm1d(16)
        self.wave_dec1_conv2 = nn.Conv1d(16, 16, kernel_size=1)
        self.wave_dec1_norm3 = nn.BatchNorm1d(16)

        # Foreground/Background Separation
        self.wave_conv_foreground = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1)
        self.wave_norm_foreground = nn.BatchNorm1d(1)

        self.wave_conv_background = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1)
        self.wave_norm_background = nn.BatchNorm1d(1)

        # Output
        self.wave_output_conv_foreground = nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
        self.wave_output_norm_foreground = nn.BatchNorm1d(1)

        self.wave_output_conv_background = nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
        self.wave_output_norm_background = nn.BatchNorm1d(1)


    def forward(self, x):
        # Encoder
        wave_res1 = self.wave_enc1_conv1(x)
        wave_res1 = self.wave_enc1_norm1(wave_res1)
        wave_res1 = F.gelu(wave_res1)
        wave_res1 = self.wave_enc1_conv2(wave_res1)
        wave_res1 = self.wave_enc1_norm2(wave_res1)
        wave_e1 = F.gelu(wave_res1) + wave_res1

        wave_res2 = self.wave_enc2_conv1(wave_e1)
        wave_res2 = self.wave_enc2_norm1(wave_res2)
        wave_res2 = F.gelu(wave_res2)
        wave_e2 = self.wave_enc2_conv2(wave_res2)
        wave_e2 = self.wave_enc2_norm2(wave_e2)
        wave_e2 = F.gelu(wave_e2) + wave_res2

        wave_res3 = self.wave_enc3_conv1(wave_e2)
        wave_res3 = self.wave_enc3_norm1(wave_res3)
        wave_res3 = F.gelu(wave_res3)
        wave_e3 = self.wave_enc3_conv2(wave_res3)
        wave_e3 = self.wave_enc3_norm2(wave_e3)
        wave_e3 = F.gelu(wave_e3) + wave_res3

        wave_res4 = self.wave_enc4_conv1(wave_e3)
        wave_res4 = self.wave_enc4_norm1(wave_res4)
        wave_res4 = F.gelu(wave_res4)
        wave_e4 = self.wave_enc4_conv2(wave_res4)
        wave_e4 = self.wave_enc4_norm2(wave_e4)
        wave_e4 = F.gelu(wave_e4) + wave_res4

        #Bottleneck
        wave_bn_res = self.wave_bn_conv1(wave_e4)
        wave_bn_res = self.wave_bn_norm1(wave_bn_res)
        wave_bn_res = F.gelu(wave_bn_res)
        wave_bn =  wave_bn_res.permute(0, 2, 1)
        wave_bn = self.wave_bn_local_attention(wave_bn)
        wave_bn = wave_bn.permute(0, 2, 1)
        wave_bn = self.wave_bn_conv2(wave_bn)
        wave_bn = self.wave_bn_norm2(wave_bn)
        wave_bn = F.gelu(wave_bn) + wave_bn_res

        ### Decoder ###
        # Decoder Layer 4
        wave_res4 = self.wave_dec4_conv1(wave_bn)
        wave_res4 = self.wave_dec4_norm1(wave_res4)
        wave_res4 = F.gelu(wave_res4)
        wave_res4 = F.pad(wave_res4, (0, wave_e4.shape[-1] - wave_res4.shape[-1]))  # Padding per matching dimensionale
        wave_d4 = torch.cat([wave_res4, wave_e4], dim=1)  # Concatenazione con skip connection
        wave_d4 = self.wave_dec4_conv2(wave_d4)
        wave_d4 = self.wave_dec4_norm2(wave_d4)
        wave_d4 = F.gelu(wave_d4) + wave_res4  # Residual connection
        wave_d4 = self.wave_dec4_deconv(wave_d4)  # Upsampling
        wave_d4 = self.wave_dec4_norm3(wave_d4)
        wave_d4 = F.gelu(wave_d4)

        # Decoder Layer 3
        wave_res3 = self.wave_dec3_conv1(wave_d4)
        wave_res3 = self.wave_dec3_norm1(wave_res3)
        wave_res3 = F.gelu(wave_res3)
        wave_res3 = F.pad(wave_res3, (0, wave_e3.shape[-1] - wave_res3.shape[-1]))  # Padding per matching dimensionale
        wave_d3 = torch.cat([wave_res3, wave_e3], dim=1)  # Concatenazione con skip connection
        wave_d3 = self.wave_dec3_conv2(wave_d3)
        wave_d3 = self.wave_dec3_norm2(wave_d3)
        wave_d3 = F.gelu(wave_d3) + wave_res3  # Residual connection
        wave_d3 = self.wave_dec3_deconv(wave_d3)  # Upsampling
        wave_d3 = self.wave_dec3_norm3(wave_d3)
        wave_d3 = F.gelu(wave_d3)

        # Decoder Layer 2
        wave_res2 = self.wave_dec2_conv1(wave_d3)
        wave_res2 = self.wave_dec2_norm1(wave_res2)
        wave_res2 = F.gelu(wave_res2)
        wave_res2 = F.pad(wave_res2, (0, wave_e2.shape[-1] - wave_res2.shape[-1]))  # Padding per matching dimensionale
        wave_d2 = torch.cat([wave_res2, wave_e2], dim=1)  # Concatenazione con skip connection
        wave_d2 = self.wave_dec2_conv2(wave_d2)
        wave_d2 = self.wave_dec2_norm2(wave_d2)
        wave_d2 = F.gelu(wave_d2) + wave_res2  # Residual connection
        wave_d2 = self.wave_dec2_deconv(wave_d2)  # Upsampling
        wave_d2 = self.wave_dec2_norm3(wave_d2)
        wave_d2 = F.gelu(wave_d2)

        # Decoder Layer 1
        wave_d2 = self.wave_dec1_conv1(wave_d2)
        wave_d2 = self.wave_dec1_norm1(wave_d2)
        wave_d2 = F.gelu(wave_d2)
        wave_d2 = torch.cat([F.pad(wave_d2, (0, wave_e1.shape[-1] - wave_d2.shape[-1])), wave_e1], dim=1)  # Concatenazione con skip connection
        wave_res1 = self.wave_dec1_deconv(wave_d2)  # Upsampling
        wave_res1 = self.wave_dec1_norm2(wave_res1)
        wave_res1 = F.gelu(wave_res1)

        wave_d1 = self.wave_dec1_conv2(wave_res1)
        wave_d1 = self.wave_dec1_norm3(wave_d1)
        wave_d1 = F.gelu(wave_d1) + wave_res1  # Residual connection


        # Foreground/Background Separation
        wave_foreground = self.wave_conv_foreground(wave_d1)
        wave_background = self.wave_conv_background(wave_d1)
        wave_foreground = self.wave_norm_foreground(wave_foreground)
        wave_background = self.wave_norm_background(wave_background)
        wave_foreground = F.gelu(wave_foreground)
        wave_background = F.gelu(wave_background)

        # Output
        foreground = self.wave_output_conv_foreground(torch.cat([wave_foreground, x], dim=1))
        background = self.wave_output_conv_background(torch.cat([wave_background, x], dim=1))
        foreground =  F.gelu(foreground)
        background = F.gelu(background)
        foreground = self.wave_output_norm_foreground(foreground)
        background = self.wave_output_norm_background(background)
        foreground = F.tanh(foreground)
        background = F.tanh(background)

        return foreground, background

In [17]:
class VoxStrideNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder (rimane invariato)
        self.e1_conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=8, stride=4)
        self.e1_in1 = nn.InstanceNorm1d(32, affine=True)
        self.e1_conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=1)
        self.e1_in2 = nn.InstanceNorm1d(32, affine=True)

        self.e2_conv1 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=8, stride=4)
        self.e2_in1 = nn.InstanceNorm1d(64, affine=True)
        self.e2_bilstm = nn.LSTM(input_size=64, hidden_size=64, bidirectional=True, batch_first=True)
        self.e2_conv2 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=1)
        self.e2_in2 = nn.InstanceNorm1d(64, affine=True)

        self.e3_conv1 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=8, stride=4)
        self.e3_in1 = nn.InstanceNorm1d(128, affine=True)
        self.e3_bilstm = nn.LSTM(input_size=128, hidden_size=128, bidirectional=True, batch_first=True)
        self.e3_conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=1)
        self.e3_in2 = nn.InstanceNorm1d(128, affine=True)

        self.e4_conv1 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=8, stride=4)
        self.e4_in1 = nn.InstanceNorm1d(256, affine=True)
        self.e4_bilstm = nn.LSTM(input_size=256, hidden_size=256, bidirectional=True, batch_first=True)
        self.e4_conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1)
        self.e4_in2 = nn.InstanceNorm1d(256, affine=True)

        self.e5_conv1 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=8, stride=4)
        self.e5_in1 = nn.InstanceNorm1d(512, affine=True)
        self.e5_bilstm = nn.LSTM(input_size=512, hidden_size=512, bidirectional=True, batch_first=True)
        self.e5_conv2 = nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=1)
        self.e5_in2 = nn.InstanceNorm1d(512, affine=True)

        # Decoder
        self.d5_conv1 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1, stride=1)
        self.d5_in1 = nn.InstanceNorm1d(256, affine=True)
        self.d5_deconv1 = nn.ConvTranspose1d(in_channels=512 + 256, out_channels=256, kernel_size=8, stride=4)
        self.d5_in2 = nn.InstanceNorm1d(256, affine=True)
        self.d5_lstm = nn.LSTM(input_size=256, hidden_size=256, batch_first=True)
        self.d5_conv2 = nn.Conv1d(256, 256, kernel_size=1)
        self.d5_in3 = nn.InstanceNorm1d(256, affine=True)

        self.d4_conv1 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=1, stride=1)
        self.d4_in1 = nn.InstanceNorm1d(128, affine=True)
        self.d4_deconv1 = nn.ConvTranspose1d(in_channels=256 + 128, out_channels=128, kernel_size=8, stride=4)
        self.d4_in2 = nn.InstanceNorm1d(128, affine=True)
        self.d4_lstm = nn.LSTM(input_size=128, hidden_size=128, batch_first=True, bidirectional=True)
        self.d4_conv2 = nn.Conv1d(256, 128, kernel_size=1)
        self.d4_in3 = nn.InstanceNorm1d(128, affine=True)

        self.d3_conv1 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=1, stride=1)
        self.d3_in1 = nn.InstanceNorm1d(64, affine=True)
        self.d3_deconv1 = nn.ConvTranspose1d(in_channels=128 + 64, out_channels=64, kernel_size=8, stride=4)
        self.d3_in2 = nn.InstanceNorm1d(64, affine=True)
        self.d3_lstm = nn.LSTM(input_size=64, hidden_size=64, batch_first=True)
        self.d3_conv2 = nn.Conv1d(64, 64, kernel_size=1)
        self.d3_in3 = nn.InstanceNorm1d(64, affine=True)

        self.d2_conv1 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=1, stride=1)
        self.d2_in1 = nn.InstanceNorm1d(32, affine=True)
        self.d2_deconv1 = nn.ConvTranspose1d(in_channels=64 + 32, out_channels=32, kernel_size=8, stride=4)
        self.d2_in2 = nn.InstanceNorm1d(32, affine=True)
        self.d2_lstm = nn.LSTM(input_size=32, hidden_size=32, batch_first=True)
        self.d2_conv2 = nn.Conv1d(32, 32, kernel_size=1)
        self.d2_in3 = nn.InstanceNorm1d(32, affine=True)

        self.d1_conv1 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=1, stride=1)
        self.d1_in1 = nn.InstanceNorm1d(16, affine=True)
        self.d1_deconv1 = nn.ConvTranspose1d(in_channels=32 + 16, out_channels=16, kernel_size=8, stride=4)
        self.d1_in2 = nn.InstanceNorm1d(16, affine=True)
        self.d1_conv2 = nn.Conv1d(16, 16, kernel_size=1)
        self.d1_in3 = nn.InstanceNorm1d(16, affine=True)

        # Output
        self.output_conv_event = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1)
        self.output_in_event = nn.InstanceNorm1d(1, affine=True)
        self.output_conv_background = nn.Conv1d(in_channels=16, out_channels=1, kernel_size=1)
        self.output_in_background = nn.InstanceNorm1d(1, affine=True)

    def forward(self, x):
        # Encoder
        e1 = F.leaky_relu(self.e1_in1(self.e1_conv1(x)))
        res1 = e1
        e1 = F.gelu(self.e1_in2(self.e1_conv2(res1))) + res1

        e2 = F.leaky_relu(self.e2_in1(self.e2_conv1(e1)))
        res2 = e2
        e2 = e2.permute(0, 2, 1)
        e2, _ = self.e2_bilstm(e2)
        e2 = e2.permute(0, 2, 1)
        e2 = F.gelu(self.e2_in2(self.e2_conv2(e2))) + res2

        e3 = F.leaky_relu(self.e3_in1(self.e3_conv1(e2)))
        res3 = e3
        e3 = e3.permute(0, 2, 1)
        e3, _ = self.e3_bilstm(e3)
        e3 = e3.permute(0, 2, 1)
        e3 = F.gelu(self.e3_in2(self.e3_conv2(e3))) + res3

        e4 = F.leaky_relu(self.e4_in1(self.e4_conv1(e3)))
        res4 = e4
        e4 = e4.permute(0, 2, 1)
        e4, _ = self.e4_bilstm(e4)
        e4 = e4.permute(0, 2, 1)
        e4 = F.gelu(self.e4_in2(self.e4_conv2(e4))) + res4

        e5 = F.leaky_relu(self.e5_in1(self.e5_conv1(e4)))
        res5 = e5
        e5 = e5.permute(0, 2, 1)
        e5, _ = self.e5_bilstm(e5)
        e5 = e5.permute(0, 2, 1)
        e5 = F.gelu(self.e5_in2(self.e5_conv2(e5))) + res5

        d5 = F.leaky_relu(self.d5_in1(self.d5_conv1(e5)))
        d5 = torch.cat([F.pad(d5, (0, e5.shape[-1] - d5.shape[-1])), e5], dim=1)
        res5 = F.leaky_relu(self.d5_in2(self.d5_deconv1(d5)))
        d5 = res5.permute(0, 2, 1)
        d5, _ = self.d5_lstm(d5)
        d5 = d5.permute(0, 2, 1)
        d5 = F.leaky_relu(self.d5_in3(self.d5_conv2(d5))) + res5

        d4 = F.leaky_relu(self.d4_in1(self.d4_conv1(d5)))
        d4 = torch.cat([F.pad(d4, (0, e4.shape[-1] - d4.shape[-1])), e4], dim=1)
        res4 = F.leaky_relu(self.d4_in2(self.d4_deconv1(d4)))
        d4 = res4.permute(0, 2, 1)
        d4, _ = self.d4_lstm(d4)
        d4 = d4.permute(0, 2, 1)
        d4 = F.leaky_relu(self.d4_in3(self.d4_conv2(d4))) + res4

        d3 = F.leaky_relu(self.d3_in1(self.d3_conv1(d4)))
        d3 = torch.cat([F.pad(d3, (0, e3.shape[-1] - d3.shape[-1])), e3], dim=1)
        res3 = F.leaky_relu(self.d3_in2(self.d3_deconv1(d3)))
        d3 = res3.permute(0, 2, 1)
        d3, _ = self.d3_lstm(d3)
        d3 = d3.permute(0, 2, 1)
        d3 = F.leaky_relu(self.d3_in3(self.d3_conv2(d3))) + res3

        d2 = F.leaky_relu(self.d2_in1(self.d2_conv1(d3)))
        d2 = torch.cat([F.pad(d2, (0, e2.shape[-1] - d2.shape[-1])), e2], dim=1)
        res2 = F.leaky_relu(self.d2_in2(self.d2_deconv1(d2)))
        d2 = res2.permute(0, 2, 1)
        d2, _ = self.d2_lstm(d2)
        d2 = d2.permute(0, 2, 1)
        d2 = F.leaky_relu(self.d2_in3(self.d2_conv2(d2))) + res2

        d1 = F.leaky_relu(self.d1_in1(self.d1_conv1(d2)))
        d1 = torch.cat([F.pad(d1, (0, e1.shape[-1] - d1.shape[-1])), e1], dim=1)
        res1 = F.leaky_relu(self.d1_in2(self.d1_deconv1(d1)))
        d1 = F.leaky_relu(self.d1_in3(self.d1_conv2(res1))) + res1

        # Output
        event = F.tanh(self.output_in_event(self.output_conv_event(d1)))
        background = F.tanh(self.output_in_background(self.output_conv_background(d1)))

        return event, background

In [18]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, target_duration=10000, target_sample_rate=10000,
                 target_channels=1):
        self.df = pd.read_csv(csv_file)
        self.target_duration = target_duration
        self.target_channels = target_channels
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        event_path = os.path.join(self.df.loc[idx, 'event'])
        background_path = os.path.join(self.df.loc[idx, 'background'])
        mixture_path = os.path.join(self.df.loc[idx, 'mixture'])

        event_audio = AudioSegment.from_file(event_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)
        background_audio = AudioSegment.from_file(background_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)
        mixture_audio = AudioSegment.from_file(mixture_path).set_channels(self.target_channels).set_frame_rate(self.target_sample_rate)

        event, _ = self._pydub_to_array(event_audio)
        background, _ = self._pydub_to_array(background_audio)
        mixture, _ = self._pydub_to_array(mixture_audio)

        event_tensor = torch.Tensor(event)
        background_tensor = torch.Tensor(background)
        mixture_tensor = torch.Tensor(mixture)

        return mixture_tensor, event_tensor, background_tensor

    def _pydub_to_array(self, audio: AudioSegment) -> (np.ndarray, int):
        return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
                1 << (8 * audio.sample_width - 1)), audio.frame_rate

In [19]:
def pydub_to_array(audio: AudioSegment) -> (np.ndarray, int):
    return np.array(audio.get_array_of_samples(), dtype=np.float32).reshape((audio.channels, -1)) / (
            1 << (8 * audio.sample_width - 1)), audio.frame_rate

def array_to_pydub(audio_np_array: np.ndarray, sample_rate: int = 10000, sample_width: int = 2, channels: int = 1) -> AudioSegment:
    return AudioSegment((audio_np_array * (2 ** (8 * sample_width - 1))).astype(np.int16).tobytes(),
                        frame_rate=sample_rate, sample_width=sample_width, channels=channels)

In [20]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_csv = "utils/train_dataset_wav.csv"
validation_csv = "utils/val_dataset_wav.csv"
test_csv = "utils/test_dataset_wav.csv"

# train_csv = 'pezz.csv'
# validation_csv = 'pezz.csv'
# test_csv = 'pezz.csv'

train = AudioDataset(csv_file=train_csv)
validation = AudioDataset(csv_file=validation_csv)
test = AudioDataset(csv_file=test_csv)

batch_size = 8

train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)

In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = VoxStrideNet().to(device)
model = HDemucs(audio_channels=1, channels=24, sources=['event', 'background']).to(device)

criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
sdr_metric = SignalDistortionRatio().to(device)

In [22]:
#File CSV per il log delle metriche
csv_filename = "HD_sep_training_log.csv"

# Crea un DataFrame vuoto per il log
if os.path.exists(csv_filename):
    log_df = pd.read_csv(csv_filename)
else:
    log_df = pd.DataFrame(columns=["epoch", "train_loss", "train_sdr", "val_loss", "val_sdr"])


# Numero di epoche
num_epochs = 20
best_val_loss = float("inf")

# Loop di training
for epoch in range(num_epochs):
    model.train()  # Imposta il modello in modalità training
    train_running_loss = 0.0
    train_running_sdr = 0.0

    # Usa tqdm per una barra di avanzamento
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Train", unit="batch") as train_bar:
        for mixture, event, background in train_bar:
            train_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}")

            # Sposta i dati sul dispositivo corretto
            mixture = mixture.to(device)
            event = event.to(device)
            background = background.to(device)

            # Azzera i gradienti
            optimizer.zero_grad()

            # Forward pass
            output = model(mixture)

            if isinstance(output, tuple):
                event_pred, background_pred = output
            else:
                event_pred, background_pred = output[:, 0, :, :], output[:, 1, :, :]


            # Calcola la loss
            loss_event = criterion(event_pred, event)
            loss_background = criterion(background_pred, background)
            loss = loss_event + loss_background

            # Backward pass e ottimizzazione
            loss.backward()
            optimizer.step()

            # Calcola l'SDR
            sdr_event = sdr_metric(event_pred, event)
            sdr_background = sdr_metric(background_pred, background)
            sdr = (sdr_event + sdr_background) / 2  # Media dell'SDR per evento e background

            # Aggiorna la loss e l'SDR totali
            train_running_loss += loss.item()
            train_running_sdr += sdr.item()

            # Aggiorna la barra di avanzamento con loss e SDR
            train_bar.set_postfix(loss=train_running_loss / (train_bar.n + 1), sdr=train_running_sdr / (train_bar.n + 1))

    train_running_loss /= len(train_loader)
    train_running_sdr /= len(train_loader)

    # Validation loop (opzionale)
    model.eval()  # Imposta il modello in modalità evaluation
    val_running_loss = 0.0
    val_running_sdr = 0.0
    with torch.no_grad():
        with tqdm(validation_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation", unit="batch") as val_bar:
            for mixture, event, background in val_bar:
                    mixture = mixture.to(device)
                    event = event.to(device)
                    background = background.to(device)

                    output = model(mixture)

                    if isinstance(output, tuple):
                        event_pred, background_pred = output
                    else:
                        event_pred, background_pred = output[:, 0, :, :], output[:, 1, :, :]


                    loss_event = criterion(event_pred, event)
                    loss_background = criterion(background_pred, background)
                    loss = loss_event + loss_background

                    sdr_event = sdr_metric(event_pred, event)
                    sdr_background = sdr_metric(background_pred, background)
                    sdr = (sdr_event + sdr_background) / 2

                    val_running_loss += loss.item()
                    val_running_sdr += sdr.item()

                    val_bar.set_postfix(loss=val_running_loss / (val_bar.n + 1), val_sdr=val_running_sdr / (val_bar.n + 1))

        val_running_loss /= len(validation_loader)
        val_running_sdr /= len(validation_loader)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_running_loss:.4f}, Train SDR: {train_running_sdr:.4f}')
    print("-" * 70)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {val_running_loss:.4f}, Validation SDR: {val_running_sdr:.4f}')


     # Aggiorna il log
    new_row = pd.DataFrame({
        "epoch": [epoch + 1],
        "train_loss": [train_running_loss],
        "train_sdr": [train_running_sdr],
        "val_loss": [val_running_loss],
        "val_sdr": [val_running_sdr]})

    log_df = pd.concat([log_df, new_row], ignore_index=True)
    log_df.to_csv(csv_filename, index=False)

    # Salvataggio modello se migliora
    if val_running_loss < best_val_loss:
        best_val_loss = val_running_loss
        torch.save(model.state_dict(), "HD_sep_best_model.pth")
        print(f"Best Model Saved (Epoch {epoch+1}, Val Loss: {val_running_loss:.4f})")
        print("-" * 70)

print(f"Training log salvato in {csv_filename}")

Epoch 1/20:   0%|          | 21/8750 [00:19<2:17:01,  1.06batch/s, loss=0.189, sdr=-11.6]


KeyboardInterrupt: 