# Prediction Methods

This file contains the code that tunes the hyperparameters for the k-NN, r-NN, and Entropy based confidence prediction methods.
It also evaluates the performance of these models on different datasets.

## Setup

In [1]:
from __future__ import division
from __future__ import print_function
from os import path
import os, glob, torch,torchaudio, re
from python_speech_features import delta
from python_speech_features import mfcc

import matplotlib.pyplot as plt
import numpy as np
import scipy.io.wavfile as wav
import sys
import speech_dtw.qbe as qbe

from transformers import WavLMModel
from sklearn.decomposition import PCA

import math
from collections import Counter
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

sys.path.append("..")
sys.path.append(path.join("..", "utils"))

SAMPLE_RATE = 16000 
WAVLM_LAYER_INDEX = 6

device = "cpu"
model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(device).eval()

def cmvn(X):
    # X: [T, D] NumPy
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True)
    return (X - mu) / (sd + 1e-8)

def getWavLMFeatures(file): #A function which extracts MFCCs features from a given audio file
    sig, rate = torchaudio.load(file) #Reads the audio file, extracting the sample rate and signal data (as an array)

    #Check if sampled as correct sampling rate, if not - resample
    if rate != SAMPLE_RATE:
        print("Resampling", file ,"at 16kHz.\n")
        sig = torchaudio.functional.resample(sig, rate, SAMPLE_RATE)

    #Extracts layer 6 features
    sig = sig.to(device)
    with torch.inference_mode():
        out = model(sig, output_hidden_states=True)
        features = out.hidden_states[WAVLM_LAYER_INDEX].squeeze(0)  # [T, D] torch

    #Convert to numpy
    features = features.numpy()

    #Apply CMVN
    features = cmvn(features) #Applies cepstral mean and variance normalization to features
    
    return features

def getMinimumCost(queryFile, templateFile, FeatureType):
    #Loading the features
    if FeatureType == "MFCCs":
        queryFeatures = getMFCCsFeatures(queryFile) #Extract features for query data
    elif FeatureType == "WavLM":
        queryFeatures = getWavLMFeatures(queryFile) #Extract features for query data
        
    templateFeatures = torch.load(templateFile) #Load the template's feature file
    templateFeatures = templateFeatures["features"].numpy() #Extract the features as numpy arrays

    #Make both feature sets 2D, float64, contiguous:
    queryFeatures = np.ascontiguousarray(queryFeatures, dtype=np.float64)
    templateFeatures = np.ascontiguousarray(templateFeatures, dtype=np.float64)

    distance = qbe.dtw_sweep_min(queryFeatures, templateFeatures) #Calculate the minimum sweeping DTW distance between the two feature sets

    return distance

  from .autonotebook import tqdm as notebook_tqdm


## k-NN

In [6]:
def accuracies_knn(testFolder, templateFolder, ks):
    correct = {k: 0 for k in ks}
    total = 0
    class_counts = Counter()
    correct_per_class = {k: Counter() for k in ks}

    # Load templates
    template_files = list(Path(templateFolder).rglob("*.pt"))
    template_labels = {}
    for temp in template_files:
        parts = temp.stem.split("_")
        prefix = parts[0]
        if prefix.isdigit():
            template_labels[str(temp)] = int(prefix)

    for testFile in Path(testFolder).rglob("*.wav"):
        parts = testFile.stem.split("_")
        prefix = parts[0]
        true_label = int(prefix) if prefix.isdigit() else "No Number"
        class_counts[true_label] += 1
        total += 1

        # Compute distances
        distances = []
        for temp_path in template_labels:
            dist = getMinimumCost(str(testFile), temp_path, "WavLM")
            distances.append((dist, template_labels[temp_path]))

        # Sort
        distances.sort(key=lambda x: x[0])

        # Evaluate for each k
        for k in ks:
            if k <= 0: continue
            knn = distances[:k]
            labels = [lab for d, lab in knn]
            counts = Counter(labels)
            predicted_label = counts.most_common(1)[0][0]
            if predicted_label == true_label:
                correct[k] += 1
                correct_per_class[k][true_label] += 1

    results = {}
    for k in sorted(ks):
        accuracy = (correct[k] / total) * 100 if total > 0 else 0.0
        print(f"Classification Results for k={k}:")
        print(f"Accuracy: {accuracy:.2f}%")
        results[k] = {
            "accuracy": accuracy,
            "class_counts": dict(class_counts),
            "correct_per_class": dict(correct_per_class[k]),
        }
    return results


Usage:

