<a href="https://colab.research.google.com/github/cnovak232/DL_Speech_Enhancement/blob/main/DL_Final_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Load Data by cloning the repo - easiest way to access shared data

In [None]:
!git clone https://github.com/cnovak232/DL_Speech_Enhancement.git


Define some helper function for plotting and playing audio

In [None]:
import torch
import torchaudio as ta
import librosa as lib
from IPython.display import Audio, display
import matplotlib
import matplotlib.pyplot as plt

# helper functions for audio and what not
# mostly taken for torchaudio tutorials 

def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
    if num_channels > 1:
        axes[c].set_ylabel(f'Channel {c+1}')
    if xlim:
        axes[c].set_xlim(xlim)
    if ylim:
        axes[c].set_ylim(ylim)

    figure.suptitle(title)
    plt.show(block=False)

def get_spectrogram(
    waveform = None,
    n_fft = 512,
    win_len = None,
    hop_len = None,
    power = 1.0 ):
    spectrogram = ta.transforms.Spectrogram(
      n_fft=n_fft,
      win_length=win_len,
      hop_length=hop_len,
      center=True,
      pad_mode="reflect",
      power=power )
    
    return spectrogram(waveform)

def plot_spectrogram(spec, type = "amplitude", title=None, ylabel='freq_bin', aspect='auto', xmax=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or 'Spectrogram (db)')
    axs.set_ylabel(ylabel)
    axs.set_xlabel('frame')
    toDb = ta.transforms.AmplitudeToDB(type)
    im = axs.imshow(toDb(spec), origin='lower', aspect=aspect)
    if xmax:
        axs.set_xlim((0, xmax))
    fig.colorbar(im, ax=axs)
    plt.show(block=False)

def norm_spec( spec ):
    return spec, 1
    #normed = spec / spec.max()
    #return normed, spec.max()


Define a Custom Dataset class for the Data and read it in

In [None]:
# read dataset in and downsample / transform / pad if needed
from torch.utils.data import Dataset
import torch
import os

class VoiceBankDemand(Dataset):
    def __init__(self, clean_dir, noisy_dir, list_dir, 
                 data = "train", len_samples = None, downsample = None, 
                 transform = None ):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.list_dir = list_dir
        self.num_samples = len_samples
        self.downsample = downsample
        self.transform = transform
        self.data = data
    
    def __len__(self):
        return len(self.list_dir)

    def __getitem__( self, idx ):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        clean_name = os.path.join( self.clean_dir, self.list_dir[idx] )
        noisy_name = os.path.join( self.noisy_dir, self.list_dir[idx] )
        clean_audio, fs = ta.load(clean_name)
        noisy_audio, fs= ta.load(noisy_name)

        if self.downsample:
            downsampler = ta.transforms.Resample(fs,self.downsample)
            clean_audio = downsampler( clean_audio )
            noisy_audio = downsampler( noisy_audio )

        orig_len = clean_audio.shape[1]

        if self.num_samples:
            # trim or pad audio to fixed length 
            if clean_audio.shape[1] > num_samples:
                clean_audio = clean_audio[:,:num_samples]
                noisy_audio = noisy_audio[:,:num_samples]
            elif clean_audio.shape[1] < num_samples:
                pad_len = int( num_samples - clean_audio.shape[1] )
                pad = torch.zeros(1,pad_len)
                clean_audio = torch.cat((clean_audio,pad), dim=1)
                noisy_audio = torch.cat((noisy_audio,pad),dim=1)
        
        if self.data == "test":
            if self.transform:
                noisy_trnsfrm = self.transform( noisy_audio )
                clean_trnsfrm = self.transform( clean_audio )
                clean_mag,_ = norm_spec( torch.abs(clean_trnsfrm) )
                clean_audio = (clean_audio, clean_mag)
                noisy_mag, norm_val = norm_spec( torch.abs(noisy_trnsfrm) )
                noisy_phase = torch.angle(noisy_trnsfrm)
                noisy_audio = (noisy_mag, noisy_phase, noisy_audio, norm_val)
        else:
            if self.transform:
                clean_audio,_ = norm_spec( self.transform( clean_audio ) )
                noisy_audio,_ = norm_spec( self.transform( noisy_audio ) )

        sample = (clean_audio, noisy_audio, orig_len)

        return sample

