In [6]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.io.wavfile as wavfile
from scipy import signal
import IPython.display as ipd

def wiener_filter(noisy_signal, sample_rate, noise_estimate=None, frame_length=0.025, frame_step=0.010, alpha=2.0):
    """
    Apply Wiener filter for speech denoising
    
    Parameters:
    -----------
    noisy_signal : numpy array
        Noisy speech signal
    sample_rate : int
        Sampling rate of the signal
    noise_estimate : numpy array, optional
        Estimated noise signal. If None, first 0.5 seconds are assumed to be noise
    frame_length : float
        Length of each frame in seconds
    frame_step : float
        Step between frames in seconds
    alpha : float
        Over-subtraction factor to control noise reduction amount
        
    Returns:
    --------
    enhanced_signal : numpy array
        Denoised speech signal
    """
    # Convert frame length and step from seconds to samples
    frame_len = int(frame_length * sample_rate)
    frame_step_samples = int(frame_step * sample_rate)
    
    # If noise estimate not provided, use first 0.5s of signal
    if noise_estimate is None:
        noise_length = int(0.5 * sample_rate)
        noise_estimate = noisy_signal[:noise_length]
    
    # Calculate noise power spectrum
    noise_psd = estimate_spectrum(noise_estimate, frame_len)
    
    # Prepare for processing in frames
    num_frames = 1 + int((len(noisy_signal) - frame_len) / frame_step_samples)
    enhanced_signal = np.zeros_like(noisy_signal)
    window = np.hamming(frame_len)
    
    # Process each frame
    for i in range(num_frames):
        # Extract frame
        start = i * frame_step_samples
        end = start + frame_len
        if end > len(noisy_signal):
            break
            
        frame = noisy_signal[start:end]
        
        # Apply window
        windowed_frame = frame * window
        
        # FFT
        frame_fft = np.fft.rfft(windowed_frame)
        frame_power = np.abs(frame_fft) ** 2
        
        # Wiener filter formula: H(f) = P_s(f) / (P_s(f) + P_n(f))
        # Where P_s(f) is estimated as P_y(f) - P_n(f)
        snr = frame_power / (noise_psd + 1e-10)
        gain = snr / (1 + snr)  # Wiener filter
        
        # Apply gain to the frame
        enhanced_fft = frame_fft * gain
        
        # Inverse FFT
        enhanced_frame = np.fft.irfft(enhanced_fft)
        
        # Overlap-add
        enhanced_signal[start:end] += enhanced_frame * window
    
    return enhanced_signal

