In [68]:
from torchaudio.backend import sox_io_backend
import shutil
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.sox_effects import apply_effects_tensor
import librosa
import numpy as np
from tqdm import tqdm

In [69]:
def compute_snr(waveform, n_fft=1024, hop_length=512, noise_frames=5):
    """
    Computes an approximate SNR (in dB) for the given waveform.
    
    Args:
        waveform (Tensor): Audio tensor of shape [channels, samples]. Should be mono.
        n_fft (int): FFT window size.
        hop_length (int): Hop length for STFT.
        noise_frames (int): Number of initial frames to use as noise estimate.
    
    Returns:
        float: Estimated SNR in decibels.
    """
    # Convert tensor to numpy array and squeeze extra dimensions to get a 1D array
    y = waveform.numpy().squeeze()
    
    # Compute the STFT of the signal
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    magnitude = np.abs(D)
    
    # Estimate noise power from the first few frames
    noise_power = np.mean(magnitude[:, :noise_frames] ** 2)
    
    # Compute overall signal power (mean power across all frames)
    signal_power = np.mean(magnitude ** 2)
    
    # Avoid division by zero
    epsilon = 1e-8
    snr = 10 * np.log10(signal_power / (noise_power + epsilon))
    return snr


In [70]:
class AudioDataset(Dataset):
    def __init__(self, audio_dir, transform=None, processed_dir=None,  snr_threshold=5):
        """
        Args:
            audio_dir (str): Directory with raw audio files.
            transform (callable, optional): Function to apply preprocessing to the waveform.
            processed_dir (str, optional): Directory to save processed audio files.
        """
        self.audio_dir = audio_dir
        self.audio_files = [
            os.path.join(audio_dir, f)
            for f in os.listdir(audio_dir)
            if f.lower().endswith(('.wav', '.mp3'))
        ]
        self.transform = transform
        self.processed_dir = processed_dir
        self.snr_threshold = snr_threshold
        
        if self.processed_dir:
            os.makedirs(self.processed_dir, exist_ok=True)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        file_path = self.audio_files[idx]
        try:
            waveform, sample_rate = torchaudio.load(file_path, backend="sox")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            broken_dir = os.path.join(os.path.dirname(self.audio_dir), "audio_broken")
            os.makedirs(broken_dir, exist_ok=True)
            filename = os.path.basename(file_path)
            shutil.move(file_path, os.path.join(broken_dir, filename))
            self.audio_files.remove(file_path)
            return None, None, file_path
        
        snr = compute_snr(waveform)
        if snr < self.snr_threshold:
            print(f"File {file_path} has very low SNR: {snr:.2f} dB. Skipping.")
            # Optionally move to a "too_noisy" folder
            noisy_dir = os.path.join(os.path.dirname(self.audio_dir), "audio_noisy")
            os.makedirs(noisy_dir, exist_ok=True)
            shutil.move(file_path, os.path.join(noisy_dir, os.path.basename(file_path)))
            self.audio_files.remove(file_path)
            return None, None, file_path
        
        if self.transform:
            waveform = self.transform(waveform, sample_rate)
            
            print(f"Processed file: {file_path}")
            print(f"Sample Rate: {sample_rate}")
            print(f"Waveform shape after transform: {waveform.shape}")
            waveform_2d = waveform.squeeze()
            
            if waveform_2d.dim() == 1:
                waveform_2d = waveform_2d.unsqueeze(0)
            
            print(f"Reshaped waveform: {waveform_2d.shape}")
            
            if self.processed_dir:
                output_path = os.path.join(self.processed_dir, os.path.basename(file_path))
                torchaudio.save(output_path, waveform_2d, sample_rate)
                print(f"Saved to: {output_path}")
                
            waveform = waveform_2d
        
        return waveform, sample_rate, file_path

In [71]:
def spectral_subtraction_transform(waveform, sample_rate, n_fft=1024, hop_length=512, noise_frames=5):
    """
    Applies spectral subtraction to reduce stationary noise.
    Assumes that the first few frames (noise_frames) contain only noise.
    """
    # Convert tensor to numpy array and squeeze extra dimensions
    y = waveform.numpy().squeeze()
    # Compute STFT
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    magnitude, phase = np.abs(D), np.angle(D)
    # Estimate noise magnitude from the first few frames
    noise_mag = np.mean(magnitude[:, :noise_frames], axis=1, keepdims=True)
    # Subtract noise estimate and clip negative values
    subtracted = magnitude - noise_mag
    subtracted[subtracted < 0] = 0
    # Reconstruct the complex spectrum and invert the STFT
    D_clean = subtracted * np.exp(1j * phase)
    y_clean = librosa.istft(D_clean, hop_length=hop_length)
    # Convert back to a tensor with a channel dimension
    return torch.tensor(y_clean).unsqueeze(0)

def wiener_filter_transform(waveform, sample_rate, n_fft=1024, hop_length=512, noise_frames=5, beta=0.002):
    """
    Applies a basic Wiener filter to reduce noise.
    The gain is computed for each frequency bin based on an estimated SNR.
    """
    y = waveform.numpy().squeeze()
    # Compute STFT
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    magnitude, phase = np.abs(D), np.angle(D)
    # Estimate noise from the first few frames
    noise_mag = np.mean(magnitude[:, :noise_frames], axis=1, keepdims=True)
    power_spec = magnitude ** 2
    noise_power = noise_mag ** 2
    eps = 1e-8
    # Compute Wiener gain factor and clip it to a minimum value beta
    gain = np.maximum((power_spec - noise_power) / (power_spec + eps), beta)
    # Apply gain (take the square root because we're modifying magnitudes)
    filtered_D = np.sqrt(gain) * D
    y_clean = librosa.istft(filtered_D, hop_length=hop_length)
    return torch.tensor(y_clean).unsqueeze(0)

