<a href="https://colab.research.google.com/github/cnovak232/DL_Speech_Enhancement/blob/LSTM/DL_Final_Project_LSTM_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Load Data by cloning the repo - easiest way to access shared data

In [None]:
!git clone https://github.com/cnovak232/DL_Speech_Enhancement.git


Define some helper function for plotting and playing audio

In [None]:
import torch
import torchaudio as ta
import librosa as lib
from IPython.display import Audio, display
import matplotlib
import matplotlib.pyplot as plt

# helper functions for audio and what not
# mostly taken for torchaudio tutorials 

def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
    if num_channels > 1:
        axes[c].set_ylabel(f'Channel {c+1}')
    if xlim:
        axes[c].set_xlim(xlim)
    if ylim:
        axes[c].set_ylim(ylim)

    figure.suptitle(title)
    plt.show(block=False)

def get_spectrogram(
    waveform = None,
    n_fft = 512,
    win_len = None,
    hop_len = None,
    power = 1.0 ):
    spectrogram = ta.transforms.Spectrogram(
      n_fft=n_fft,
      win_length=win_len,
      hop_length=hop_len,
      center=True,
      pad_mode="reflect",
      power=power )
    
    return spectrogram(waveform)

def plot_spectrogram(spec, type = "amplitude", title=None, ylabel='freq_bin', aspect='auto', xmax=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or 'Spectrogram (db)')
    axs.set_ylabel(ylabel)
    axs.set_xlabel('frame')
    toDb = ta.transforms.AmplitudeToDB(type)
    im = axs.imshow(spec, origin='lower', aspect=aspect)
    if xmax:
        axs.set_xlim((0, xmax))
    fig.colorbar(im, ax=axs)
    plt.show(block=False)

def norm_spec( spec ):
    normed = spec / spec.max()
    return normed, spec.max()


Define a Custom Dataset class for the Data and read it in

In [None]:
# read dataset in and downsample / transform / pad if needed
from torch.utils.data import Dataset
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio as ta

class VoiceBankDemand(Dataset):
    def __init__(self, clean_dir, noisy_dir, list_dir, 
                 data = "train", len_samples = None, downsample = None, 
                 transform = None ):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.list_dir = list_dir
        self.num_samples = len_samples
        self.downsample = downsample
        self.transform = transform
        self.data = data
    
    def __len__(self):
        return len(self.list_dir)

    def __getitem__( self, idx ):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        clean_name = os.path.join( self.clean_dir, self.list_dir[idx] )
        noisy_name = os.path.join( self.noisy_dir, self.list_dir[idx] )
        clean_audio, fs = ta.load(clean_name)
        noisy_audio, fs= ta.load(noisy_name)

        if self.downsample:
            downsampler = ta.transforms.Resample(fs,self.downsample)
            clean_audio = downsampler( clean_audio )
            noisy_audio = downsampler( noisy_audio )

        if self.num_samples:
            orig_len = clean_audio.shape[1]
            if clean_audio.shape[1] > num_samples:
                clean_audio = clean_audio[:,:num_samples]
                noisy_audio = noisy_audio[:,:num_samples]
            elif clean_audio.shape[1] < num_samples:
                pad_len = int( num_samples - clean_audio.shape[1] )
                pad = torch.zeros(1,pad_len)
                clean_audio = torch.cat((clean_audio,pad), dim=1)
                noisy_audio = torch.cat((noisy_audio,pad),dim=1)
        
        if self.data == "test":
            if self.transform:
                noisy_trnsfrm = self.transform( noisy_audio )
                clean_trnsfrm = self.transform( clean_audio )
            clean_mag,_ = norm_spec( torch.abs(clean_trnsfrm) )
            clean_audio = (clean_audio, clean_mag)
            noisy_mag, norm_val = norm_spec( torch.abs(noisy_trnsfrm) )
            noisy_phase = torch.angle(noisy_trnsfrm)
            noisy_audio = (noisy_mag, noisy_phase, noisy_audio, norm_val)
        else:
            if self.transform:
                clean_audio,_ = norm_spec( self.transform( clean_audio ) )
                noisy_audio,_ = norm_spec( self.transform( noisy_audio ) )

        sample = (clean_audio, noisy_audio, orig_len)

        return sample

train_clean_path = './DL_Speech_Enhancement/clean_trainset_28spk_wav'
train_noisy_path = './DL_Speech_Enhancement/noisy_trainset_28spk_wav'
test_clean_path  = './DL_Speech_Enhancement/clean_testset_wav'
test_noisy_path  = './DL_Speech_Enhancement/noisy_testset_wav'

list_dir_train = os.listdir(train_clean_path)
list_dir_test = os.listdir(test_clean_path)

target_fs = 16000 # downsample to 16 KHz
spectrogram = ta.transforms.Spectrogram(
    n_fft=512,
    power=1.0,
    normalized = False )
complex_spec = ta.transforms.Spectrogram(
    n_fft=512,
    power=None,
    normalized = False ) # return complex spectrum

num_samples = int( 5.0 * target_fs ) 

