In [6]:
import torch
import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as F
from tqdm import tqdm
import os

# Load VAD model
print("Loading model VAD...")
vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                    model='silero_vad',
                                    force_reload=False,
                                    trust_repo=True)
get_speech_timestamps, _, _, _, _ = utils

def rms_normalize(waveform, target_rms=0.05, eps=1e-8):
    rms = torch.sqrt(torch.mean(waveform ** 2) + eps)
    gain = target_rms / rms
    return waveform * gain

def preprocess_audio_folder(
    input_dir,
    output_dir,
    target_sample_rate=16000,
    vad_model=None,
    get_speech_timestamps=None):
    """
    Preprocess all wav files in a folder and save results
    """

    assert vad_model is not None, "vad_model is required"
    assert get_speech_timestamps is not None, "get_speech_timestamps function is required"

    input_dir = os.path.abspath(input_dir)
    os.makedirs(output_dir, exist_ok=True)

    resampler_cache = {}

    for root, _, files in os.walk(input_dir):
        for file in tqdm(files, desc="Processing audio"):
            if not file.lower().endswith(".wav"):
                continue

            in_path = os.path.join(root, file)

            rel_path = os.path.relpath(root, input_dir)
            out_dir = os.path.join(output_dir, rel_path)
            os.makedirs(out_dir, exist_ok=True)
            out_path = os.path.join(out_dir, file)

            try:
                # Load audio
                waveform, orig_sr = torchaudio.load(in_path)

                if waveform.numel() == 0:
                    continue

                # Resample
                if orig_sr != target_sample_rate:
                    if orig_sr not in resampler_cache:
                        resampler_cache[orig_sr] = T.Resample(orig_sr, target_sample_rate)
                    waveform = resampler_cache[orig_sr](waveform)

                # Mono
                if waveform.shape[0] > 1:
                    waveform = torch.mean(waveform, dim=0, keepdim=True)

                # High-pass filter (>80Hz)
                waveform = F.highpass_biquad(
                    waveform,
                    sample_rate=target_sample_rate,
                    cutoff_freq=80)

                # RMS Normalize
                waveform = rms_normalize(waveform)

                # VAD
                wav_1d = waveform.squeeze()
                speech_timestamps = get_speech_timestamps(
                    wav_1d,
                    vad_model,
                    sampling_rate=target_sample_rate)

                if len(speech_timestamps) == 0:
                    continue

                speech_segments = [
                    wav_1d[ts["start"]:ts["end"]]
                    for ts in speech_timestamps
                    if ts["end"] > ts["start"]
                ]

                if len(speech_segments) == 0:
                    continue

                clean_waveform = torch.cat(speech_segments).unsqueeze(0)
                clean_waveform = clean_waveform.clamp(-1.0, 1.0)

                torchaudio.save(out_path, clean_waveform, target_sample_rate)

            except Exception as e:
                print(f"❌ Error {in_path}: {e}")

    print(f"✅ Done → {output_dir}")

Loading model VAD...


Using cache found in C:\Users\PC1/.cache\torch\hub\snakers4_silero-vad_master


### Preprocessing for VSASV

In [None]:
preprocess_audio_folder(
    input_dir=r"speech_data\wav\VSASV",
    output_dir=r"speech_data\clean_wav\VSASV",
    target_sample_rate=16000,
    vad_model=vad_model,
    get_speech_timestamps=get_speech_timestamps)