train_clean_path = './DL_Speech_Enhancement/clean_trainset_28spk_wav'
train_noisy_path = './DL_Speech_Enhancement/noisy_trainset_28spk_wav'
test_clean_path  = './DL_Speech_Enhancement/clean_testset_wav'
test_noisy_path  = './DL_Speech_Enhancement/noisy_testset_wav'

list_dir_train = os.listdir(train_clean_path)
list_dir_train.sort()
list_dir_test = os.listdir(test_clean_path)
list_dir_test.sort()

target_fs = 8000 # downsample to 16 KHz
spectrogram = ta.transforms.Spectrogram(
    n_fft=512,
    power=1.0,
    normalized = False )
complex_spec = ta.transforms.Spectrogram(
    n_fft=512,
    power=None,
    normalized = False ) # return complex spectrum

WAVE_UNET_LEN = 16384 # length the WaveUNET processing
num_samples = WAVE_UNET_LEN * 2

# return turns a tuple of ( clean_data,noisy_data,original_length)
# data type depends on if transform was specified
train_set = VoiceBankDemand( clean_dir = train_clean_path,
                             noisy_dir = train_noisy_path,
                             list_dir = list_dir_train,
                             len_samples = num_samples, # clip or pad samples to be 5s
                             downsample = target_fs, # downsample to 16Khz
                             transform = None )

# returns tuples within tuples: clean_data, noisy_data, original length
# If using a transform (like spectrogram) the subtuples will be:
# clean_data: (clean_waveform, clean_spec)
# noisy_data: ( noisy_spec, noisy_phase, noisy_waveform, normalization_val )
test_set = VoiceBankDemand( clean_dir = test_clean_path,
                            noisy_dir = test_noisy_path,
                            list_dir = list_dir_test,
                            data = "test",
                            len_samples = num_samples,
                            downsample = target_fs,
                            transform = None )


clean, noisey, orgin_len = train_set[1]

plot_waveform(clean,target_fs)

#clean_test, noisy_test, orig_len = test_set[0]

#noisy_mag, noisy_phase, noisy_audio, norm_val = noisy_test

#plot_spectrogram(clean_spec.squeeze())
#plot_spectrogram(noisy_spec.squeeze())

#play_audio(noisy_audio[:,:orig_len],target_fs)

In [None]:
import torch
import torch.nn as nn

# U-Net Style Autoencoder from to start from
# https://medium.com/@sriskandaryan/autoencoders-demystified-audio-signal-denoising-32a491ab023a

class UNet(nn.Module):
    def __init__(self, chnls_in=1, chnls_out=1):
        super(UNet, self).__init__()
        self.down_conv_layer_1 = DownConvBlock(chnls_in, 64, norm=False)
        self.down_conv_layer_2 = DownConvBlock(64, 128)
        self.down_conv_layer_3 = DownConvBlock(128, 256)
        self.down_conv_layer_4 = DownConvBlock(256, 256, dropout=0.5)
        self.down_conv_layer_5 = DownConvBlock(256, 256, dropout=0.5)
        self.down_conv_layer_6 = DownConvBlock(256, 256, dropout=0.5)

        self.up_conv_layer_1 = UpConvBlock(256, 256, kernel_size=(2,3), stride=2, padding=0, dropout=0.5)# 256+256 6 5 kernel_size=(2, 3), stride=2, padding=0
        self.up_conv_layer_2 = UpConvBlock(512, 256, kernel_size=(2,3), stride=2, padding=0, dropout=0.5) # 256+256 1 4
        self.up_conv_layer_3 = UpConvBlock(512, 256, kernel_size=(2,3), stride=2, padding=0, output_padding = 2, dropout=0.5) # 2 3
        self.up_conv_layer_4 = UpConvBlock(512, 128, dropout=0.5) # 3 2
        self.up_conv_layer_5 = UpConvBlock(256, 64) # 4 1
        self.up_conv_layer_6 = UpConvBlock(512, 128)
        self.up_conv_layer_7 = UpConvBlock(256, 64)
        self.upsample_layer = nn.Upsample(scale_factor=2)
        self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0))
        self.conv_layer_1 = nn.Conv2d(128, chnls_out, 4, padding='same')
        self.activation = nn.Tanh()
    
    def forward(self, x):
        enc1 = self.down_conv_layer_1(x)
        enc2 = self.down_conv_layer_2(enc1) 
        enc3 = self.down_conv_layer_3(enc2)
        enc4 = self.down_conv_layer_4(enc3)
        enc5 = self.down_conv_layer_5(enc4)
        enc6 = self.down_conv_layer_6(enc5)
 
        dec1 = self.up_conv_layer_1(enc6, enc5)
        dec2 = self.up_conv_layer_2(dec1, enc4)
        dec3 = self.up_conv_layer_3(dec2, enc3)
        dec4 = self.up_conv_layer_4(dec3, enc2)
        dec5 = self.up_conv_layer_5(dec4, enc1)

        final = self.upsample_layer(dec5)
        final = self.zero_pad(final)
        final = self.conv_layer_1(final)
        return final