train_set = VoiceBankDemand( clean_dir = train_clean_path,
                             noisy_dir = train_noisy_path,
                             list_dir = list_dir_train,
                             len_samples = num_samples, # clip or pad samples to be 5s
                             downsample = target_fs, # downsample to 16Khz
                             transform = spectrogram )

# returns the mag/phase of each audio file 
test_set = VoiceBankDemand( clean_dir = test_clean_path,
                            noisy_dir = test_noisy_path,
                            list_dir = list_dir_test,
                            data = "test",
                            len_samples = num_samples,
                            downsample = target_fs,
                            transform = complex_spec )


clean, noisy, orig_len = train_set[1]
print(clean.size())
print(orig_len)

clean_test, noisy_test, orig_len = test_set[0]

print(len(noisy_test))

noisy_mag, noisy_phase, noisy_wav, norm_val = noisy_test

plot_spectrogram(noisy.squeeze())
plot_spectrogram(noisy_mag.squeeze())

In [None]:
# LSTM
from torch.utils.data import Dataset
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio as ta
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

####################################

# Using Reccurnt Neural Network (RNN) based method - In recurrent networks, each component shares the same weights. (Feedforward network - each input has its own weights)
# RNNs maintain a hidden state that can capture information from previous time steps, allowing them to model temporal dependencies in the input data.
# Also, a model can process sequences with different lengths by sharing the weights. 
# Note: if you want variable lengths, you could also use a Gated Recurrent Unit (GRU) (more computationally efdficient than LSTM but may perform worse)

# LSTM block with LSTM layer and fully connected layer
# Define the SpeechEnhancementLSTM class
class SpeechEnhancementLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(SpeechEnhancementLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size * 257)  # Update the output size of the linear layer

    def forward(self, x, lengths):
        # Initialize hidden state and cell state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Pack the input sequence
        x_packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)

        # Forward propagate LSTM
        out_packed, _ = self.lstm(x_packed, (h0, c0))

        # Unpack the sequence
        out, _ = nn.utils.rnn.pad_packed_sequence(out_packed, batch_first=True)

        # Apply the fully connected layer
        out = self.fc(out)

        # Reshape the output to match the input dimensions
        out = out.view(x.size(0), x.size(1), 257)
        
        return out

###############################################



Main block for training the model and testing.
Initializes all training and testing parameters. 
Performs training and testing in one big loop, testing each epoch to track any overfitting.

In [None]:
!pip install --upgrade torch torchvision

# training (work in progress)
from torch.utils.data import DataLoader
import torch.optim as optim
import os
os.environ['TORCH_USE_CUDA_DSA'] = '1'
torch.cuda.synchronize()
#torch.backends.cudnn.enabled = False
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Set device to use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0

    # Iterate over data
    for clean_audio, noisy_audio, orig_len in dataloader:
        # Send inputs to the device (GPU or CPU)
        clean_audio = clean_audio.to(device)
        noisy_audio = noisy_audio.to(device)

        # Remove extra dimension and transpose the input and target tensors
        clean_audio = clean_audio.squeeze(1).permute(0, 2, 1)
        noisy_audio = noisy_audio.squeeze(1).permute(0, 2, 1)

        lengths = orig_len.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        with torch.enable_grad():
            # Send the noisy speech sample through the network
            output = model(noisy_audio)

            # Transpose the output tensor back to its original shape
            output = output.permute(0, 2, 1)

            # Compute loss between network output and clean audio
            loss = criterion(output, clean_audio)

            # Backward + optimize 
            loss.backward()
            optimizer.step()

        # Statistics
        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader.dataset)

    return model, epoch_loss


def test_model(model, device, criterion, dataloader):
    model.eval() 

    cleaned_waveforms = torch.empty(0)
    cleaned_data = torch.empty(0)
    running_loss = 0

    with torch.no_grad():
        for clean_audio, noisy_audio, orig_len in dataloader:
            # send inputs to the device (GPU or CPU)
            clean_audio = clean_audio.to(device)
            noisy_audio = noisy_audio.to(device)

            # Remove extra dimension and transpose the input and target tensors
            clean_audio = clean_audio.permute(0, 2, 1)
            noisy_audio = noisy_audio.squeeze(2)

            lengths = orig_len.to(device)

            # Pack sequences
            packed_noisy = nn.utils.rnn.pack_padded_sequence(noisy_audio, lengths.cpu(), batch_first=True, enforce_sorted=False)

            # Forward
            packed_enhanced = model(packed_noisy, lengths)

            # Unpack sequences
            enhanced_data, lengths = nn.utils.rnn.pad_packed_sequence(packed_enhanced, batch_first=True, total_length=noisy_audio.shape[1], lengths=lengths)

            loss = criterion(enhanced_data, clean_audio)

            cleaned_data = torch.cat((cleaned_data, enhanced_data.detach().cpu()), 0)

            running_loss += loss.item()

    test_loss = running_loss / len(dataloader.dataset)

    cleaned_data = torch.stack([x.squeeze(0).permute(1, 0) for x in cleaned_data])
    
    return cleaned_data, test_loss


      
