In [9]:
import webrtcvad
import torchaudio
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import json

from utils import metrics
from config import PathsConfig

In [10]:
audios_root = '/home/eugene/Datasets/LibriSpeech_augmented/audios/'
labels_root = '/home/eugene/Datasets/LibriSpeech_augmented/features_labels/'
val_meta = 'meta/val.json'

In [11]:
class WebRTC:
    def __init__(self, threshold=4):
        self.vad = webrtcvad.Vad(threshold).is_speech
        
    def __call__(self, path):
        samples, sample_rate = torchaudio.load(path)
        assert sample_rate == 16000
        samples = samples.squeeze(0).numpy()
        if len(samples) % 160 != 0:
            samples = np.concatenate([samples, np.zeros(160 - (len(samples) % 160))])
        samples = (samples * 2 ** 16).astype('int16')
        outputs = np.zeros(len(samples) // 160, dtype=bool)
        for j in range(len(samples) // 160):
            start = j * 160
            end = (j + 1) * 160
            outputs[j] = self.vad(samples[start:end].tobytes(), sample_rate)
        return outputs

In [13]:
meta = json.load(open(val_meta, 'r'))

for thr in range(4):
    model = WebRTC(thr)
    all_labels = []
    all_outputs = []
    for i in range(len(meta)):
        label_path = meta[i]['label_path'].replace(PathsConfig.features_labels, labels_root)
        labels = torch.load(label_path).squeeze(0).numpy()
        audio_path = meta[i]['audio_path'].replace(PathsConfig.augmented, audios_root)
        outputs = model(audio_path)
        all_labels.extend(labels[:len(outputs)].tolist())
        all_outputs.extend(outputs.tolist())
    webrtc_results = metrics(all_labels, all_outputs)
    print(f"Threshold = {thr},  FAR = {webrtc_results['fars'][1]:.4f},  FRR = {webrtc_results['frrs'][1]:.4f}")

Threshold = 0,  FAR = 0.7332,  FRR = 0.0094
Threshold = 1,  FAR = 0.6972,  FRR = 0.0118
Threshold = 2,  FAR = 0.5718,  FRR = 0.0238
Threshold = 3,  FAR = 0.0619,  FRR = 0.6264
