In [None]:
import sys
sys.path.append('..')
from utils import vad as vad
import IPython.display as ipd
import librosa
import soundfile as sf
from tqdm import tqdm
import os
import torch
import time
from concurrent.futures import ThreadPoolExecutor

chunk_size = 11.5
device = 'cuda'

vad_params = {
    "vad_onset": 0.500,
    "vad_offset": 0.463
}
vad_model = vad.load_vad_model(torch.device(device), model_fp="pyannote/segmentation-3.0", use_auth_token=True,)

def vad_segments(audio: str):
    audio, sr = librosa.load(audio, sr=None)
    vad_segments = vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": sr})
    vad_segments = vad.merge_chunks(
        vad_segments,
        chunk_size,
        onset=vad_params["vad_onset"],
        offset=vad_params["vad_offset"],
    )

    return audio, sr, vad_segments


# Root directory for WAV files and target directory for splits
root_dir = '/mnt/sea/yt_wavs'
target_root_dir = '/mnt/sea/yt_splits'

def save_segment(segment_data):
    id, target_dir, sr, audio, i, start_index, end_index = segment_data
    segment_audio = audio[start_index:end_index]
    segment_file_name = f'{id}__{i}__{start_index}-{end_index}__{sr}.wav'
    target_file_path = os.path.join(target_dir, segment_file_name)
    if not os.path.exists(target_file_path):
        sf.write(target_file_path, segment_audio, sr)        

for dirpath, dirnames, filenames in os.walk(root_dir):
    for filename in tqdm(filenames):
        if filename.endswith('.wav'):
            full_path = os.path.join(dirpath, filename)
            rel_path = os.path.relpath(dirpath, root_dir)
            id = filename.split(' - ')[1]

            audio, sr, segments = vad_segments(audio=full_path)

            target_dir = os.path.join(target_root_dir, rel_path, id)
            os.makedirs(target_dir, exist_ok=True)

            index_ranges = [(int(seg['start'] * sr), int(seg['end'] * sr)) for seg in segments]
            segment_data = [
                (id, target_dir, sr, audio, i, start, end)
                for i, (start, end) in enumerate(index_ranges)
            ]
            
            # Using ThreadPoolExecutor to write files in parallel
            with ThreadPoolExecutor(max_workers=30) as executor: 
                executor.map(save_segment, segment_data)