class UpConvBlock(nn.Module):
    def __init__(self, ip_sz, op_sz, kernel_size=4, stride= 2, padding=0 , output_padding = 0, dropout=0.0):
        super(UpConvBlock, self).__init__()
        self.layers = nn.ModuleList([
            nn.ConvTranspose2d(ip_sz, op_sz, kernel_size=kernel_size, stride=stride, padding=padding, output_padding = output_padding),
            nn.InstanceNorm2d(op_sz),
            nn.ReLU(),
        ])
        if dropout:
            self.layers += [nn.Dropout(dropout)]
    def forward(self, x, enc_ip):
        x = nn.Sequential(*(self.layers))(x)
        op = torch.cat((x, enc_ip), 1)
        return op


class DownConvBlock(nn.Module):
    def __init__(self, ip_sz, op_sz, kernel_size=4, norm=True, dropout=0.0):
        super(DownConvBlock, self).__init__()
        self.layers = nn.ModuleList([nn.Conv2d(ip_sz, op_sz, kernel_size, 2, 1)])
        if norm:
            self.layers.append(nn.InstanceNorm2d(op_sz))
        self.layers += [nn.LeakyReLU(0.2)]
        if dropout:
            self.layers += [nn.Dropout(dropout)]
    def forward(self, x):
        op = nn.Sequential(*(self.layers))(x)
        return op

In [None]:
import torch.nn.utils.rnn as rnn

class SpeechEnhancementLSTM(nn.Module):
  #input size, hidden size, number of layers, output size, and a flag for using bidirectional LSTMs as its parameters
    def __init__(self, input_size, hidden_size, num_layers, output_size, bidirectional=True):
        super(SpeechEnhancementLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
        self.fc = nn.Linear(hidden_size * 2 if bidirectional else hidden_size, output_size)

    def forward(self, x, lengths):
        x = rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        x, _ = self.lstm(x)
        x, _ = rnn.pad_packed_sequence(x, batch_first=True)
        x = self.fc(x)
        return x

WaveU-Net for Speech Denoising
For Reference:
https://github.com/haoxiangsnr/Wave-U-Net-for-Speech-Enhancement
https://arxiv.org/pdf/1806.03185.pdf


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DownSamplingLayer(nn.Module):
    def __init__(self, channel_in, channel_out, dilation=1, kernel_size=15, stride=1, padding=7):
        super(DownSamplingLayer, self).__init__()
        self.main = nn.Sequential(
            nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,
                      stride=stride, padding=padding, dilation=dilation),
            nn.BatchNorm1d(channel_out),
            nn.LeakyReLU(negative_slope=0.1)
        )

    def forward(self, ipt):
        return self.main(ipt)

class UpSamplingLayer(nn.Module):
    def __init__(self, channel_in, channel_out, kernel_size=5, stride=1, padding=2):
        super(UpSamplingLayer, self).__init__()
        self.main = nn.Sequential(
            nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,
                      stride=stride, padding=padding),
            nn.BatchNorm1d(channel_out),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
        )

    def forward(self, ipt):
        return self.main(ipt)