Processing audio: 0it [00:00, ?it/s]
Processing audio: 100%|██████████| 25/25 [00:01<00:00, 18.16it/s]
Processing audio: 100%|██████████| 63/63 [00:03<00:00, 18.01it/s]
Processing audio: 100%|██████████| 47/47 [00:02<00:00, 18.43it/s]
Processing audio: 100%|██████████| 95/95 [00:06<00:00, 15.62it/s]
Processing audio: 100%|██████████| 57/57 [00:03<00:00, 18.05it/s]
Processing audio: 100%|██████████| 20/20 [00:00<00:00, 32.15it/s]
Processing audio: 100%|██████████| 50/50 [00:01<00:00, 32.49it/s]
Processing audio: 100%|██████████| 10/10 [00:00<00:00, 10.88it/s]
Processing audio: 100%|██████████| 15/15 [00:00<00:00, 20.00it/s]
Processing audio: 100%|██████████| 17/17 [00:00<00:00, 17.63it/s]
Processing audio: 100%|██████████| 13/13 [00:00<00:00, 27.68it/s]
Processing audio: 100%|██████████| 15/15 [00:00<00:00, 18.52it/s]
Processing audio: 100%|██████████| 100/100 [00:05<00:00, 19.84it/s]
Processing audio: 100%|██████████| 63/63 [00:04<00:00, 14.32it/s]
Processing audio: 100%|██████████| 87

✅ Done → E:\speech_data\clean_wav\VSASV





### Preprocessing for VoxVietnam

In [None]:
preprocess_audio_folder(
    input_dir=r"speech_data\wav\Vox_train",
    output_dir=r"speech_data\clean_wav\Vox_train",
    target_sample_rate=16000,
    vad_model=vad_model,
    get_speech_timestamps=get_speech_timestamps)

Processing audio: 0it [00:00, ?it/s]
Processing audio: 100%|██████████| 191/191 [00:47<00:00,  4.04it/s]
Processing audio: 100%|██████████| 149/149 [00:39<00:00,  3.80it/s]
Processing audio: 100%|██████████| 90/90 [00:24<00:00,  3.65it/s]
Processing audio: 100%|██████████| 58/58 [00:13<00:00,  4.14it/s]
Processing audio: 100%|██████████| 63/63 [00:17<00:00,  3.52it/s]
Processing audio: 100%|██████████| 58/58 [00:17<00:00,  3.39it/s]
Processing audio: 100%|██████████| 35/35 [00:10<00:00,  3.43it/s]
Processing audio: 100%|██████████| 43/43 [00:11<00:00,  3.72it/s]
Processing audio: 100%|██████████| 36/36 [00:11<00:00,  3.20it/s]
Processing audio: 100%|██████████| 37/37 [00:11<00:00,  3.11it/s]
Processing audio: 100%|██████████| 23/23 [00:06<00:00,  3.72it/s]
Processing audio: 100%|██████████| 16/16 [00:03<00:00,  4.10it/s]
Processing audio: 100%|██████████| 13/13 [00:03<00:00,  3.32it/s]
Processing audio: 100%|██████████| 17/17 [00:04<00:00,  4.02it/s]
Processing audio: 100%|██████████| 

✅ Done → E:\speech_data\clean_wav\Vox_train





### Merge

In [None]:
import shutil
import csv

VSASV_DIR = r"speech_data\clean_wav\VSASV"
VOX_DIR   = r"speech_data\clean_wav\Vox_train"
OUT_DIR   = r"speech_data\train_raw"
next_speaker_id = 0
metadata = []

os.makedirs(OUT_DIR, exist_ok=True)

def merge_dataset(root_dir, dataset_name):
    global next_speaker_id

    for speaker in sorted(os.listdir(root_dir)):
        speaker_path = os.path.join(root_dir, speaker)
        if not os.path.isdir(speaker_path):
            continue

        new_sid = next_speaker_id
        next_speaker_id += 1

        utt_counter = 0

        for wav in sorted(os.listdir(speaker_path)):
            if not wav.lower().endswith(".wav"):
                continue

            new_name = f"id{new_sid:05d}_{utt_counter:05d}.wav"

            src = os.path.join(speaker_path, wav)
            dst = os.path.join(OUT_DIR, new_name)

            if os.path.exists(dst):
                raise RuntimeError(f"Overwrite detected: {dst}")

            shutil.copy2(src, dst)

            metadata.append({
                "speaker_id": f"id{new_sid:05d}",
                "utt_id": f"{utt_counter:05d}",
                "filename": new_name,
                "source_dataset": dataset_name,
                "source_speaker": speaker,
                "source_file": wav
            })

            utt_counter += 1