In [72]:
def preprocess_transform(waveform, sample_rate, target_sample_rate=16000,
                         silence_threshold=1.5, min_silence_duration=0.3,
                         noise_reduction_method=None):
    """
    Applies comprehensive preprocessing for ASR tasks, including optional noise reduction.
    
    Args:
        waveform (Tensor): Audio tensor.
        sample_rate (int): Sample rate of the audio.
        target_sample_rate (int): Target sample rate for resampling.
        silence_threshold (float): Threshold percentage for silence detection.
        min_silence_duration (float): Minimum duration (in seconds) of silence to trim.
        noise_reduction_method (str, optional): Choose 'spectral' or 'wiener' to apply that noise reduction method.
    
    Returns:
        Tensor: Preprocessed waveform.
    """
    # (1) Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # (2) Volume normalization using SoX gain effect
    effects = [["gain", "-n", "-3"]]
    try:
        waveform, sample_rate = apply_effects_tensor(waveform, sample_rate, effects)
    except Exception as e:
        print(f"Volume normalization failed, skipping: {e}")
    
    # (3) Trim only leading and trailing silences using librosa.effects.trim
    try:
        # Convert waveform to a 1D numpy array
        audio_np = waveform.squeeze().numpy()
        
        # Use librosa.effects.trim to remove only the leading and trailing silence.
        # The 'top_db' parameter controls what is considered silence.
        # Adjust top_db (e.g., 30) to be less aggressive if needed.
        trimmed_audio, _ = librosa.effects.trim(audio_np, top_db=30)
        
        # Convert back to a PyTorch tensor and add the channel dimension back
        waveform = torch.tensor(trimmed_audio).unsqueeze(0)
        
        print(f"Trimmed waveform shape: {waveform.shape}")
    except Exception as e:
        print(f"Librosa silence trim failed, skipping: {e}")

    # (4) High-pass filter to cut low-frequency noise
    try:
        effects = [["highpass", "100"]]
        waveform, sample_rate = apply_effects_tensor(waveform, sample_rate, effects)
    except Exception as e:
        print(f"High-pass filter failed, skipping: {e}")
    
    # (5) Optional custom noise reduction using our functions
    if noise_reduction_method == 'spectral':
        waveform = spectral_subtraction_transform(waveform, sample_rate)
    elif noise_reduction_method == 'wiener':
        waveform = wiener_filter_transform(waveform, sample_rate)
    
    return waveform

In [73]:

def custom_collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    return torch.utils.data.dataloader.default_collate(batch)

In [74]:

if __name__ == "__main__":
    # Define directories for raw and processed audio
    AUDIO_DIR = "./audio_raw"
    PROCESSED_DIR = "./audio_processed"
    
    # Example usage: apply spectral subtraction noise reduction.
    # To switch to Wiener filtering, set noise_reduction_method='wiener'
    dataset = AudioDataset(
        audio_dir=AUDIO_DIR, 
        transform=lambda w, sr: preprocess_transform(w, sr, noise_reduction_method='spectral'),
        processed_dir=PROCESSED_DIR,
        snr_threshold= 0.2
    )
    
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
    
    count = 0
    for i, (waveform, sample_rate, file_path) in enumerate(tqdm(dataloader, desc="Processing audio files")):
        count += 1
        if waveform is None:
            print(f"Skipping broken file: {file_path}")
            continue
            
        print(f"DataLoader waveform shape: {waveform.shape}")
        if count > 2:
            break

Processing audio files:   0%|          | 1/4285 [00:11<13:29:15, 11.33s/it]

Trimmed waveform shape: torch.Size([1, 208896])
Processed file: ./audio_raw/f5a009c0-59fe-40af-bea2-b9d70b817e7d_16e2c90c-529c-4829-84af-925a2abd909d.mp3
Sample Rate: 24000
Waveform shape after transform: torch.Size([1, 208896])
Reshaped waveform: torch.Size([1, 208896])
Saved to: ./audio_processed/f5a009c0-59fe-40af-bea2-b9d70b817e7d_16e2c90c-529c-4829-84af-925a2abd909d.mp3
DataLoader waveform shape: torch.Size([1, 1, 208896])


Processing audio files:   0%|          | 2/4285 [00:11<5:42:25,  4.80s/it] 

Trimmed waveform shape: torch.Size([1, 208384])
Processed file: ./audio_raw/f41b73ab-fade-40cc-85e7-cc37e6921f0a_a4f93573-e6f3-4bd5-a6ee-f23a48feb9b3.mp3
Sample Rate: 24000
Waveform shape after transform: torch.Size([1, 208384])
Reshaped waveform: torch.Size([1, 208384])
Saved to: ./audio_processed/f41b73ab-fade-40cc-85e7-cc37e6921f0a_a4f93573-e6f3-4bd5-a6ee-f23a48feb9b3.mp3
DataLoader waveform shape: torch.Size([1, 1, 208384])


Processing audio files:   0%|          | 2/4285 [00:11<7:01:21,  5.90s/it]

Trimmed waveform shape: torch.Size([1, 159232])
Processed file: ./audio_raw/ad67da86-2eb9-4892-a28c-b5a08e767c5f_3645f126-2167-4bb4-9e14-aacb88bbd67c.mp3
Sample Rate: 24000
Waveform shape after transform: torch.Size([1, 159232])
Reshaped waveform: torch.Size([1, 159232])
Saved to: ./audio_processed/ad67da86-2eb9-4892-a28c-b5a08e767c5f_3645f126-2167-4bb4-9e14-aacb88bbd67c.mp3
DataLoader waveform shape: torch.Size([1, 1, 159232])



