In [6]:
from torchaudio.backend import sox_io_backend
import shutil
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.sox_effects import apply_effects_tensor
import librosa
import numpy as np
from tqdm import tqdm


In [7]:
class AudioDataset(Dataset):
    def __init__(self, audio_dir, transform=None, processed_dir=None):
        """
        Args:
            audio_dir (str): Directory with raw audio files.
            transform (callable, optional): Function to apply preprocessing to the waveform.
            processed_dir (str, optional): Directory to save processed audio files.
        """
        self.audio_dir = audio_dir
        self.audio_files = [
            os.path.join(audio_dir, f)
            for f in os.listdir(audio_dir)
            if f.lower().endswith(('.wav', '.mp3'))
        ]
        self.transform = transform
        self.processed_dir = processed_dir
        
        # Create processed directory if specified
        if self.processed_dir:
            os.makedirs(self.processed_dir, exist_ok=True)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        file_path = self.audio_files[idx]
        try:
            # Load the audio using sox_io backend
            waveform, sample_rate = torchaudio.load(file_path, backend="sox")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            # Create audio_broken directory if it doesn't exist
            broken_dir = os.path.join(os.path.dirname(self.audio_dir), "audio_broken")
            os.makedirs(broken_dir, exist_ok=True)
            
            # Move the broken file to audio_broken folder
            filename = os.path.basename(file_path)
            shutil.move(file_path, os.path.join(broken_dir, filename))
            
            # Remove this file from our list to avoid future attempts
            self.audio_files.remove(file_path)
            
            # Return None to indicate this file should be skipped
            return None, None, file_path
        
        # Apply the transform if provided
        if self.transform:
            waveform = self.transform(waveform, sample_rate)
            
            # After transform, verify and fix channel dimensions
            print(f"Processed file: {file_path}")
            print(f"Sample Rate: {sample_rate}")
            print(f"Waveform shape after transform: {waveform.shape}")
            
            # Reshape waveform to 2D: [channels, samples]
            # First squeeze to remove all extra dimensions
            waveform_2d = waveform.squeeze()
            
            # If the tensor becomes 1D after squeezing, add a channel dimension
            if waveform_2d.dim() == 1:
                waveform_2d = waveform_2d.unsqueeze(0)  # Add channel dimension [1, samples]
            
            print(f"Reshaped waveform: {waveform_2d.shape}")
            
            # Save the processed audio if output directory is specified
            if self.processed_dir:
                output_path = os.path.join(self.processed_dir, os.path.basename(file_path))
                torchaudio.save(output_path, waveform_2d, sample_rate)
                print(f"Saved to: {output_path}")
                
            # Update waveform to the properly shaped version
            waveform = waveform_2d
        
        return waveform, sample_rate, file_path

In [8]:
def spectral_subtraction_transform(waveform, sample_rate, n_fft=1024, hop_length=512, noise_frames=5):
    """
    Applies spectral subtraction to reduce stationary noise.
    Assumes that the first few frames (noise_frames) contain only noise.
    """
    # Convert tensor to numpy array and squeeze extra dimensions
    y = waveform.numpy().squeeze()
    # Compute STFT
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    magnitude, phase = np.abs(D), np.angle(D)
    # Estimate noise magnitude from the first few frames
    noise_mag = np.mean(magnitude[:, :noise_frames], axis=1, keepdims=True)
    # Subtract noise estimate and clip negative values
    subtracted = magnitude - noise_mag
    subtracted[subtracted < 0] = 0
    # Reconstruct the complex spectrum and invert the STFT
    D_clean = subtracted * np.exp(1j * phase)
    y_clean = librosa.istft(D_clean, hop_length=hop_length)
    # Convert back to a tensor with a channel dimension
    return torch.tensor(y_clean).unsqueeze(0)

def wiener_filter_transform(waveform, sample_rate, n_fft=1024, hop_length=512, noise_frames=5, beta=0.002):
    """
    Applies a basic Wiener filter to reduce noise.
    The gain is computed for each frequency bin based on an estimated SNR.
    """
    y = waveform.numpy().squeeze()
    # Compute STFT
    D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    magnitude, phase = np.abs(D), np.angle(D)
    # Estimate noise from the first few frames
    noise_mag = np.mean(magnitude[:, :noise_frames], axis=1, keepdims=True)
    power_spec = magnitude ** 2
    noise_power = noise_mag ** 2
    eps = 1e-8
    # Compute Wiener gain factor and clip it to a minimum value beta
    gain = np.maximum((power_spec - noise_power) / (power_spec + eps), beta)
    # Apply gain (take the square root because we're modifying magnitudes)
    filtered_D = np.sqrt(gain) * D
    y_clean = librosa.istft(filtered_D, hop_length=hop_length)
    return torch.tensor(y_clean).unsqueeze(0)