In [None]:
merge_dataset(VOX_DIR, "VoxVietnam")
merge_dataset(VSASV_DIR, "VSASV")

META_DIR   = r"speech_data\train_raw\metadata"
os.makedirs(META_DIR, exist_ok=True)

meta_path = os.path.join(META_DIR, "metadata.csv")

with open(meta_path, "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "speaker_id",
            "utt_id",
            "filename",
            "source_dataset",
            "source_speaker",
            "source_file"])
    
    writer.writeheader()
    writer.writerows(metadata)

In [14]:
print("Total wav:", len([f for f in os.listdir(OUT_DIR) if f.endswith(".wav")]))
print("Metadata rows:", len(metadata))

Total wav: 258168
Metadata rows: 258168


DOUBLE-CHECK

In [15]:
from collections import defaultdict

meta_files = 0
meta_speakers = set()

meta_by_dataset = defaultdict(int)
meta_by_speaker = defaultdict(int)
meta_by_dataset_speaker = defaultdict(lambda: defaultdict(int))

with open(meta_path, encoding="utf-8") as f:
    reader = csv.DictReader(f)
    rows = list(reader)

for r in rows:
    meta_files += 1
    meta_speakers.add(r["speaker_id"])
    meta_by_dataset[r["source_dataset"]] += 1
    meta_by_speaker[r["speaker_id"]] += 1
    meta_by_dataset_speaker[r["source_dataset"]][r["source_speaker"]] += 1


In [16]:
def scan_root(root_dir):
    speaker_count = 0
    file_count = 0
    files_per_speaker = {}

    for speaker in os.listdir(root_dir):
        sp_path = os.path.join(root_dir, speaker)
        if not os.path.isdir(sp_path):
            continue

        speaker_count += 1
        wavs = [f for f in os.listdir(sp_path) if f.lower().endswith(".wav")]
        files_per_speaker[speaker] = len(wavs)
        file_count += len(wavs)

    return speaker_count, file_count, files_per_speaker

vox_spk, vox_files, vox_map = scan_root(VOX_DIR)
vsasv_spk, vsasv_files, vsasv_map = scan_root(VSASV_DIR)

In [18]:
print("\n===== GLOBAL CHECK =====")
print(f"Metadata files : {meta_files}")
print(f"VoxVietnam wav : {vox_files}")
print(f"VSASV wav      : {vsasv_files}")
print(f"TOTAL wav  : {vox_files + vsasv_files}")

if meta_files != vox_files + vsasv_files:
    print("Problem: File count mismatch!")
else:
    print("✅ File count match")

#--------------------------------------
print("\n===== DATASET CHECK =====")
print(f"VoxVietnam meta : {meta_by_dataset['VoxVietnam']} | original : {vox_files}")
print(f"VSASV meta      : {meta_by_dataset['VSASV']} | original : {vsasv_files}")

#--------------------------------------
print("\n===== SPEAKER-LEVEL CHECK =====")
def check_speaker(dataset_name, meta_map, raw_map):
    errors = 0
    for spk, raw_count in raw_map.items():
        meta_count = meta_map.get(spk, 0)
        if raw_count != meta_count:
            print(f"Problem: {dataset_name} | {spk}: original={raw_count}, meta={meta_count}")
            errors += 1
    if errors == 0:
        print(f"✅ {dataset_name}: all speakers match")
check_speaker("VoxVietnam", meta_by_dataset_speaker["VoxVietnam"], vox_map)
check_speaker("VSASV", meta_by_dataset_speaker["VSASV"], vsasv_map)


===== GLOBAL CHECK =====
Metadata files : 258168
VoxVietnam wav : 161006
VSASV wav      : 97162
TOTAL wav  : 258168
✅ File count match

===== DATASET CHECK =====
VoxVietnam meta : 161006 | original : 161006
VSASV meta      : 97162 | original : 97162

===== SPEAKER-LEVEL CHECK =====
✅ VoxVietnam: all speakers match
✅ VSASV: all speakers match
