In [None]:
import os
import whisperx
from pydub import AudioSegment
import shutil
import csv

MODEL_SIZE = "large" 
DEVICE = "cpu" #cuda
COMPUTE_TYPE = "float32"

print(f"Loading Whisper model: {MODEL_SIZE}...")
model = whisperx.load_model(MODEL_SIZE, device=DEVICE, compute_type=COMPUTE_TYPE)
model_a, metadata = whisperx.load_align_model(language_code="vi", device=DEVICE)

In [None]:
def cut_audio(input_dir, output_dir, other_dir, target_duration, min_dur, max_dur, stride, writer, meta_f):
    for filename in os.listdir(input_dir):
        if not filename.lower().endswith(".wav"):
            continue

        input_path = os.path.join(input_dir, filename)
        base_name = os.path.splitext(filename)[0]
        speaker_id = base_name.split("_")[0]

        print(f"\nProcessing: {filename}")
        try:
            result = model.transcribe(input_path)
            detected_lang = result.get("language", "unknown")

            if detected_lang != "vi":
                shutil.copy(input_path, os.path.join(other_dir, filename))
                continue

            result_aligned = whisperx.align(
                result["segments"], model_a, metadata, input_path, DEVICE)

            words = [
                w for w in result_aligned["word_segments"]
                if w.get("start") is not None and w.get("end") is not None]

            audio = AudioSegment.from_file(input_path)
            total_duration = len(audio) / 1000.0

            start_word_idx = 0
            seg_index = 1

            while start_word_idx < len(words):

                actual_start_time = words[start_word_idx]["start"]

                if actual_start_time > total_duration - min_dur:
                    break

                # find end word near target duration
                end_word_idx = None
                for j in range(start_word_idx, len(words)):
                    seg_len = words[j]["end"] - actual_start_time

                    if min_dur <= seg_len <= max_dur:
                        end_word_idx = j
                        break

                    if seg_len > max_dur:
                        break

                if end_word_idx is not None:
                    start_ms = int(actual_start_time * 1000)
                    end_ms = int(words[end_word_idx]["end"] * 1000)

                    out_name = f"{base_name}_seg_{seg_index:03d}.wav"
                    audio[start_ms:end_ms].export(
                        os.path.join(output_dir, out_name),
                        format="wav"
                    )

                    text = " ".join(
                        w["word"].strip() for w in words[start_word_idx:end_word_idx + 1]
                    )
                    text = " ".join(text.split())

                    writer.writerow([
                        target_duration,
                        speaker_id,
                        out_name,
                        detected_lang,
                        round(words[end_word_idx]["end"] - actual_start_time, 3),
                        round(actual_start_time, 3),
                        round(words[end_word_idx]["end"], 3),
                        text])
                    meta_f.flush()

                    seg_index += 1

                # move start by >= 1s (by word)
                next_idx = None
                desired_time = actual_start_time + stride
                for k in range(start_word_idx + 1, len(words)):
                    if words[k]["start"] >= desired_time:
                        next_idx = k
                        break

                if next_idx is None:
                    break

                start_word_idx = next_idx

            print(f"Done: {filename} | Original duration: {total_duration}s")

        except Exception as e:
            print(f"Error {filename}: {e}")

In [None]:
input_dir = r"E:\speech_data\train_raw"
output_dir = r"E:\speech_data\train\vi"
other_lang_dir = r"E:\speech_data\train\other_languages"

os.makedirs(output_dir, exist_ok=True)
os.makedirs(other_lang_dir, exist_ok=True)

duration_list = [2.0, 3.0, 5.0, 7.0]

meta_path = os.path.join(output_dir, "metadata_all_durations.csv")
meta_exists = os.path.exists(meta_path)

meta_file = open(meta_path, "a", newline="", encoding="utf-8")
writer = csv.writer(meta_file)

if not meta_exists:
    writer.writerow([
        "target_duration", "speaker_id", "filename",
        "lang", "duration", "start", "end", "text"])

for target_duration in duration_list:
    print(f"\n===== Processing duration {target_duration}s =====")
    min_dur = target_duration - 0.5
    max_dur = target_duration + 0.5

    duration_folder = os.path.join(
        output_dir, f"dur_{target_duration:.1f}s"
    )
    os.makedirs(duration_folder, exist_ok=True)

    stride = 1.0

    cut_audio(
        input_dir=input_dir,
        output_dir=duration_folder,
        other_dir=other_lang_dir,
        target_duration=target_duration,
        min_dur=min_dur,
        max_dur=max_dur,
        stride=stride,
        writer=writer,
        meta_f=meta_file)
    
meta_file.close()
print(f"Metadata saved to: {meta_path}")