In [9]:
def preprocess_transform(waveform, sample_rate, target_sample_rate=16000,
                         silence_threshold=0.5, min_silence_duration=0.8,
                         noise_reduction_method=None):
    """
    Applies comprehensive preprocessing for ASR tasks, including optional noise reduction.
    
    Args:
        waveform (Tensor): Audio tensor.
        sample_rate (int): Sample rate of the audio.
        target_sample_rate (int): Target sample rate for resampling.
        silence_threshold (float): Threshold percentage for silence detection.
        min_silence_duration (float): Minimum duration (in seconds) of silence to trim.
        noise_reduction_method (str, optional): Choose 'spectral' or 'wiener' to apply that noise reduction method.
    
    Returns:
        Tensor: Preprocessed waveform.
    """
    # (1) Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # (2) Volume normalization using SoX gain effect
    effects = [["gain", "-n", "-3"]]
    try:
        waveform, sample_rate = apply_effects_tensor(waveform, sample_rate, effects)
    except Exception as e:
        print(f"Volume normalization failed, skipping: {e}")
    
    # (3) Optional noise reduction via SoX (if needed)
    # try:
    #     effects = [["noisered", "/dev/null", "0.2"]]
    #     waveform, sample_rate = apply_effects_tensor(waveform, sample_rate, effects)
    # except Exception as e:
    #     print(f"Noise reduction via SoX failed, skipping: {e}")
    
    # (4) Remove silence from beginning and end
    try:
        effects = [["silence", "1", "0.1", f"{silence_threshold}%", "1", str(min_silence_duration), f"{silence_threshold}%"]]
        waveform, sample_rate = apply_effects_tensor(waveform, sample_rate, effects)
    except Exception as e:
        print(f"Silence removal failed, skipping: {e}")
    
    # (5) High-pass filter to cut low-frequency noise
    try:
        effects = [["highpass", "100"]]
        waveform, sample_rate = apply_effects_tensor(waveform, sample_rate, effects)
    except Exception as e:
        print(f"High-pass filter failed, skipping: {e}")
    
    # (6) Optional custom noise reduction using our functions
    if noise_reduction_method == 'spectral':
        waveform = spectral_subtraction_transform(waveform, sample_rate)
    elif noise_reduction_method == 'wiener':
        waveform = wiener_filter_transform(waveform, sample_rate)
    
    return waveform

def custom_collate_fn(batch):
    batch = [item for item in batch if item[0] is not None]
    return torch.utils.data.dataloader.default_collate(batch)

In [10]:

if __name__ == "__main__":
    # Define directories for raw and processed audio
    AUDIO_DIR = "./audio_raw"
    PROCESSED_DIR = "./audio_processed"
    
    # Example usage: apply spectral subtraction noise reduction.
    # To switch to Wiener filtering, set noise_reduction_method='wiener'
    dataset = AudioDataset(
        audio_dir=AUDIO_DIR, 
        transform=lambda w, sr: preprocess_transform(w, sr, noise_reduction_method='spectral'),
        processed_dir=PROCESSED_DIR
    )
    
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)
    
    count = 0
    for i, (waveform, sample_rate, file_path) in enumerate(tqdm(dataloader, desc="Processing audio files")):
        count += 1
        if waveform is None:
            print(f"Skipping broken file: {file_path}")
            continue
            
        print(f"DataLoader waveform shape: {waveform.shape}")
        if count > 2:
            break

Processing audio files:   0%|          | 0/4282 [00:00<?, ?it/s]

Noise reduction via SoX failed, skipping: Unsupported effect: noisered


Processing audio files:   0%|          | 1/4282 [00:03<3:41:48,  3.11s/it]

Processed file: ./audio_raw/7f7c0f0e-97d1-47b7-98f6-796c0ea190af_54c7ab3d-9c89-49c5-83e1-92adc44d4157.mp3
Sample Rate: 24000
Waveform shape after transform: torch.Size([1, 69120])
Reshaped waveform: torch.Size([1, 69120])
Saved to: ./audio_processed/7f7c0f0e-97d1-47b7-98f6-796c0ea190af_54c7ab3d-9c89-49c5-83e1-92adc44d4157.mp3
DataLoader waveform shape: torch.Size([1, 1, 69120])
Noise reduction via SoX failed, skipping: Unsupported effect: noisered
Processed file: ./audio_raw/3acaa58c-1c25-4693-8b88-8634b085e21f_9ad4646b-c33f-481a-9bba-5ff51a93e1ee.mp3
Sample Rate: 24000
Waveform shape after transform: torch.Size([1, 62464])
Reshaped waveform: torch.Size([1, 62464])
Saved to: ./audio_processed/3acaa58c-1c25-4693-8b88-8634b085e21f_9ad4646b-c33f-481a-9bba-5ff51a93e1ee.mp3
DataLoader waveform shape: torch.Size([1, 1, 62464])
Noise reduction via SoX failed, skipping: Unsupported effect: noisered
Processed file: ./audio_raw/4c2906fa-e77a-486e-bab2-95903b04a648_bf468261-2e0d-4079-ba50-b60d8a6

Processing audio files:   0%|          | 2/4282 [00:03<1:57:25,  1.65s/it]

Saved to: ./audio_processed/4c2906fa-e77a-486e-bab2-95903b04a648_bf468261-2e0d-4079-ba50-b60d8a69a7bf.mp3
DataLoader waveform shape: torch.Size([1, 1, 79872])



