# Speech Enhancement using SEGANS

This notebook implements a deep learning approach for speech enhancement using **SEGANS (Speech Enhancement Generative Adversarial Networks)**. The project is structured into multiple components to ensure modularity and maintainability. Each section of the notebook corresponds to a specific part of the project.

## Project Overview
Speech enhancement aims to improve the quality and intelligibility of speech signals by reducing noise or other distortions. **SEGANS** leverages the power of Generative Adversarial Networks (GANs) to perform this task effectively. The workflow includes:
1. **Preprocessing**: Preparing the dataset by normalizing and transforming audio signals into a suitable format for training.
2. **Model Development**: Designing the SEGANS architecture, including generator and discriminator networks.
3. **Training**: Training the model using the preprocessed data to minimize loss and enhance performance.
4. **Testing**: Evaluating the model on unseen data and analyzing its performance.

This notebook is organized into the following sections:
- **Preprocessing**: Code for loading, cleaning, and preparing audio data.
- **Model**: Definition of the SEGANS model, including the generator and discriminator.
- **Training**: Implementation of the training loop, along with loss functions and optimizers.
- **Testing**: Validation and performance metrics on the test dataset.



# Preprocessing for Speech Enhancement

The preprocessing step prepares audio data for the SEGANS model by performing the following tasks:
1. **Audio Segmentation**: Audio files are divided into fixed-sized segments with overlapping windows. This segmentation ensures that each audio file is broken into smaller, manageable chunks for training and testing.
2. **Serialization**: Clean and noisy audio segments are paired and saved as `.npy` files for efficient use during training and testing.
3. **Validation**: Ensures that the serialized data matches the expected format and segment length.

### Key Parameters:
- **Window Size**: 2^14 samples (approx. 1 second of audio).
- **Sample Rate**: 16 kHz (target sample rate for audio processing).
- **Stride**: 50% overlap between consecutive segments.


In [None]:
import os
import librosa
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
from torch.utils import data
import argparse
from scipy.io import wavfile
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader 






In [None]:
# Directory Paths
DIRS = {
    "train_clean": 'data/clean_trainset',
    "train_noisy": 'data/noisy_trainset',
    "test_clean": 'data/clean_testset',
    "test_noisy": 'data/noisy_testset',
    "train_serialized": 'data/serialized_train_data',
    "test_serialized": 'data/serialized_test_data',
}

# Parameters
WINDOW_SIZE = 2 ** 14  # Approx. 1 second
SAMPLE_RATE = 16000
STRIDE = 0.5  # 50% overlap


def segment_audio_file(audio_path, window_size=WINDOW_SIZE, stride=STRIDE, sample_rate=SAMPLE_RATE):
    """
    Segments an audio file into overlapping segments.

    Args:
        audio_path (str): Path to the audio file.
        window_size (int): Number of samples per segment.
        stride (float): Overlap ratio (e.g., 0.5 for 50% overlap).
        sample_rate (int): Target sample rate for audio.

    Returns:
        list: List of audio segments.
    """
    audio, sr = librosa.load(audio_path, sr=sample_rate)
    hop_length = int(window_size * stride)
    return [audio[i:i + window_size] for i in range(0, len(audio) - window_size + 1, hop_length)]


def preprocess_and_save(dataset_type="train"):
    """
    Processes and serializes audio data (clean and noisy) for training or testing.

    Args:
        dataset_type (str): Type of dataset ('train' or 'test').
    """
    clean_path = DIRS[f"{dataset_type}_clean"]
    noisy_path = DIRS[f"{dataset_type}_noisy"]
    save_path = DIRS[f"{dataset_type}_serialized"]

    os.makedirs(save_path, exist_ok=True)

    print(f"Processing {dataset_type} data...")
    for filename in tqdm(os.listdir(clean_path), desc=f"Serializing {dataset_type}"):
        clean_file = os.path.join(clean_path, filename)
        noisy_file = os.path.join(noisy_path, filename)

        # Segment clean and noisy files
        clean_segments = segment_audio_file(clean_file)
        noisy_segments = segment_audio_file(noisy_file)

        # Save paired segments as .npy
        for idx, (clean_segment, noisy_segment) in enumerate(zip(clean_segments, noisy_segments)):
            pair = np.array([clean_segment, noisy_segment])
            save_filename = f"{os.path.splitext(filename)[0]}_{idx}.npy"
            np.save(os.path.join(save_path, save_filename), pair)