class WaveUNet2(nn.Module):
    def __init__(self, n_layers=12, channels_interval=24):
        super(WaveUNet2, self).__init__()

        self.n_layers = n_layers
        self.channels_interval = channels_interval
        encoder_in_channels_list = [1] + [i * self.channels_interval for i in range(1, self.n_layers)]
        encoder_out_channels_list = [i * self.channels_interval for i in range(1, self.n_layers + 1)]

        #          1    => 2    => 3    => 4    => 5    => 6   => 7   => 8   => 9  => 10 => 11 =>12
        # 16384 => 8192 => 4096 => 2048 => 1024 => 512 => 256 => 128 => 64 => 32 => 16 =>  8 => 4
        self.encoder = nn.ModuleList()
        for i in range(self.n_layers):
            self.encoder.append(
                DownSamplingLayer(
                    channel_in=encoder_in_channels_list[i],
                    channel_out=encoder_out_channels_list[i]
                )
            )

        self.middle = nn.Sequential(
            nn.Conv1d(self.n_layers * self.channels_interval, self.n_layers * self.channels_interval, 15, stride=1,
                      padding=7),
            nn.BatchNorm1d(self.n_layers * self.channels_interval),
            nn.LeakyReLU(negative_slope=0.1, inplace=True)
        )

        decoder_in_channels_list = [(2 * i + 1) * self.channels_interval for i in range(1, self.n_layers)] + [
            2 * self.n_layers * self.channels_interval]
        decoder_in_channels_list = decoder_in_channels_list[::-1]
        decoder_out_channels_list = encoder_out_channels_list[::-1]
        self.decoder = nn.ModuleList()
        for i in range(self.n_layers):
            self.decoder.append(
                UpSamplingLayer(
                    channel_in=decoder_in_channels_list[i],
                    channel_out=decoder_out_channels_list[i]
                )
            )

        self.out = nn.Sequential(
            nn.Conv1d(1 + self.channels_interval, 1, kernel_size=1, stride=1),
            nn.Tanh()
        )

    def forward(self, input):
        tmp = []
        o = input

        # Up Sampling
        for i in range(self.n_layers):
            o = self.encoder[i](o)
            tmp.append(o)
            # [batch_size, T // 2, channels]
            o = o[:, :, ::2]

        o = self.middle(o)

        # Down Sampling
        for i in range(self.n_layers):
            # [batch_size, T * 2, channels]
            o = F.interpolate(o, scale_factor=2, mode="linear", align_corners=True)
            # Skip Connection
            o = torch.cat([o, tmp[self.n_layers - i - 1]], dim=1)
            o = self.decoder[i](o)

        o = torch.cat([o, input], dim=1)
        o = self.out(o)
        return o

Main block for training the model and testing.
Initializes all training and testing parameters. 
Performs training and testing in one big loop, testing each epoch to track any overfitting.

In [None]:
# training (work in progress)
from torch.utils.data import DataLoader

def train_model( model, dataloader, criterion, optimizer, device ):
    model.train()

    running_loss = 0.0

    # Iterate over data.
    for clean_audio,noisy_audio,orig_len in dataloader:

        # send inputs to gpu
        clean_audio = clean_audio.to(device)
        noisy_audio = noisy_audio.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()
        
        with torch.enable_grad():
            # Send the noisey speech sample through network
            output = model(noisy_audio)

            # compute loss between network output and clean audio
            loss = criterion(output, clean_audio)

            # backward + optimize 
            loss.backward()
            optimizer.step()

        # statistics
        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader.dataset)

    return model, epoch_loss

def test_model( model, device, criterion, dataloader, inv_transform = None ):
    model.eval() 

    cleaned_waveforms = torch.empty(0)
    cleaned_data = torch.empty(0)
    running_loss = 0

    for clean_audio, noisy_audio, orig_len in dataloader:
        #noisy_mag, noisy_phase, noisy_wav, norm_val = noisy_audio
        noisy_audio = noisy_audio.to(device)
       # clean_wav, clean_mag = clean_audio
        clean_audio = clean_audio.to(device)
        #norm_val = norm_val.to(device)
        
        with torch.no_grad():
            # forward
            enhanced_data = model(noisy_audio)

            loss = criterion(enhanced_data, clean_audio)

            cleaned_data = torch.cat( (cleaned_data, enhanced_data.detach().cpu()), 0 )

            running_loss += loss.item()

    test_loss = running_loss / len( dataloader.dataset )

    return cleaned_data, test_loss