train_loader = DataLoader(train_set, batch_size=128)
test_loader = DataLoader(test_set, batch_size=128)

# Get the input shape from the dataset
input_shape = (313, 257)
output_shape = (313, 257)
hidden_shape = (128, 2)
num_layers = 2

# Assuming you have created a DataLoader named `train_loader`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assuming you have created a DataLoader named `train_loader`
input_size = 257  # Assuming 80 features for each time step
hidden_size = 4
num_layers = 2
output_size = 80  # Assuming 80 features for the output

# Initialize the model, loss function, and optimizer
model = SpeechEnhancementLSTM(input_size, hidden_size, num_layers).to(device)

# Define your loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Call this function to free up GPU memory cache
#torch.cuda.empty_cache()

# Number of training epochs
num_epochs = 10

# Training Loop
accumulation_steps = 4  # Accumulate gradients over 4 batches
for epoch in range(num_epochs):
    running_loss = 0.0
    optimizer.zero_grad()

    for batch_idx, (clean_audio, noisy_audio, lengths) in enumerate(train_loader):
        # Move data and lengths to the device
        clean_audio = clean_audio.to(device)
        noisy_audio = noisy_audio.to(device)
        lengths = lengths.to(device)

        lengths = lengths.cpu()
        noisy_audio = noisy_audio.squeeze(1).permute(0, 2, 1)

        # Forward pass
        output = model(noisy_audio, lengths)

        # Slice to the original sequence lengths
       # output = output[:, :length.item()]
        print("Output tensor size:", output.size())


        # Reshape the output tensor to match the target tensor shape
        output = output.view(clean_audio.shape)

        # Compute the loss
        loss = criterion(output.data, clean_audio.data)


        # Backward pass
        loss.backward()

        # Accumulate the gradients
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Update the running loss
        running_loss += loss.item()

    # Print the average loss for this epoch
    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}")

    # Perform optimizer step for any remaining accumulated gradients
    if batch_idx % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

print("Finished training.")



In [None]:
!pip install torchmetrics[audio]
!pip install pesq

Loop through testing samples and evaluate enhanced data against clean data

In [None]:
from torchmetrics import SignalNoiseRatio
from pesq import pesq

SNR = SignalNoiseRatio(zero_mean=True)

snr_vals_noisy = torch.empty(len(test_set))
pesq_vals_noisy = torch.empty(len(test_set))
snr_vals = torch.empty(len(test_set))
pesq_vals = torch.empty(len(test_set))

for i in range( len(test_set) ):
    clean_audio,noisy_audio,orig_len = test_set[i]
    clean_wav = clean_audio
    #noisy_audio = noisy_data[2]
    enhanced_audio = cleaned_data[i]

    #un_norm_mag = torch.transpose(enhanced_mag,0,3) * norm_val
    #un_norm_mag = torch.transpose(un_norm_mag,0,3)
    #complex_out = torch.polar(un_norm_mag, noisy_phase)
    enhanced_audio = enhanced_audio.detach().cpu()

    #if inv_transform:
        #audio_enhanced = inv_transform(audio_enhanced, clean_wav.shape[2])

    if enhanced_audio.shape[1] > orig_len:
        enhanced_audio = enhanced_audio[:,:orig_len]
        noisy_audio = noisy_audio[:,:orig_len]
        clean_wav = clean_wav[:,:orig_len]

    snr_vals[i] = SNR(enhanced_audio, clean_wav)
    pesq_vals[i] = pesq( target_fs, enhanced_audio.squeeze().numpy(), clean_wav.squeeze().numpy(), "nb")

    snr_vals_noisy[i] = SNR(noisy_audio, clean_wav)
    pesq_vals_noisy[i] = pesq( target_fs, noisy_audio.squeeze().numpy(), clean_wav.squeeze().numpy(), "nb")



Compute average and best SNR and PESQ improvements

In [None]:
SNR_imp = snr_vals - snr_vals_noisy
PESQ_imp = pesq_vals - pesq_vals_noisy
print("Average PESQ Enhanced", pesq_vals.sum() / len(pesq_vals) )
print("Average PESQ Noisy", pesq_vals_noisy.sum() / len(pesq_vals_noisy) )

avg_SNR_imp = SNR_imp.sum() / len(SNR_imp)
avg_PESQ_imp = PESQ_imp.sum() / len(PESQ_imp)

peak_SNR_imp = SNR_imp.max()
peak_snr_ind = SNR_imp.argmax()
peak_PESQ_imp = PESQ_imp.max()
peak_PESQ_ind = PESQ_imp.argmax()

print("Average SNR Improvement = ", avg_SNR_imp )
print("Average PESQ Improvement = ", avg_PESQ_imp )

print("Best SNR Improvement = ", peak_SNR_imp )
print("Best PESQ Improvment = ", peak_PESQ_imp )


# use if using magnitude spectrograms
#plot_spectrogram(cleaned_data[0].squeeze())
#plot_spectrogram(test_set[0][0][1].squeeze())

play_audio(test_set[0][1],target_fs) # play noisy test sample
play_audio(cleaned_data[0],target_fs)