In [7]:
# Validation Data
ks = np.arange(1, 11, 1)
knnAccuracies = accuracies_knn("ValidationData/OnlyNumbers",
                               "TrainingData/TrainingFeatures/WavLMBase+/English",
                               ks=ks)

Classification Results for k=1:
Accuracy: 75.49%
Classification Results for k=2:
Accuracy: 75.49%
Classification Results for k=3:
Accuracy: 81.37%
Classification Results for k=4:
Accuracy: 79.41%
Classification Results for k=5:
Accuracy: 80.39%
Classification Results for k=6:
Accuracy: 77.45%
Classification Results for k=7:
Accuracy: 78.43%
Classification Results for k=8:
Accuracy: 75.49%
Classification Results for k=9:
Accuracy: 74.51%
Classification Results for k=10:
Accuracy: 74.51%


## r-NN

In [9]:
def compute_metrics_for_r(testFile, template_labels, r, true_label):
    # Compute distances to all templates
    distances = []
    for temp_path, label in template_labels.items():
        dist = getMinimumCost(str(testFile), temp_path, "WavLM")
        distances.append((dist, label))

    # Find all templates within radius r
    rnn = [(d, lab) for d, lab in distances if d <= r]
    
    predicted_label = "No Number" if not rnn else Counter(lab for d, lab in rnn).most_common(1)[0][0]
    
    return true_label, predicted_label

def evaluate_rnn_metrics(testFolder, templateFolder, r_values):
    # Load all template files and their labels once
    template_files = list(Path(templateFolder).rglob("*.pt"))
    template_labels = {}
    for temp in template_files:
        parts = temp.stem.split("_")
        prefix = parts[0]
        if prefix.isdigit():
            template_labels[str(temp)] = int(prefix)

    test_files = list(Path(testFolder).rglob("*.wav"))
    true_labels = []
    for testFile in test_files:
        parts = testFile.stem.split("_")
        prefix = parts[0]
        true_label = int(prefix) if prefix.isdigit() else "No Number"
        true_labels.append(true_label)

    results = {}
    with ThreadPoolExecutor() as executor:
        for r in r_values:
            # Parallelize predictions for all test files
            predictions = list(executor.map(
                lambda tf: compute_metrics_for_r(tf, template_labels, r, true_labels[test_files.index(tf)]),
                test_files
            ))

            # Calculate overall metrics
            correct_predictions = 0
            true_positives = 0
            false_positives = 0
            false_negatives = 0
            true_negatives = 0
            total = len(test_files)

            for true_label, predicted_label in predictions:
                # Accuracy: exact match (same number or "No Number")
                if true_label == predicted_label:
                    correct_predictions += 1
                    if true_label != "No Number":
                        true_positives += 1
                    else:
                        true_negatives += 1
                else:
                    if predicted_label != "No Number":
                        false_positives += 1
                    if true_label != "No Number":
                        false_negatives += 1

            accuracy = correct_predictions / total * 100 if total > 0 else 0.0
            precision = true_positives / (true_positives + false_positives) * 100 if (true_positives + false_positives) > 0 else 0.0
            recall = true_positives / (true_positives + false_negatives) * 100 if (true_positives + false_negatives) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

            # Print results for this r
            print(f"\nResults for r={r}:")
            print(f"Accuracy: {accuracy:.2f}%")
            print(f"Precision: {precision:.2f}%")
            print(f"Recall: {recall:.2f}%")
            print(f"F1 Score: {f1:.2f}%")

            results[r] = {
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f1": f1
            }

    return results

Usage:

In [10]:
# Validation Data
evaluate_rnn_metrics("ValidationData/WithNoNumClass", "TrainingData/TrainingFeatures/WavLMBase+/English/", np.arange(0.3, 0.45, 0.01) )


Results for r=0.3:
Accuracy: 41.96%
Precision: 74.51%
Recall: 37.25%
F1 Score: 49.67%

Results for r=0.31:
Accuracy: 46.43%
Precision: 69.23%
Recall: 44.12%
F1 Score: 53.89%

Results for r=0.32:
Accuracy: 47.32%
Precision: 65.33%
Recall: 48.04%
F1 Score: 55.37%

Results for r=0.33:
Accuracy: 50.89%
Precision: 67.07%
Recall: 53.92%
F1 Score: 59.78%

Results for r=0.34:
Accuracy: 52.68%
Precision: 63.74%
Recall: 56.86%
F1 Score: 60.10%

Results for r=0.35000000000000003:
Accuracy: 53.57%
Precision: 60.82%
Recall: 57.84%
F1 Score: 59.30%

