In [None]:
import os
import whisperx
from pydub import AudioSegment
import shutil
import torch
import csv

# --- CẤU HÌNH MÔI TRƯỜNG ---
venv_path = os.path.join(os.getcwd(), "venv", "Lib", "site-packages", "nvidia", "cublas", "bin")
if os.path.exists(venv_path):
    os.environ["PATH"] += os.pathsep + venv_path

ffmpeg_path = r"D:\Study\7-SP26\DATxSLP\ffmpeg.exe"
ffprobe_path = r"D:\Study\7-SP26\DATxSLP\ffprobe.exe" 
AudioSegment.converter = ffmpeg_path
AudioSegment.ffprobe = ffprobe_path
os.environ["PATH"] += os.pathsep + r"D:\Study\7-SP26\DATxSLP"

# Fix torch.load
if hasattr(torch, 'serialization'):
    _original_torch_load = torch.load
    def _custom_torch_load(*args, **kwargs):
        if 'weights_only' not in kwargs: kwargs['weights_only'] = False
        return _original_torch_load(*args, **kwargs)
    torch.load = _custom_torch_load

# --- CẤU HÌNH ---
input_dir = r"D:\Study\7-SP26\DATxSLP\Data_after_preprocessing\VoxVietnam_sampled_audio_16k_clean\train_small_clean\id00000"
output_dir = r"D:\Study\7-SP26\DATxSLP\Data_after_cut\test_output"
english_dir = r"D:\Study\7-SP26\DATxSLP\Data_after_cut\file_english"

os.makedirs(output_dir, exist_ok=True)
os.makedirs(english_dir, exist_ok=True)

MODEL_SIZE = "large-v2" 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPUTE_TYPE = "int8_float16" 
ASR_OPTIONS = {"n_mels": 128}

TARGET_SECONDS = 3.0
STRIDE_SECONDS = 1.0 
MARGIN = 0.5
MIN_DUR = TARGET_SECONDS - MARGIN 
MAX_DUR = TARGET_SECONDS + MARGIN 

# Metadata
meta_path = os.path.join(output_dir, r"D:\Study\7-SP26\DATxSLP\Data_after_cut\metadata.csv")
meta_exists = os.path.exists(meta_path)
meta_file = open(meta_path, "a", newline="", encoding="utf-8")
writer = csv.writer(meta_file)
if not meta_exists:
    writer.writerow(["target_duration", "speaker_id", "filename", "lang", "duration", "start", "end", "text"])
    meta_file.flush()

print(f"Loading Whisper {MODEL_SIZE} on {DEVICE}...")
model = whisperx.load_model(MODEL_SIZE, device=DEVICE, compute_type=COMPUTE_TYPE, asr_options=ASR_OPTIONS)
align_models = {}

# --- [HÀM MỚI] KIỂM TRA TEXT RÁC (LẶP TỪ) ---
def is_valid_text(words_list):
    if not words_list: return False
    
    # Lấy danh sách các từ (bỏ qua dấu câu cơ bản nếu cần)
    text_content = [w["word"].strip().lower() for w in words_list]
    
    # 1. Kiểm tra độ đa dạng của từ (Unique Ratio)
    # Ví dụ: "con con con con" -> 4 từ, nhưng chỉ có 1 từ duy nhất -> Ratio = 0.25 -> RÁC
    unique_words = set(text_content)
    unique_ratio = len(unique_words) / len(text_content)
    
    if len(text_content) > 3 and unique_ratio < 0.4:
        return False # Quá nhiều từ lặp lại
        
    return True
# ---------------------------------------------

print(f"\n>>> CHẾ ĐỘ: CỬA SỔ TRƯỢT INDEX + LỌC RÁC <<<")

try:
    for filename in os.listdir(input_dir):
        if not filename.lower().endswith((".wav", ".mp3")): continue
        
        print(f"\n--- Xử lý: {filename} ---")
        input_path = os.path.join(input_dir, filename)
        base_name = os.path.splitext(filename)[0]
        speaker_id = base_name.split("_")[0]

        try:
            # 1. Transcribe
            audio_data = whisperx.load_audio(input_path)
            result = model.transcribe(audio_data, batch_size=16)
            detected_lang = result.get("language", "unknown")

            if detected_lang == "en":
                shutil.copy(input_path, os.path.join(english_dir, filename))
                continue 

            if detected_lang not in align_models:
                if detected_lang == "vi":
                    align_models["vi"] = whisperx.load_align_model(language_code="vi", device=DEVICE, model_name="nguyenvulebinh/wav2vec2-base-vi")
                else:
                    align_models[detected_lang] = whisperx.load_align_model(language_code=detected_lang, device=DEVICE)
            
            result_aligned = whisperx.align(result["segments"], align_models[detected_lang][0], align_models[detected_lang][1], input_path, DEVICE)
            words = [w for w in result_aligned["word_segments"] if "start" in w and "end" in w]

            if not words: continue

            # 2. CẮT AUDIO (Index-Based)
            audio = AudioSegment.from_file(input_path)
            start_word_idx = 0
            seg_index = 1
            
            while start_word_idx < len(words):
                actual_start_time = words[start_word_idx]["start"]
                end_word_idx = None
                valid_segment_found = False
                
                # A. Gom từ
                for j in range(start_word_idx, len(words)):
                    current_word = words[j]
                    current_dur = current_word["end"] - actual_start_time
                    
                    if current_dur > MAX_DUR: break 
                    if current_dur >= MIN_DUR:
                        end_word_idx = j
                        valid_segment_found = True
                        break 

                # B. Xuất file (CÓ THÊM BƯỚC CHECK RÁC)
                if valid_segment_found and end_word_idx is not None:
                    current_segment_words = words[start_word_idx : end_word_idx + 1]
                    
                    # --- [FIX MỚI] CHECK LẶP TỪ ---
                    if is_valid_text(current_segment_words):
                        final_dur = current_segment_words[-1]["end"] - actual_start_time
                        start_ms = int(actual_start_time * 1000)
                        end_ms = int(current_segment_words[-1]["end"] * 1000)
                        
                        seg_audio = audio[start_ms:end_ms]
                        out_name = f"{base_name}_seg_{seg_index:03d}.wav"
                        seg_audio.export(os.path.join(output_dir, out_name), format="wav")
                        
                        text_seg = " ".join([w["word"].strip() for w in current_segment_words])
                        writer.writerow([
                            TARGET_SECONDS, speaker_id, out_name, detected_lang, 
                            round(final_dur, 3), 
                            round(actual_start_time, 3), 
                            round(current_segment_words[-1]["end"], 3), 
                            text_seg
                        ])
                        meta_file.flush()
                        seg_index += 1
                    else:
                        pass 
                        # print(f"-> Bỏ qua đoạn rác (Lặp từ): {actual_start_time}s")

                # C. Sliding Window (Index)
                next_start_idx = None
                desired_next_time = actual_start_time + STRIDE_SECONDS
                for k in range(start_word_idx + 1, len(words)):
                    if words[k]["start"] >= desired_next_time:
                        next_start_idx = k
                        break
                
                if next_start_idx is None: break 
                start_word_idx = next_start_idx

            print(f"-> Đã lưu: {seg_index-1} file.")

        except Exception as e:
            print(f"Lỗi {filename}: {e}")

finally:
    meta_file.close()
    print("\n>>> HOÀN TẤT <<<")