def validate_serialized_data(dataset_type="train"):
    """
    Validates the serialized data by checking the consistency of segment lengths.

    Args:
        dataset_type (str): Type of dataset ('train' or 'test').
    """
    serialized_path = DIRS[f"{dataset_type}_serialized"]
    print(f"Validating {dataset_type} serialized data...")

    for file in tqdm(os.listdir(serialized_path), desc=f"Validating {dataset_type}"):
        file_path = os.path.join(serialized_path, file)
        data_pair = np.load(file_path)

        # Validate shape
        if data_pair.shape[1] != WINDOW_SIZE:
            print(f"Error: Segment in {file} does not match window size {WINDOW_SIZE}.")

### Audio Dataset 

In this section, we define the `AudioDataset` class and the `emphasis` function, which are essential for handling and processing audio data.

- **`emphasis`**: This function applies pre-emphasis or de-emphasis filtering to audio signals. It is important for accentuating or diminishing certain frequency ranges to enhance learning and reduce noise in the data.
  
- **`AudioDataset`**: This custom PyTorch dataset class is responsible for loading the audio data (both clean and noisy), applying the emphasis function, and organizing the data into batches. The dataset also provides a method for generating reference batches used during virtual batch normalization.



In [None]:
def emphasis(signal_batch, emph_coeff=0.95, pre=True):
    """
    Pre-emphasis or De-emphasis of higher frequencies given a batch of signal.

    Args:
        signal_batch: batch of signals, represented as numpy arrays
        emph_coeff: emphasis coefficient
        pre: pre-emphasis or de-emphasis signals

    Returns:
        result: pre-emphasized or de-emphasized signal batch
    """
    result = np.zeros(signal_batch.shape)
    for sample_idx, sample in enumerate(signal_batch):
        for ch, channel_data in enumerate(sample):
            if pre:
                result[sample_idx][ch] = np.append(channel_data[0], channel_data[1:] - emph_coeff * channel_data[:-1])
            else:
                result[sample_idx][ch] = np.append(channel_data[0], channel_data[1:] + emph_coeff * channel_data[:-1])
    return result