Results for r=0.36000000000000004:
Accuracy: 53.57%
Precision: 58.25%
Recall: 58.82%
F1 Score: 58.54%

Results for r=0.37000000000000005:
Accuracy: 51.79%
Precision: 54.21%
Recall: 56.86%
F1 Score: 55.50%

Results for r=0.38000000000000006:
Accuracy: 46.43%
Precision: 47.27%
Recall: 50.98%
F1 Score: 49.06%

Results for r=0.39000000000000007:
Accuracy: 46.43%
Precision: 46.85%
Recall: 50.98%
F1 Score: 48.83%

Results for r=0.4000000000000001:
Accuracy: 4

{np.float64(0.3): {'accuracy': 41.964285714285715,
  'precision': 74.50980392156863,
  'recall': 37.254901960784316,
  'f1': 49.673202614379086},
 np.float64(0.31): {'accuracy': 46.42857142857143,
  'precision': 69.23076923076923,
  'recall': 44.11764705882353,
  'f1': 53.89221556886227},
 np.float64(0.32): {'accuracy': 47.32142857142857,
  'precision': 65.33333333333333,
  'recall': 48.03921568627451,
  'f1': 55.367231638418076},
 np.float64(0.33): {'accuracy': 50.89285714285714,
  'precision': 67.07317073170732,
  'recall': 53.92156862745098,
  'f1': 59.78260869565218},
 np.float64(0.34): {'accuracy': 52.67857142857143,
  'precision': 63.73626373626373,
  'recall': 56.86274509803921,
  'f1': 60.10362694300518},
 np.float64(0.35000000000000003): {'accuracy': 53.57142857142857,
  'precision': 60.824742268041234,
  'recall': 57.84313725490197,
  'f1': 59.29648241206031},
 np.float64(0.36000000000000004): {'accuracy': 53.57142857142857,
  'precision': 58.252427184466015,
  'recall': 58.8

## Entropy

In [17]:
def compute_metrics_for_entropy(testFile, template_labels, threshold, true_label):
    # Compute min distances to each class (0-9)
    min_dists = {lab: float('inf') for lab in range(10)}
    for temp_path, label in template_labels.items():
        dist = getMinimumCost(str(testFile), temp_path, "WavLM")
        if dist < min_dists[label]:
            min_dists[label] = dist

    dists = [min_dists[i] for i in range(10)]
    min_d = min(d for d in dists if d != float('inf'))
    sims = [math.exp(-(d - min_d)) if d != float('inf') else 0 for d in dists]
    sum_sim = sum(sims)
    probs = [s / sum_sim if sum_sim > 0 else 0 for s in sims]

    # Entropy
    entropy = -sum(p * math.log(p) for p in probs if p > 0)

    if entropy > threshold:
        predicted_label = "No Number"
    else:
        best_class = dists.index(min(dists))
        predicted_label = best_class
    
    return true_label, predicted_label

def evaluate_entropy_metrics(testFolder, templateFolder, threshold_values):
    # Load all template files and their labels once
    template_files = list(Path(templateFolder).rglob("*.pt"))
    template_labels = {}
    for temp in template_files:
        parts = temp.stem.split("_")
        prefix = parts[0]
        if prefix.isdigit():
            template_labels[str(temp)] = int(prefix)

    test_files = list(Path(testFolder).rglob("*.wav"))
    true_labels = []
    for testFile in test_files:
        parts = testFile.stem.split("_")
        prefix = parts[0]
        true_label = int(prefix) if prefix.isdigit() else "No Number"
        true_labels.append(true_label)

    results = {}
    with ThreadPoolExecutor() as executor:
        for threshold in threshold_values:
            # Parallelize predictions for all test files
            predictions = list(executor.map(
                lambda tf: compute_metrics_for_entropy(tf, template_labels, threshold, true_labels[test_files.index(tf)]),
                test_files
            ))

            # Calculate overall metrics
            true_positives = 0
            false_positives = 0
            false_negatives = 0
            true_negatives = 0
            total = len(test_files)

            for true_label, predicted_label in predictions:
                is_true_number = true_label != "No Number"
                is_pred_number = predicted_label != "No Number"
                
                if is_true_number and is_pred_number:
                    true_positives += 1
                elif is_true_number and not is_pred_number:
                    false_negatives += 1
                elif not is_true_number and is_pred_number:
                    false_positives += 1
                else:
                    true_negatives += 1

            accuracy = (true_positives + true_negatives) / total * 100 if total > 0 else 0.0
            precision = true_positives / (true_positives + false_positives) * 100 if (true_positives + false_positives) > 0 else 0.0
            recall = true_positives / (true_positives + false_negatives) * 100 if (true_positives + false_negatives) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

            # Print results for this threshold
            print(f"\nResults for threshold={threshold}:")
            print(f"Accuracy: {accuracy:.2f}%")
            print(f"Precision: {precision:.2f}%")
            print(f"Recall: {recall:.2f}%")
            print(f"F1 Score: {f1:.2f}%")

            results[threshold] = {
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f1": f1
            }

    return results

Usage:

In [19]:
evaluate_entropy_metrics("ValidationData/WithNoNumClass", "TrainingData/TrainingFeatures/WavLMBase+/English/", np.arange(2.1, 2.6, 0.1))


Results for threshold=2.1:
Accuracy: 8.93%
Precision: 0.00%
Recall: 0.00%
F1 Score: 0.00%

Results for threshold=2.2:
Accuracy: 8.93%
Precision: 0.00%
Recall: 0.00%
F1 Score: 0.00%

Results for threshold=2.3000000000000003:
Accuracy: 8.93%
Precision: 0.00%
Recall: 0.00%
F1 Score: 0.00%

Results for threshold=2.4000000000000004:
Accuracy: 91.07%
Precision: 91.07%
Recall: 100.00%
F1 Score: 95.33%

Results for threshold=2.5000000000000004:
Accuracy: 91.07%
Precision: 91.07%
Recall: 100.00%
F1 Score: 95.33%


{np.float64(2.1): {'accuracy': 8.928571428571429,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0},
 np.float64(2.2): {'accuracy': 8.928571428571429,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0},
 np.float64(2.3000000000000003): {'accuracy': 8.928571428571429,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0},
 np.float64(2.4000000000000004): {'accuracy': 91.07142857142857,
  'precision': 91.07142857142857,
  'recall': 100.0,
  'f1': 95.32710280373833},
 np.float64(2.5000000000000004): {'accuracy': 91.07142857142857,
  'precision': 91.07142857142857,
  'recall': 100.0,
  'f1': 95.32710280373833}}

# Prediction Method Performance

Minimum Cost:

In [20]:
# Validation Data
minCostValidationAccuracy = accuracies_knn("ValidationData/OnlyNumbers",
                               "TrainingData/TrainingFeatures/WavLMBase+/English",
                               [1])

Classification Results for k=1:
Accuracy: 75.49%


In [32]:
# Testing Data
minCostTestingAccuracy = accuracies_knn("TestingData/OnlyNumbers/English",
                               "TrainingData/TrainingFeatures/WavLMBase+/English",
                               [1])

Classification Results for k=1:
Accuracy: 78.43%


k-NN with k=3:

In [22]:
# Validation Data
knnValidationAccuracy = accuracies_knn("ValidationData/OnlyNumbers",
                               "TrainingData/TrainingFeatures/WavLMBase+/English",
                               [3])

Classification Results for k=3:
Accuracy: 81.37%


In [28]:
# Testing Data
knnTestingAccuracy = accuracies_knn("TestingData/OnlyNumbers/English",
                               "TrainingData/TrainingFeatures/WavLMBase+/English",
                               [3])

Classification Results for k=3:
Accuracy: 79.41%


r-NN with r=0.34

In [33]:
# Validation Data
rnnValidationAccuracy = evaluate_rnn_metrics("ValidationData/WithNoNumClass", "TrainingData/TrainingFeatures/WavLMBase+/English/", np.arange(0.34, 0.35, 0.01))


Results for r=0.34:
Accuracy: 52.68%
Precision: 63.74%
Recall: 56.86%
F1 Score: 60.10%


In [34]:
# Testing Data
rnnTestingAccuracy = evaluate_rnn_metrics("TestingData/WithNoNumClass/English", "TrainingData/TrainingFeatures/WavLMBase+/English/",np.arange(0.34, 0.35, 0.01))


Results for r=0.34:
Accuracy: 62.16%
Precision: 73.33%
Recall: 64.71%
F1 Score: 68.75%


# Final Prediction Function's Accuracy

English:

In [None]:
accuracyEnglish =  accuracies_knn("TestingData/OnlyNumbers/English", "TrainingData/TrainingFeatures/WavLMBase+/Child/Jibo", [3])

Afrikaans:

In [35]:
accuracyAfrikaans =  accuracies_knn("TestingData/OnlyNumbers/Afrikaans", "TrainingFeatures/WavLMBase+/Afrikaans", [3])

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad/04_09.wav at 16kHz.

Resampling TestingData/OnlyNumbers/Afrikaans/4kwaad