# Preprocessing Code

This code includes the audio pre-processing as well as the feature extraction for the training data

## Code Setup

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

from __future__ import division
from __future__ import print_function
from os import path
from pathlib import Path
import os, glob, torch,torchaudio, re
from python_speech_features import delta
from python_speech_features import mfcc

import matplotlib.pyplot as plt
import numpy as np
import scipy.io.wavfile as wav
import sys
import speech_dtw.qbe as qbe

from transformers import WavLMModel
from sklearn.decomposition import PCA

sys.path.append("..")
sys.path.append(path.join("..", "utils"))

## Pre-processing Function

In [None]:
# Removing silence and resampling recordings
SAMPLE_RATE = 16000 

# Load Silero VAD
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                              model='silero_vad',
                              force_reload=False,
                              trust_repo=True)
(get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils

def preProcessAudio(folder):
    for wav_fn in Path(folder).rglob("*.wav"): #Loop through all .wav files in folder and subfolders
        wav = read_audio(str(wav_fn), sampling_rate=SAMPLE_RATE)  # loads & resamples to 16kHz mono
        ts = get_speech_timestamps(
            wav, model, sampling_rate=SAMPLE_RATE,
            threshold=0.35,
            speech_pad_ms=300,
            min_speech_duration_ms=150,
            min_silence_duration_ms=300,
        ) #Gets which timestamps have speech
        if ts:  # If there is speech detected, save it
            wav_clean = collect_chunks(ts, wav)
            save_audio(str(wav_fn), wav_clean, sampling_rate=SAMPLE_RATE)  # overwrite
            print("Processed", wav_fn)

Example usage:

In [None]:
preProcessAudio("TrainingData/Child/Afrikaans")

## Feature Extraction

In [None]:
WAVLM_LAYER_INDEX = 6

# Select the WavLM model to be used here
device = "cpu"
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(device).eval()

def cmvn(X):
    # X: [T, D] NumPy
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True)
    return (X - mu) / (sd + 1e-8)

def getMFCCsFeatures(file): #A function which extracts MFCCs features from a given audio file
    sig, rate = torchaudio.load(file) #Reads the audio file, extracting the sample rate and signal data (as an array)

    #Check if sampled as correct sampling rate, if not - resample
    if rate != SAMPLE_RATE:
        print("Resampling", file ,"at 16kHz.\n")
        sig = torchaudio.functional.resample(sig, rate, SAMPLE_RATE)

    sig = sig.squeeze(0).numpy()

    #Extract features
    MFCC_static = mfcc(sig, SAMPLE_RATE) #Extracts MFCCs features given
    MFCC_deltas = delta(MFCC_static, 2) #Calculates delta (first derivative) of MFCCs features
    MFCC_delta_deltas = delta(MFCC_deltas, 2) #Calculates delta-delta (second derivative) of MFCCs features
    
    #Combine static, delta, and delta-delta features into a single feature vector
    features = np.hstack((MFCC_static, MFCC_deltas, MFCC_delta_deltas))
    features = cmvn(features) #Applies cepstral mean and variance normalization to features

    return features

def getWavLMFeatures(file): #A function which extracts MFCCs features from a given audio file
    sig, rate = torchaudio.load(file) #Reads the audio file, extracting the sample rate and signal data (as an array)

    #Check if sampled as correct sampling rate, if not - resample
    if rate != SAMPLE_RATE:
        print("Resampling", file ,"at 16kHz.\n")
        sig = torchaudio.functional.resample(sig, rate, SAMPLE_RATE)

    #Extracts layer 6 features
    sig = sig.to(device)
    with torch.inference_mode():
        out = model(sig, output_hidden_states=True)
        features = out.hidden_states[WAVLM_LAYER_INDEX].squeeze(0)  # [T, D] torch

    #Convert to numpy
    features = features.numpy()

    #Apply CMVN
    features = cmvn(features) #Applies cepstral mean and variance normalization to features
    
    return features

### Function to save the features as .pt files

In [None]:
def extractAndSaveFeatures(inputPath, outputPath, FeatureType): #Saves all features so they do not have to be re-extracted each time
    inputPath = Path(inputPath)
    outputPath  = Path(outputPath)

    #Expects naming to go A01/01_00.wav, where A01 is the speaker ID and 01 is the label
    #A02/01_02.wav would refer to adult speaker 2, where 01 is the label, and 02 is the utterance number (utterance 3)
    filePattern = re.compile(r"^\d{2}_\d{2}\.wav$")

    for wav_fn in inputPath.rglob("*.wav"):  # Loop through all .wav files
        fileName = wav_fn.stem                   # e.g. "03_01" or "noNum_01"
        prefix = fileName.split("_")[0]          # "03" or "noNum"

        # Handle label
        if prefix.isdigit():
            label = int(prefix)
        else:
            label = "No Number"  # e.g. "noNum_01.wav"  

        if FeatureType == "MFCCs":
            # Extract MFCC features
            features = getMFCCsFeatures(str(wav_fn)) #Gets MFCC features as numpy array
            features = torch.from_numpy(features).float() #Convert to torch tensor
        elif FeatureType == "WavLM":
            # Extract WavLM features
            features = getWavLMFeatures(str(wav_fn)) #Gets WavLM features as numpy array
            features = torch.from_numpy(features).float() #Convert to torch tensor

        # Build save path
        rel_path = wav_fn.relative_to(inputPath)   # e.g. A01/03_00.wav
        save_path = outputPath / rel_path.with_suffix(".pt")
        save_path.parent.mkdir(parents=True, exist_ok=True)

        # Package data
        data = {
            "features": features,
            "label": label,
            "speaker": wav_fn.parent.name,
            "utt_id": fileName,
            "feat_type": FeatureType
        }

        torch.save(data, save_path)
        print(f"Saved {save_path}")

    print("Feature extraction complete.")


Example usage:

In [None]:
extractAndSaveFeatures("TrainingData/Child/Afrikaans", "TrainingFeatures/WavLMBase+/Afrikaans", "WavLM")