class AudioDataset(data.Dataset):
    """
    Audio sample reader using preprocessed and serialized data.
    """

    def __init__(self, data_type):
        """
        Initializes the dataset by pointing to the respective serialized folder.

        Args:
            data_type (str): Type of dataset ('train' or 'test').
        """
        # Use serialized data directory paths from the `DIRS` dictionary
        if data_type == 'train':
            data_path = DIRS['train_serialized']
        else:
            data_path = DIRS['test_serialized']

        if not os.path.exists(data_path):
            raise FileNotFoundError(f'The {data_type} serialized data folder does not exist!')

        self.data_type = data_type
        self.file_names = [os.path.join(data_path, filename) for filename in os.listdir(data_path)]

    def reference_batch(self, batch_size):
        """
        Randomly selects a reference batch from dataset.
        Reference batch is used for calculating statistics for virtual batch normalization operation.

        Args:
            batch_size (int): batch size

        Returns:
            ref_batch: reference batch
        """
        ref_file_names = np.random.choice(self.file_names, batch_size)
        ref_batch = np.stack([np.load(f) for f in ref_file_names])

        ref_batch = emphasis(ref_batch, emph_coeff=0.95)
        return torch.from_numpy(ref_batch).type(torch.FloatTensor)

    def __getitem__(self, idx):
        """
        Retrieves a single data sample from the dataset.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (pair, clean, noisy) for training, or (file_name, noisy) for testing.
        """
        pair = np.load(self.file_names[idx])
        pair = emphasis(pair[np.newaxis, :, :], emph_coeff=0.95).reshape(2, -1)
        noisy = pair[1].reshape(1, -1)

        if self.data_type == 'train':
            clean = pair[0].reshape(1, -1)
            return torch.from_numpy(pair).type(torch.FloatTensor), torch.from_numpy(clean).type(
                torch.FloatTensor), torch.from_numpy(noisy).type(torch.FloatTensor)
        else:
            return os.path.basename(self.file_names[idx]), torch.from_numpy(noisy).type(torch.FloatTensor)

    def __len__(self):
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: Total number of samples.
        """
        return len(self.file_names)

# Now we go to implement the model.

The model consists of two main components:
1. **Generator (G):** Generates synthetic signals from random latent vectors (`z`). It is trained to produce outputs that resemble real signals.
2. **Discriminator (D):** Distinguishes between real signals and synthetic signals produced by the generator. It provides feedback to the generator to improve its outputs.

Key features of the model:
- **1D Convolutions:** Efficiently process sequential data for both generator and discriminator.
- **Batch Normalization (BN):** Reduces batch dependence during training, improving the stability and generalization of the model.
- **Leaky ReLU and Tanh Activations:** Introduce non-linearities to capture complex data patterns.

The training goal is to achieve a balance where the generator produces realistic signals, making it difficult for the discriminator to differentiate between real and fake samples.


### BatchNormlization Class
 It implements the Virtual Batch Normalization (BN) layer.

Key features:
- Learns scale (`gamma`) and shift (`beta`) parameters.
- Computes batch statistics for normalization.
- Updates reference statistics to normalize inputs during training.
- Improves generalization in GANs by reducing batch dependence.

References:
- Virtual Batch Normalization: https://arxiv.org/abs/1606.03498
- Implementation discussion: https://discuss.pytorch.org/t/parameter-grad-of-conv-weight-is-none-after-virtual-batch-normalization/9036


In [None]:
class VirtualBatchNorm1d(Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.gamma = Parameter(torch.normal(mean=1.0, std=0.02, size=(1, num_features, 1)))
        self.beta = Parameter(torch.zeros(1, num_features, 1))

    def get_stats(self, x):
        mean = x.mean(2, keepdim=True).mean(0, keepdim=True)
        mean_sq = (x ** 2).mean(2, keepdim=True).mean(0, keepdim=True)
        return mean, mean_sq

    def forward(self, x, ref_mean, ref_mean_sq):
        mean, mean_sq = self.get_stats(x)
        if ref_mean is None or ref_mean_sq is None:
            mean = mean.clone().detach()
            mean_sq = mean_sq.clone().detach()
            out = self.normalize(x, mean, mean_sq)
        else:
            batch_size = x.size(0)
            new_coeff = 1. / (batch_size + 1.)
            old_coeff = 1. - new_coeff
            mean = new_coeff * mean + old_coeff * ref_mean
            mean_sq = new_coeff * mean_sq + old_coeff * ref_mean_sq
            out = self.normalize(x, mean, mean_sq)
        return out, mean, mean_sq

    def normalize(self, x, mean, mean_sq):
        std = torch.sqrt(self.eps + mean_sq - mean ** 2)
        x = x - mean
        x = x / std
        x = x * self.gamma
        x = x + self.beta
        return x

    def __repr__(self):
        return ('{name}(num_features={num_features}, eps={eps}'
                .format(name=self.__class__.__name__, **self.__dict__))


### Generator Class
is  the core of the GAN's generator.

Key features:
- Implements a neural network that generates signals from a latent vector `z`.
- Contains convolutional and transposed convolutional layers for encoding and decoding signals.
- Uses activation functions (`PReLU` and `Tanh`) to introduce non-linearity.
- Initializes weights using Xavier initialization for better convergence.


In [None]:
class Generator(nn.Module):
    """G"""

    def __init__(self):
        super().__init__()
        # encoder gets a noisy signal as input [B x 1 x 16384]
        self.enc1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=32, stride=2, padding=15)  # [B x 16 x 8192]
        self.enc1_nl = nn.PReLU()
        self.enc2 = nn.Conv1d(16, 32, 32, 2, 15)  # [B x 32 x 4096]
        self.enc2_nl = nn.PReLU()
        self.enc3 = nn.Conv1d(32, 32, 32, 2, 15)  # [B x 32 x 2048]
        self.enc3_nl = nn.PReLU()
        self.enc4 = nn.Conv1d(32, 64, 32, 2, 15)  # [B x 64 x 1024]
        self.enc4_nl = nn.PReLU()
        self.enc5 = nn.Conv1d(64, 64, 32, 2, 15)  # [B x 64 x 512]
        self.enc5_nl = nn.PReLU()
        self.enc6 = nn.Conv1d(64, 128, 32, 2, 15)  # [B x 128 x 256]
        self.enc6_nl = nn.PReLU()
        self.enc7 = nn.Conv1d(128, 128, 32, 2, 15)  # [B x 128 x 128]
        self.enc7_nl = nn.PReLU()
        self.enc8 = nn.Conv1d(128, 256, 32, 2, 15)  # [B x 256 x 64]
        self.enc8_nl = nn.PReLU()
        self.enc9 = nn.Conv1d(256, 256, 32, 2, 15)  # [B x 256 x 32]
        self.enc9_nl = nn.PReLU()
        self.enc10 = nn.Conv1d(256, 512, 32, 2, 15)  # [B x 512 x 16]
        self.enc10_nl = nn.PReLU()
        self.enc11 = nn.Conv1d(512, 1024, 32, 2, 15)  # [B x 1024 x 8]
        self.enc11_nl = nn.PReLU()

        # decoder generates an enhanced signal
        self.dec10 = nn.ConvTranspose1d(in_channels=2048, out_channels=512, kernel_size=32, stride=2, padding=15)
        self.dec10_nl = nn.PReLU()  # out : [B x 512 x 16] -> (concat) [B x 1024 x 16]
        self.dec9 = nn.ConvTranspose1d(1024, 256, 32, 2, 15)  # [B x 256 x 32]
        self.dec9_nl = nn.PReLU()
        self.dec8 = nn.ConvTranspose1d(512, 256, 32, 2, 15)  # [B x 256 x 64]
        self.dec8_nl = nn.PReLU()
        self.dec7 = nn.ConvTranspose1d(512, 128, 32, 2, 15)  # [B x 128 x 128]
        self.dec7_nl = nn.PReLU()
        self.dec6 = nn.ConvTranspose1d(256, 128, 32, 2, 15)  # [B x 128 x 256]
        self.dec6_nl = nn.PReLU()
        self.dec5 = nn.ConvTranspose1d(256, 64, 32, 2, 15)  # [B x 64 x 512]
        self.dec5_nl = nn.PReLU()
        self.dec4 = nn.ConvTranspose1d(128, 64, 32, 2, 15)  # [B x 64 x 1024]
        self.dec4_nl = nn.PReLU()
        self.dec3 = nn.ConvTranspose1d(128, 32, 32, 2, 15)  # [B x 32 x 2048]
        self.dec3_nl = nn.PReLU()
        self.dec2 = nn.ConvTranspose1d(64, 32, 32, 2, 15)  # [B x 32 x 4096]
        self.dec2_nl = nn.PReLU()
        self.dec1 = nn.ConvTranspose1d(64, 16, 32, 2, 15)  # [B x 16 x 8192]
        self.dec1_nl = nn.PReLU()
        self.dec_final = nn.ConvTranspose1d(32, 1, 32, 2, 15)  # [B x 1 x 16384]
        self.dec_tanh = nn.Tanh()

        # initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize weights for convolution layers using Xavier initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
                nn.init.xavier_normal(m.weight.data)

    def forward(self, x, z):
        """
        Forward pass of generator.

        Args:
            x: input batch (signal)
            z: latent vector
        """
        # encoding step
        e1 = self.enc1(x)
        e2 = self.enc2(self.enc1_nl(e1))
        e3 = self.enc3(self.enc2_nl(e2))
        e4 = self.enc4(self.enc3_nl(e3))
        e5 = self.enc5(self.enc4_nl(e4))
        e6 = self.enc6(self.enc5_nl(e5))
        e7 = self.enc7(self.enc6_nl(e6))
        e8 = self.enc8(self.enc7_nl(e7))
        e9 = self.enc9(self.enc8_nl(e8))
        e10 = self.enc10(self.enc9_nl(e9))
        e11 = self.enc11(self.enc10_nl(e10))
        # c = compressed feature, the 'thought vector'
        c = self.enc11_nl(e11)
        
        encoded = torch.cat((c, z), dim=1)

        # decoding step
        d10 = self.dec10(encoded)
        # dx_c : concatenated with skip-connected layer's output & passed nonlinear layer
        d10_c = self.dec10_nl(torch.cat((d10, e10), dim=1))
        d9 = self.dec9(d10_c)
        d9_c = self.dec9_nl(torch.cat((d9, e9), dim=1))
        d8 = self.dec8(d9_c)
        d8_c = self.dec8_nl(torch.cat((d8, e8), dim=1))
        d7 = self.dec7(d8_c)
        d7_c = self.dec7_nl(torch.cat((d7, e7), dim=1))
        d6 = self.dec6(d7_c)
        d6_c = self.dec6_nl(torch.cat((d6, e6), dim=1))
        d5 = self.dec5(d6_c)
        d5_c = self.dec5_nl(torch.cat((d5, e5), dim=1))
        d4 = self.dec4(d5_c)
        d4_c = self.dec4_nl(torch.cat((d4, e4), dim=1))
        d3 = self.dec3(d4_c)
        d3_c = self.dec3_nl(torch.cat((d3, e3), dim=1))
        d2 = self.dec2(d3_c)
        d2_c = self.dec2_nl(torch.cat((d2, e2), dim=1))
        d1 = self.dec1(d2_c)
        d1_c = self.dec1_nl(torch.cat((d1, e1), dim=1))
        out = self.dec_tanh(self.dec_final(d1_c))
        return out

### Discriminator Class
is the counterpart to the `Generator`.

Key features:
- Accepts both real and generated signals as input.
- Uses convolutional layers for feature extraction.
- Includes Virtual Batch Normalization for better training stability.
- Outputs a probability score indicating whether the input is real or fake.


In [None]:
class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        # D gets a noisy signal and clear signal as input [B x 2 x 16384]
        negative_slope = 0.03
        self.conv1 = nn.Conv1d(in_channels=2, out_channels=32, kernel_size=31, stride=2, padding=15)  # [B x 32 x 8192]
        self.vbn1 = VirtualBatchNorm1d(32)
        self.lrelu1 = nn.LeakyReLU(negative_slope)
        self.conv2 = nn.Conv1d(32, 64, 31, 2, 15)  # [B x 64 x 4096]
        self.vbn2 = VirtualBatchNorm1d(64)
        self.lrelu2 = nn.LeakyReLU(negative_slope)
        self.conv3 = nn.Conv1d(64, 64, 31, 2, 15)  # [B x 64 x 2048]
        self.dropout1 = nn.Dropout()
        self.vbn3 = VirtualBatchNorm1d(64)
        self.lrelu3 = nn.LeakyReLU(negative_slope)
        self.conv4 = nn.Conv1d(64, 128, 31, 2, 15)  # [B x 128 x 1024]
        self.vbn4 = VirtualBatchNorm1d(128)
        self.lrelu4 = nn.LeakyReLU(negative_slope)
        self.conv5 = nn.Conv1d(128, 128, 31, 2, 15)  # [B x 128 x 512]
        self.vbn5 = VirtualBatchNorm1d(128)
        self.lrelu5 = nn.LeakyReLU(negative_slope)
        self.conv6 = nn.Conv1d(128, 256, 31, 2, 15)  # [B x 256 x 256]
        self.dropout2 = nn.Dropout()
        self.vbn6 = VirtualBatchNorm1d(256)
        self.lrelu6 = nn.LeakyReLU(negative_slope)
        self.conv7 = nn.Conv1d(256, 256, 31, 2, 15)  # [B x 256 x 128]
        self.vbn7 = VirtualBatchNorm1d(256)
        self.lrelu7 = nn.LeakyReLU(negative_slope)
        self.conv8 = nn.Conv1d(256, 512, 31, 2, 15)  # [B x 512 x 64]
        self.vbn8 = VirtualBatchNorm1d(512)
        self.lrelu8 = nn.LeakyReLU(negative_slope)
        self.conv9 = nn.Conv1d(512, 512, 31, 2, 15)  # [B x 512 x 32]
        self.dropout3 = nn.Dropout()
        self.vbn9 = VirtualBatchNorm1d(512)
        self.lrelu9 = nn.LeakyReLU(negative_slope)
        self.conv10 = nn.Conv1d(512, 1024, 31, 2, 15)  # [B x 1024 x 16]
        self.vbn10 = VirtualBatchNorm1d(1024)
        self.lrelu10 = nn.LeakyReLU(negative_slope)
        self.conv11 = nn.Conv1d(1024, 2048, 31, 2, 15)  # [B x 2048 x 8]
        self.vbn11 = VirtualBatchNorm1d(2048)
        self.lrelu11 = nn.LeakyReLU(negative_slope)
        # 1x1 size kernel for dimension and parameter reduction
        self.conv_final = nn.Conv1d(2048, 1, kernel_size=1, stride=1)  # [B x 1 x 8]
        self.lrelu_final = nn.LeakyReLU(negative_slope)
        self.fully_connected = nn.Linear(in_features=8, out_features=1)  # [B x 1]
        self.sigmoid = nn.Sigmoid()

        # initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize weights for convolution layers using Xavier initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.xavier_normal(m.weight.data)

    def forward(self, x, ref_x):
        """
        Forward pass of discriminator.

        Args:
            x: input batch (signal)
            ref_x: reference input batch for virtual batch norm
        """
        # reference pass
        ref_x = self.conv1(ref_x)
        ref_x, mean1, meansq1 = self.vbn1(ref_x, None, None)
        ref_x = self.lrelu1(ref_x)
        ref_x = self.conv2(ref_x)
        ref_x, mean2, meansq2 = self.vbn2(ref_x, None, None)
        ref_x = self.lrelu2(ref_x)
        ref_x = self.conv3(ref_x)
        ref_x = self.dropout1(ref_x)
        ref_x, mean3, meansq3 = self.vbn3(ref_x, None, None)
        ref_x = self.lrelu3(ref_x)
        ref_x = self.conv4(ref_x)
        ref_x, mean4, meansq4 = self.vbn4(ref_x, None, None)
        ref_x = self.lrelu4(ref_x)
        ref_x = self.conv5(ref_x)
        ref_x, mean5, meansq5 = self.vbn5(ref_x, None, None)
        ref_x = self.lrelu5(ref_x)
        ref_x = self.conv6(ref_x)
        ref_x = self.dropout2(ref_x)
        ref_x, mean6, meansq6 = self.vbn6(ref_x, None, None)
        ref_x = self.lrelu6(ref_x)
        ref_x = self.conv7(ref_x)
        ref_x, mean7, meansq7 = self.vbn7(ref_x, None, None)
        ref_x = self.lrelu7(ref_x)
        ref_x = self.conv8(ref_x)
        ref_x, mean8, meansq8 = self.vbn8(ref_x, None, None)
        ref_x = self.lrelu8(ref_x)
        ref_x = self.conv9(ref_x)
        ref_x = self.dropout3(ref_x)
        ref_x, mean9, meansq9 = self.vbn9(ref_x, None, None)
        ref_x = self.lrelu9(ref_x)
        ref_x = self.conv10(ref_x)
        ref_x, mean10, meansq10 = self.vbn10(ref_x, None, None)
        ref_x = self.lrelu10(ref_x)
        ref_x = self.conv11(ref_x)
        ref_x, mean11, meansq11 = self.vbn11(ref_x, None, None)
        # train pass
        x = self.conv1(x)
        x, _, _ = self.vbn1(x, mean1, meansq1)
        x = self.lrelu1(x)
        x = self.conv2(x)
        x, _, _ = self.vbn2(x, mean2, meansq2)
        x = self.lrelu2(x)
        x = self.conv3(x)
        x = self.dropout1(x)
        x, _, _ = self.vbn3(x, mean3, meansq3)
        x = self.lrelu3(x)
        x = self.conv4(x)
        x, _, _ = self.vbn4(x, mean4, meansq4)
        x = self.lrelu4(x)
        x = self.conv5(x)
        x, _, _ = self.vbn5(x, mean5, meansq5)
        x = self.lrelu5(x)
        x = self.conv6(x)
        x = self.dropout2(x)
        x, _, _ = self.vbn6(x, mean6, meansq6)
        x = self.lrelu6(x)
        x = self.conv7(x)
        x, _, _ = self.vbn7(x, mean7, meansq7)
        x = self.lrelu7(x)
        x = self.conv8(x)
        x, _, _ = self.vbn8(x, mean8, meansq8)
        x = self.lrelu8(x)
        x = self.conv9(x)
        x = self.dropout3(x)
        x, _, _ = self.vbn9(x, mean9, meansq9)
        x = self.lrelu9(x)
        x = self.conv10(x)
        x, _, _ = self.vbn10(x, mean10, meansq10)
        x = self.lrelu10(x)
        x = self.conv11(x)
        x, _, _ = self.vbn11(x, mean11, meansq11)
        x = self.lrelu11(x)
        x = self.conv_final(x)
        x = self.lrelu_final(x)
        # reduce down to a scalar value
        x = torch.squeeze(x)
        x = self.fully_connected(x)
        return self.sigmoid(x)

## 1 Model Initialization and Training Loop

the models (Discriminator and Generator) are initialized and moved to the GPU if available. The training loop begins, where both the generator and discriminator are alternately trained.

- Initializes the discriminator and generator models.
- Moves the models to the GPU if available.
- Sets up optimizers for both the generator and discriminator using RMSprop.
- Iterates over the specified number of epochs.
- Within each epoch, it:
  - Trains the discriminator to distinguish between clean and noisy audio.
  - Trains the generator to produce clean audio from noisy input, based on feedback from the discriminator.
  - Calculates loss for both the discriminator and generator, and updates their parameters accordingly.
  - Displays training progress using `tqdm`.


In [None]:
#  Arguments (adjustable through command line or notebook)
BATCH_SIZE = 50  # Adjust as needed
NUM_EPOCHS = 7  # Adjust as needed

# Load data
print('Loading data...')
train_dataset = AudioDataset(data_type='train')
test_dataset = AudioDataset(data_type='test')
train_data_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_data_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Generate reference batch
ref_batch = train_dataset.reference_batch(BATCH_SIZE)

In [None]:

# Initialize models
discriminator = Discriminator()
generator = Generator()

# Move models to GPU if available
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()
    ref_batch = ref_batch.cuda()

ref_batch = Variable(ref_batch)

print("# Generator parameters:", sum(param.numel() for param in generator.parameters()))
print("# Discriminator parameters:", sum(param.numel() for param in discriminator.parameters()))

# Optimizers
g_optimizer = optim.RMSprop(generator.parameters(), lr=0.0001)
d_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.0001)

# Training loop
for epoch in range(NUM_EPOCHS):
    train_bar = tqdm(train_data_loader)
    for train_batch, train_clean, train_noisy in train_bar:
        # Latent vector - normal distribution
        z = nn.init.normal(torch.Tensor(train_batch.size(0), 1024, 8))
        if torch.cuda.is_available():
            train_batch, train_clean, train_noisy = train_batch.cuda(), train_clean.cuda(), train_noisy.cuda()
            z = z.cuda()
        train_batch, train_clean, train_noisy = Variable(train_batch), Variable(train_clean), Variable(train_noisy)
        z = Variable(z)

        # Train Discriminator to recognize clean audio as real
        discriminator.zero_grad()
        outputs = discriminator(train_batch, ref_batch)
        clean_loss = torch.mean((outputs - 1.0) ** 2)  # L2 loss - we want them all to be 1
        clean_loss.backward()

        # Train Discriminator to recognize generated audio as noisy
        generated_outputs = generator(train_noisy, z)
        outputs = discriminator(torch.cat((generated_outputs, train_noisy), dim=1), ref_batch)
        noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0
        noisy_loss.backward()

        # Update Discriminator
        d_optimizer.step()

        # Train Generator so that Discriminator recognizes generated audio as real
        generator.zero_grad()
        generated_outputs = generator(train_noisy, z)
        gen_noise_pair = torch.cat((generated_outputs, train_noisy), dim=1)
        outputs = discriminator(gen_noise_pair, ref_batch)

        g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
        # L1 loss between generated output and clean sample
        l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(train_clean)))
        g_cond_loss = 100 * torch.mean(l1_dist)  # Conditional loss
        g_loss = g_loss_ + g_cond_loss

        # Backprop + Optimize
        g_loss.backward()
        g_optimizer.step()

        # Print training status
        train_bar.set_description(
            'Epoch {}: d_clean_loss {:.4f}, d_noisy_loss {:.4f}, g_loss {:.4f}, g_conditional_loss {:.4f}'
            .format(epoch + 1, clean_loss.data[0], noisy_loss.data[0], g_loss.data[0], g_cond_loss.data[0]))


## 2 Model Testing and Saving

After training the model, this cell tests the generator by producing audio from noisy inputs in the test set. It saves the generated audio samples and the model parameters (generator and discriminator) after each epoch.

- For each test batch, it generates enhanced audio using the trained generator.
- Applies the `emphasis` function to enhance the audio further.
- Saves the generated audio samples to the `results` directory.
- Saves the generator and discriminator models after each epoch to the `epochs` directory.



In [None]:

# Test model
sample_rate = 16000
test_bar = tqdm(test_data_loader, desc='Test model and save generated audios')
for test_file_names, test_noisy in test_bar:
    z = nn.init.normal(torch.Tensor(test_noisy.size(0), 1024, 8))
    if torch.cuda.is_available():
        test_noisy, z = test_noisy.cuda(), z.cuda()
    test_noisy, z = Variable(test_noisy), Variable(z)
    fake_speech = generator(test_noisy, z).data.cpu().numpy()  # Convert to numpy array
    fake_speech = emphasis(fake_speech, emph_coeff=0.95, pre=False)

    # Save generated audio samples
    for idx in range(fake_speech.shape[0]):
        generated_sample = fake_speech[idx]
        file_name = os.path.join('results', '{}_e{}.wav'.format(test_file_names[idx].replace('.npy', ''), epoch + 1))
        wavfile.write(file_name, sample_rate, generated_sample.T)

# Save model parameters for each epoch
g_path = os.path.join('epochs', 'generator-{}.pkl'.format(epoch + 1))
d_path = os.path.join('epochs', 'discriminator-{}.pkl'.format(epoch + 1))
torch.save(generator.state_dict(), g_path)
torch.save(discriminator.state_dict(), d_path)


### Audio Enhancement Using Pre-trained Generator Model 

This code enhances a noisy audio file by applying a pre-trained generator model. The process involves several steps outlined below:

1. **Argument Parsing**:
   - The script uses `argparse` to accept two command-line arguments: 
     - `file_name`: The path to the noisy audio file that needs enhancement.
     - `epoch_name`: The name of the generator model checkpoint to be used for enhancement.

2. **Loading the Pre-trained Model**:
   - The `Generator` model is instantiated and loaded with the weights of the specified epoch.
   - If a GPU is available, the model is moved to the GPU for efficient computation.

3. **Audio Preprocessing**:
   - The noisy audio file is sliced into smaller segments using the `segment_audio_file` function. These smaller segments are easier to process individually by the model.

4. **Audio Enhancement**:
   - Each slice of the noisy audio is passed through the generator model for enhancement.
   - A noise vector `z` is generated and used as input alongside the noisy audio slice.
   - The `emphasis` function is applied both before and after enhancement to improve the audio quality.

5. **Saving the Enhanced Audio**:
   - After processing all slices, the enhanced audio is recombined into a continuous waveform.
   - The enhanced audio is saved as a `.wav` file with a filename that includes the prefix `enhanced_` followed by the original filename.


In [None]:
# Define file path and epoch name for the generator
FILE_NAME = "/home/antec/Desktop/Najjar/SEGAN-master/data/noisy_wav/p232_002.wav"  # Replace with your actual file path
EPOCH_NAME = "EPOCH_NAME = /home/antec/Desktop/Najjar/SEGAN-master/checkpoints/checkpoint_epoch7_20250109_180511.pt"

# Load the pre-trained generator model
generator = Generator()
generator.load_state_dict(torch.load(f'epochs/{EPOCH_NAME}', map_location='cpu'))
if torch.cuda.is_available():
    generator.cuda()


In [None]:
# Slice the noisy audio file into smaller segments
WINDOW_SIZE = 2 ** 14  # Approx. 1 second
SAMPLE_RATE = 16000
STRIDE = 0.5  
noisy_slices = segment_audio_file(FILE_NAME, WINDOW_SIZE, 1, sample_rate)
enhanced_speech = []

# Process each slice through the generator model
for noisy_slice in tqdm(noisy_slices, desc='Generate enhanced audio'):
    z = nn.init.normal(torch.Tensor(1, 1024, 8))  # Generate random noise vector
    noisy_slice = torch.from_numpy(emphasis(noisy_slice[np.newaxis, np.newaxis, :])).type(torch.FloatTensor)
    
    if torch.cuda.is_available():
        noisy_slice, z = noisy_slice.cuda(), z.cuda()
    
    noisy_slice, z = Variable(noisy_slice), Variable(z)
    generated_speech = generator(noisy_slice, z).data.cpu().numpy()
    generated_speech = emphasis(generated_speech, emph_coeff=0.95, pre=False)
    generated_speech = generated_speech.reshape(-1)
    enhanced_speech.append(generated_speech)


In [None]:
# Save the enhanced audio as a .wav file
enhanced_speech = np.array(enhanced_speech).reshape(1, -1)
file_name = os.path.join(os.path.dirname(FILE_NAME),
                         f'enhanced_{os.path.basename(FILE_NAME).split(".")[0]}.wav')
wavfile.write(file_name, sample_rate, enhanced_speech.T)