train_loader = DataLoader( train_set, batch_size=64)
test_loader = DataLoader( test_set, batch_size=64 )

model = WaveUNet2() # assign model you want to use

learning_rate = 0.01
num_epochs = 10

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(),learning_rate)

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

model = model.to(device)

prev_test_loss = 10000.0
count = 0
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    model, epoch_loss = train_model( model, train_loader, criterion, optimizer, device )
    print('Train Epoch Loss = ', epoch_loss )

    cleaned_data, test_loss = \
        test_model( model, device, criterion, test_loader, inv_transform = None)
    print('Test Epoch Loss = ', test_loss )

    if prev_test_loss < test_loss:
        count += 1
    else:
        count = 0

    if count > 1:
        break;

    prev_test_loss = test_loss


In [None]:
!pip install torchmetrics[audio]
!pip install pesq

Loop through testing samples and evaluate enhanced data against clean data

In [None]:
from torchmetrics import SignalNoiseRatio
from pesq import pesq

SNR = SignalNoiseRatio(zero_mean=True)

snr_vals_noisy = torch.empty(len(test_set))
pesq_vals_noisy = torch.empty(len(test_set))
snr_vals = torch.empty(len(test_set))
pesq_vals = torch.empty(len(test_set))

for i in range( len(test_set) ):
    clean_audio,noisy_audio,orig_len = test_set[i]
    clean_wav = clean_audio
    #noisy_audio = noisy_data[2]
    enhanced_audio = cleaned_data[i]

    #un_norm_mag = torch.transpose(enhanced_mag,0,3) * norm_val
    #un_norm_mag = torch.transpose(un_norm_mag,0,3)
    #complex_out = torch.polar(un_norm_mag, noisy_phase)
    enhanced_audio = enhanced_audio.detach().cpu()

    #if inv_transform:
        #audio_enhanced = inv_transform(audio_enhanced, clean_wav.shape[2])

    if enhanced_audio.shape[1] > orig_len:
        enhanced_audio = enhanced_audio[:,:orig_len]
        noisy_audio = noisy_audio[:,:orig_len]
        clean_wav = clean_wav[:,:orig_len]

    snr_vals[i] = SNR(enhanced_audio, clean_wav)
    pesq_vals[i] = pesq( target_fs, enhanced_audio.squeeze().numpy(), clean_wav.squeeze().numpy(), "nb")

    snr_vals_noisy[i] = SNR(noisy_audio, clean_wav)
    pesq_vals_noisy[i] = pesq( target_fs, noisy_audio.squeeze().numpy(), clean_wav.squeeze().numpy(), "nb")



Compute average and best SNR and PESQ improvements

In [None]:
SNR_imp = snr_vals - snr_vals_noisy
PESQ_imp = pesq_vals - pesq_vals_noisy
print("Average PESQ Enhanced", pesq_vals.sum() / len(pesq_vals) )
print("Average PESQ Noisy", pesq_vals_noisy.sum() / len(pesq_vals_noisy) )

avg_SNR_imp = SNR_imp.sum() / len(SNR_imp)
avg_PESQ_imp = PESQ_imp.sum() / len(PESQ_imp)

peak_SNR_imp = SNR_imp.max()
peak_snr_ind = SNR_imp.argmax()
peak_PESQ_imp = PESQ_imp.max()
peak_PESQ_ind = PESQ_imp.argmax()

print("Average SNR Improvement = ", avg_SNR_imp )
print("Average PESQ Improvement = ", avg_PESQ_imp )

print("Best SNR Improvement = ", peak_SNR_imp )
print("Best PESQ Improvment = ", peak_PESQ_imp )


# use if using magnitude spectrograms
#plot_spectrogram(cleaned_data[0].squeeze())
#plot_spectrogram(test_set[0][0][1].squeeze())

play_audio(test_set[0][1],target_fs) # play noisy test sample
play_audio(cleaned_data[0],target_fs)