def estimate_spectrum(signal_segment, frame_len):
    """
    Estimate the power spectral density of a signal segment
    """
    # Split into frames with 50% overlap
    num_frames = int(len(signal_segment) / (frame_len/2)) - 1
    window = np.hamming(frame_len)
    spectrum = np.zeros(frame_len // 2 + 1)
    
    for i in range(num_frames):
        start = int(i * frame_len / 2)
        frame = signal_segment[start:start+frame_len]
        if len(frame) < frame_len:
            break
        windowed = frame * window
        fft = np.fft.rfft(windowed)
        spectrum += np.abs(fft) ** 2
    
    return spectrum / num_frames if num_frames > 0 else spectrum

def demo_wiener_filter(clean_path=None, noise_path=None, noisy_path=None, snr_db=5):
    """
    Demonstrate Wiener filter on speech data
    
    Either provide:
    - noisy_path: path to already noisy speech
    OR
    - clean_path and noise_path: to mix clean speech with noise at specified SNR
    """
    if noisy_path:
        # Load noisy speech
        sample_rate, noisy_signal = wavfile.read(noisy_path)
        if len(noisy_signal.shape) > 1:  # Convert stereo to mono if needed
            noisy_signal = noisy_signal.mean(axis=1)
    elif clean_path and noise_path:
        # Load clean speech and noise
        sample_rate, clean_signal = wavfile.read(clean_path)
        if len(clean_signal.shape) > 1:
            clean_signal = clean_signal.mean(axis=1)
            
        noise_rate, noise_signal = wavfile.read(noise_path)
        if len(noise_signal.shape) > 1:
            noise_signal = noise_signal.mean(axis=1)
        
        # Ensure same sample rate
        if noise_rate != sample_rate:
            raise ValueError("Clean and noise signals must have the same sample rate")
        
        # Adjust noise level to desired SNR
        clean_power = np.mean(clean_signal**2)
        noise_power = np.mean(noise_signal**2)
        scaling_factor = np.sqrt(clean_power / (noise_power * (10**(snr_db/10))))
        
        # Trim or repeat noise to match clean signal length
        if len(noise_signal) < len(clean_signal):
            noise_signal = np.tile(noise_signal, int(np.ceil(len(clean_signal) / len(noise_signal))))
        noise_signal = noise_signal[:len(clean_signal)]
        
        # Mix signals
        scaled_noise = scaling_factor * noise_signal
        noisy_signal = clean_signal + scaled_noise
    else:
        raise ValueError("Either provide noisy_path or both clean_path and noise_path")
    
    # Apply Wiener filter
    enhanced_signal = wiener_filter(noisy_signal, sample_rate)
    
    # Normalize signals for plotting
    if clean_path:
        clean_signal = clean_signal / np.max(np.abs(clean_signal))
    noisy_signal = noisy_signal / np.max(np.abs(noisy_signal))
    enhanced_signal = enhanced_signal / np.max(np.abs(enhanced_signal))
    
    # Plot results
    plt.figure(figsize=(15, 10))
    
    if clean_path:
        plt.subplot(3, 1, 1)
        plt.title('Clean Speech')
        plt.plot(clean_signal)
        plt.xlim([0, len(clean_signal)])
        
        plt.subplot(3, 1, 2)
    else:
        plt.subplot(2, 1, 1)
    plt.title('Noisy Speech')
    plt.plot(noisy_signal)
    plt.xlim([0, len(noisy_signal)])
    
    if clean_path:
        plt.subplot(3, 1, 3)
    else:
        plt.subplot(2, 1, 2)
    plt.title('Enhanced Speech')
    plt.plot(enhanced_signal)
    plt.xlim([0, len(enhanced_signal)])
    
    plt.tight_layout()
    plt.show()
    
    # Return the signals for audio playback
    return {
        'noisy': (noisy_signal, sample_rate),
        'enhanced': (enhanced_signal, sample_rate),
        'clean': (clean_signal, sample_rate) if clean_path else None
    }

# Example usage (uncomment to use):
# signals = demo_wiener_filter(clean_path='clean_speech.wav', noise_path='noise.wav', snr_db=5)

# To listen to the results:
# ipd.display(ipd.Audio(signals['noisy'][0], rate=signals['noisy'][1]))
# ipd.display(ipd.Audio(signals['enhanced'][0], rate=signals['enhanced'][1]))
# if signals['clean']:
#     ipd.display(ipd.Audio(signals['clean'][0], rate=signals['clean'][1]))

In [7]:
from datasets import Dataset, Audio, DatasetDict,load_dataset ,concatenate_datasets
from torch.utils.data import DataLoader
import pandas as pd
import os

# Define paths
dataset_dir = "/home/hkngae/COMP5412/data/NoisySpeechDataset"
demand_dir = "/home/hkngae/COMP5412/data/local_datasets"
metadata_file = os.path.join(dataset_dir, "metadata.csv")
first_n = 8000  # Number of examples to load for trial
demand_ds = load_dataset("JacobLinCool/VoiceBank-DEMAND-16k", cache_dir=demand_dir)

### Dataset demand_ds contains: ['train', 'test']
### Split 'train' contains 11572 examples
### Features: {'id': Value(dtype='string', id=None), 'clean': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'noisy': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}
### Split 'test' contains 824 examples
### Features: {'id': Value(dtype='string', id=None), 'clean': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'noisy': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}



# Check if metadata file exists and read its structure first
if os.path.exists(metadata_file):
    # Just peek at the first few rows to see the structure
    print("Metadata file columns:")
    print(pd.read_csv(metadata_file, nrows=1).columns.tolist())

    # Load only the first n examples from metadata for trial
    metadata_df = pd.read_csv(metadata_file).head(first_n)

    # Assuming columns like "noisy_file", "clean_file", "snr" exist
    # Adapt these column names to match your actual metadata structure
    dataset_dict = {
        "id": [str(i) for i in range(len(metadata_df))],
        "noisy": metadata_df["noisy_file"].tolist() if "noisy_file" in metadata_df.columns else [],
        "clean": metadata_df["clean_file"].tolist() if "clean_file" in metadata_df.columns else [],
        #"snr": metadata_df["snr"].tolist() if "snr" in metadata_df.columns else []
    }
    
    # Create the dataset
    small_ds = Dataset.from_dict(dataset_dict)
    
    # Add audio loading functionality 
    if "noisy_file" in metadata_df.columns:
        small_ds = small_ds.cast_column("noisy", Audio(sampling_rate=16000))
    if "clean_file" in metadata_df.columns:
        small_ds = small_ds.cast_column("clean", Audio(sampling_rate=16000))
    
    # Inspect the small dataset
    print(f"Small dataset contains {len(small_ds)} examples")
    print(f"Features: {small_ds.features}")
    print(demand_ds["train"].features)
    #concat with demand_ds['train']
    combined_ds = DatasetDict({
        'custom': small_ds,
        'train': demand_ds['train'],
        'test': demand_ds['test']
        })
    full_ds = concatenate_datasets([combined_ds['train'], combined_ds['test'], combined_ds['custom']])

    print(f"\nCombined dataset size: {len(full_ds)} examples")
    print(f"Combined dataset features: {full_ds.features}")

    
else:
    print(f"Metadata file not found at {metadata_file}")
    print("Please check the path or create the metadata file.")

Metadata file columns:
['noisy_file', 'clean_file', 'noise_file', 'snr']
Small dataset contains 8000 examples
Features: {'id': Value(dtype='string', id=None), 'noisy': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'clean': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}
{'id': Value(dtype='string', id=None), 'clean': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'noisy': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}

Combined dataset size: 20396 examples
Combined dataset features: {'id': Value(dtype='string', id=None), 'clean': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'noisy': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}


In [16]:
#calculate snr
def calculate_snr(clean, noisy):
    """
    Calculate SNR in dB
    """
    noise = noisy - clean
    snr = 10 * np.log10(np.sum(clean**2) / np.sum(noise**2))
    return snr

In [17]:
from IPython.display import Audio, display
# get 5 examples from the dataset
sample_ds = full_ds.shuffle(seed=42).select(range(5))
print("Sample dataset:")
print(sample_ds)
print("Sample dataset features:")
print(sample_ds.features)
#denoise the first 5 examples
for i, example in enumerate(sample_ds):
    print(f"Example {i+1}:")
    noisy_signal = np.array(example['noisy']["array"])
    sample_rate = 16000
    enhanced_signal = wiener_filter(noisy_signal, sample_rate)
    
    # Normalize signals for playback
    noisy_signal = noisy_signal / np.max(np.abs(noisy_signal))
    enhanced_signal = enhanced_signal / np.max(np.abs(enhanced_signal))
    
    # Play the original and enhanced signals
    print("Playing original noisy signal...")
    #print snr
    snr_before = calculate_snr(example['clean']["array"], noisy_signal)
    print(f"SNR: {snr_before} dB")
    display(Audio(noisy_signal, rate=sample_rate))
    
    print("Playing enhanced signal...")
    snr_after = calculate_snr(example['clean']["array"], enhanced_signal)
    print(f"SNR: {snr_after} dB")
    display(Audio(enhanced_signal, rate=sample_rate))
    print("\n")

Sample dataset:
Dataset({
    features: ['id', 'clean', 'noisy'],
    num_rows: 5
})
Sample dataset features:
{'id': Value(dtype='string', id=None), 'clean': Audio(sampling_rate=16000, mono=True, decode=True, id=None), 'noisy': Audio(sampling_rate=16000, mono=True, decode=True, id=None)}
Example 1:
Playing original noisy signal...
SNR: -4.134945857995857 dB


Playing enhanced signal...
SNR: -1.4292001712559155 dB




Example 2:
Playing original noisy signal...
SNR: -13.542787047513261 dB


Playing enhanced signal...
SNR: -12.88381997551684 dB




Example 3:
Playing original noisy signal...
SNR: -1.4711663876156464 dB


Playing enhanced signal...
SNR: -0.7026803373805156 dB




Example 4:
Playing original noisy signal...
SNR: -7.247979989526533 dB


Playing enhanced signal...
SNR: -6.546207296620668 dB




Example 5:
Playing original noisy signal...
SNR: -13.356966934841093 dB


Playing enhanced signal...
SNR: -12.708520564075187 dB




