<a href="https://colab.research.google.com/github/dungdt-infopstats/Device-Directed-Speech-Segmentation/blob/main/src_prototype/Transcribe_Output_DDSS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

tridungdo_test_100_150_path = kagglehub.dataset_download('tridungdo/test-100-150')
tridungdo_wav2vec2_asr_100_150_true_outputs_path = kagglehub.dataset_download('tridungdo/wav2vec2-asr-100-150-true-outputs')

print('Data source import complete.')


In [None]:
!pip install faster-whisper
!pip install pydub
!pip install jiwer

In [None]:
import shutil

In [None]:
import pandas as pd
import numpy as np
import os
import tempfile
from typing import List, Dict, Any, Optional, Tuple
from pydub import AudioSegment
from faster_whisper import WhisperModel, BatchedInferencePipeline
from multiprocessing import Pool, cpu_count, Manager
from tqdm import tqdm
import ast
import shutil
import pickle
from functools import partial

# ==========================================================
# Helper: Convert vector -> time ranges
# ==========================================================
def vector_to_time_ranges(
    vector: List[int],
    frame_dur: float,
    hop_dur: float,
    pad: int = 15
) -> List[Dict[str, float]]:
    """
    Convert binary vector (0/1) -> list of {"start", "end"} in seconds.
    Merge segments when short gaps exist (controlled by pad).
    """
    ranges = []
    n = len(vector)
    i = 0

    while i < n:
        if vector[i] == 1:
            start = i * hop_dur
            j = i
            while j < n and any(vector[max(0, j - pad):min(n, j + pad + 1)]):
                j += 1
            end = j * hop_dur + frame_dur
            ranges.append({"start": start, "end": end})
            i = j
        else:
            i += 1
    return ranges


# ==========================================================
# Helper: Load audio and slice
# ==========================================================
def load_audio_as_pydub(path: str, sr: Optional[int] = None) -> AudioSegment:
    """Load audio file using pydub with optional sample rate conversion."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Audio file not found: {path}")

    audio = AudioSegment.from_file(path)
    if sr is not None and audio.frame_rate != sr:
        audio = audio.set_frame_rate(sr)
    return audio

def slice_audio_segment(audio: AudioSegment, start: float, end: float) -> AudioSegment:
    """Slice audio segment from start to end (in seconds)."""
    start_ms = max(0, int(start * 1000))
    end_ms = min(len(audio), int(end * 1000))
    return audio[start_ms:end_ms]


# ==========================================================
# Global function for transcribing segments (needed for multiprocessing)
# ==========================================================
def transcribe_segment_worker(audio_bytes: bytes, model_size: str, device: str, batch_mode: bool = True) -> str:
    """Transcribe a single audio segment in a worker process."""
    try:
        # Initialize model in worker process
        model = WhisperModel(model_size, device=device)
        if batch_mode:
            model = BatchedInferencePipeline(model=model)

        # Create temporary file for this segment
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
            tmp_path = tmp_file.name
            tmp_file.write(audio_bytes)

        try:
            # Transcribe using Whisper
            if not batch_mode:
                segments, _ = model.transcribe(tmp_path)
            else:
                segments, _ = model.transcribe(tmp_path, batch_size=128)
            text = " ".join([seg.text for seg in segments]).strip()

            return text
        finally:
            if os.path.exists(tmp_path):
                os.unlink(tmp_path)

    except Exception as e:
        print(f"Error transcribing segment: {e}")
        return ""


def process_segment_worker(args: Tuple[bytes, Dict[str, float], int, str, str, str, str, bool]) -> Dict[str, Any]:
    """Process a single segment in a worker process."""
    audio_bytes, time_range, idx, save_segments_dir, base_name, model_size, device, batch_mode = args

    try:
        # Reconstruct audio segment from bytes
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
            tmp_path = tmp_file.name
            tmp_file.write(audio_bytes)

        audio_seg = AudioSegment.from_file(tmp_path)
        os.unlink(tmp_path)

        seg = slice_audio_segment(audio_seg, time_range["start"], time_range["end"])

        # Save segment if requested
        if save_segments_dir:
            os.makedirs(save_segments_dir, exist_ok=True)
            seg_name = f"{base_name}_segment_{idx}.wav"
            seg.export(os.path.join(save_segments_dir, seg_name), format="wav")

        if len(seg) < 100:  # too short
            return {
                "segment_index": idx,
                "start": time_range["start"],
                "end": time_range["end"],
                "recognized_text": "",
            }

        # Convert segment to bytes for transcription
        with tempfile.NamedTemporaryFile(suffix=".wav") as tmp_file:
            seg.export(tmp_file.name, format="wav")
            with open(tmp_file.name, 'rb') as f:
                seg_bytes = f.read()

        hyp = transcribe_segment_worker(seg_bytes, model_size, device, batch_mode)

        return {
            "segment_index": idx,
            "start": time_range["start"],
            "end": time_range["end"],
            "recognized_text": hyp,
        }

    except Exception as e:
        print(f"Error processing segment {idx}: {e}")
        return {
            "segment_index": idx,
            "start": time_range["start"],
            "end": time_range["end"],
            "recognized_text": "",
        }


# ==========================================================
# Process single file worker (for file-level parallelization)
# ==========================================================
def process_file_worker(args: Tuple[str, Dict[str, Any]]) -> pd.DataFrame:
    """Process a single file in a worker process."""
    audio_path, params = args

    try:
        # Extract parameters
        sr = params.get('sr')
        vector = params.get('vector')
        frame_dur = params.get('frame_dur', 0.025)
        hop_dur = params.get('hop_dur', 0.01)
        pad = params.get('pad', 15)
        model_size = params.get('model_size', 'medium')
        device = params.get('device', 'cpu')
        batch_mode = params.get('batch_mode', True)
        save_segments_dir = params.get('save_segments_dir')
        fallback_on_no_segments = params.get('fallback_on_no_segments', True)
        segment_level_parallel = params.get('segment_level_parallel', False)
        max_workers = params.get('max_workers', 4)

        # Load audio
        try:
            audio_seg = load_audio_as_pydub(audio_path, sr)
        except Exception as e:
            print(f"Error loading audio {audio_path}: {e}")
            return pd.DataFrame()

        # Get time ranges
        time_ranges = vector_to_time_ranges(vector, frame_dur, hop_dur, pad)

        if not time_ranges:
            print(f"No active segments found for {audio_path}")
            if not fallback_on_no_segments:
                return pd.DataFrame()
            time_ranges = [{"start": 0, "end": 0}]

        base_name = os.path.splitext(os.path.basename(audio_path))[0]

        # Convert audio to bytes for passing to workers
        with tempfile.NamedTemporaryFile(suffix=".wav") as tmp_file:
            audio_seg.export(tmp_file.name, format="wav")
            with open(tmp_file.name, 'rb') as f:
                audio_bytes = f.read()

        results = []

        if segment_level_parallel and len(time_ranges) > 1:
            # Process segments in parallel
            args_list = [
                (audio_bytes, time_range, idx, save_segments_dir, base_name, model_size, device, batch_mode)
                for idx, time_range in enumerate(time_ranges)
            ]

            with Pool(processes=min(max_workers, len(args_list))) as pool:
                results = pool.map(process_segment_worker, args_list)
        else:
            # Process segments sequentially within the file
            for idx, time_range in enumerate(time_ranges):
                args = (audio_bytes, time_range, idx, save_segments_dir, base_name, model_size, device, batch_mode)
                result = process_segment_worker(args)
                results.append(result)

        results.sort(key=lambda x: x["segment_index"])
        return pd.DataFrame(results)

    except Exception as e:
        print(f"Error processing file {audio_path}: {e}")
        return pd.DataFrame()


# ==========================================================
# Multi-processing ASR Processor
# ==========================================================
class ASRProcessor:
    def __init__(self, model_size="medium", device="cpu", max_workers=None):
        self.model_size = model_size
        self.device = device
        self.max_workers = max_workers or cpu_count()

    def process_file(
        self,
        audio_path: str,
        sr: Optional[int],
        vector: List[int],
        frame_dur: float,
        hop_dur: float,
        pad: int,
        show_progress: bool = True,
        fallback_on_no_segments: bool = True,
        save_segments_dir: Optional[str] = None,
        segment_level_parallel: bool = False
    ) -> pd.DataFrame:
        """Process a single file (wrapper for compatibility)."""
        params = {
            'sr': sr,
            'vector': vector,
            'frame_dur': frame_dur,
            'hop_dur': hop_dur,
            'pad': pad,
            'model_size': self.model_size,
            'device': self.device,
            'batch_mode': True,
            'save_segments_dir': save_segments_dir,
            'fallback_on_no_segments': fallback_on_no_segments,
            'segment_level_parallel': segment_level_parallel,
            'max_workers': self.max_workers
        }

        return process_file_worker((audio_path, params))

    def process_multiple_files(
        self,
        file_params_list: List[Tuple[str, Dict[str, Any]]],
        show_progress: bool = True,
        file_level_parallel: bool = True
    ) -> List[pd.DataFrame]:
        """Process multiple files with multiprocessing."""
        if file_level_parallel and len(file_params_list) > 1:
            # Process files in parallel
            with Pool(processes=min(self.max_workers, len(file_params_list))) as pool:
                if show_progress:
                    results = []
                    with tqdm(total=len(file_params_list), desc="Processing files") as pbar:
                        for result in pool.imap(process_file_worker, file_params_list):
                            results.append(result)
                            pbar.update(1)
                else:
                    results = pool.map(process_file_worker, file_params_list)
            return results
        else:
            # Process files sequentially
            results = []
            iterator = tqdm(file_params_list, desc="Processing files") if show_progress else file_params_list
            for file_params in iterator:
                result = process_file_worker(file_params)
                results.append(result)
            return results


# ==========================================================
# Safe parsing function
# ==========================================================
def safe_parse_predictions(preds_str: str) -> Optional[List[int]]:
    try:
        if isinstance(preds_str, (list, np.ndarray)):
            return list(preds_str)
        if isinstance(preds_str, str):
            try:
                return ast.literal_eval(preds_str)
            except:
                return eval(preds_str)
        return None
    except Exception as e:
        print(f"Error parsing predictions: {e}")
        return None


# ==========================================================
# Process DataFrame template with multiprocessing
# ==========================================================
def process_from_template_df(
    df: pd.DataFrame,
    model_size: str = "medium",
    device: str = "cpu",
    frame_dur: float = 0.025,
    hop_dur: float = 0.01,
    pad: int = 15,
    max_workers: int = None,
    audio_root_path: str = "/kaggle/input/ddss-aug-ratio-test/test",
    fallback_on_no_segments: bool = True,
    file_level_parallel: bool = True,
    segment_level_parallel: bool = False,
    save_segments_dir: Optional[str] = None
) -> pd.DataFrame:
    """
    Process DataFrame with multiprocessing support.

    Args:
        file_level_parallel: If True, process files in parallel
        segment_level_parallel: If True, process segments within each file in parallel
    """
    max_workers = max_workers or cpu_count()
    print(f"Initializing ASR Processor with model_size={model_size}, device={device}, max_workers={max_workers}")
    print(f"File-level parallel: {file_level_parallel}, Segment-level parallel: {segment_level_parallel}")

    processor = ASRProcessor(model_size=model_size, device=device, max_workers=max_workers)

    # Prepare file parameters list
    file_params_list = []
    for idx, row in df.iterrows():
        try:
            audio_id = row["id"]
            folder_name = row.get("id_f", "")
            audio_path = os.path.join(audio_root_path, folder_name, f"{audio_id}.wav")

            if not os.path.exists(audio_path):
                print(f"Warning: Audio file not found: {audio_path}")
                continue

            preds = safe_parse_predictions(row["preds"])
            if preds is None:
                print(f"Error parsing predictions for {audio_id}")
                continue

            params = {
                'sr': row.get("sampling_rate", None),
                'vector': preds,
                'frame_dur': frame_dur,
                'hop_dur': hop_dur,
                'pad': pad,
                'model_size': model_size,
                'device': device,
                'batch_mode': True,
                'save_segments_dir': save_segments_dir,
                'fallback_on_no_segments': fallback_on_no_segments,
                'segment_level_parallel': segment_level_parallel,
                'max_workers': max_workers,
                'row_data': row.to_dict()  # Store row data for later use
            }

            file_params_list.append((audio_path, params))

        except Exception as e:
            print(f"Error preparing row {idx}: {e}")
            continue

    if not file_params_list:
        print("Warning: No valid files to process")
        return pd.DataFrame()

    print(f"Processing {len(file_params_list)} files...")

    # Process files
    results_list = processor.process_multiple_files(
        file_params_list,
        show_progress=True,
        file_level_parallel=file_level_parallel
    )

    # Combine results and add original row data
    all_results = []
    for i, df_segments in enumerate(results_list):
        if df_segments.empty:
            continue

        # Get original row data
        _, params = file_params_list[i]
        row_data = params['row_data']

        # Add original columns to segments
        for col, value in row_data.items():
            if col not in df_segments.columns:
                # Convert complex types to string to avoid length mismatch
                if isinstance(value, (list, np.ndarray, dict)):
                    value = str(value)
                df_segments[col] = value

        all_results.append(df_segments)

    if not all_results:
        print("Warning: No results to concatenate")
        return pd.DataFrame()

    final_df = pd.concat(all_results, ignore_index=True)
    print(f"Final dataset shape: {final_df.shape}")
    return final_df


# ==========================================================
# Save results
# ==========================================================
def save_results(df: pd.DataFrame, output_path: str):
    if df.empty:
        print("Warning: DataFrame is empty, nothing to save")
        return
    df.to_csv(output_path, index=False)
    print(f"Results saved to: {output_path}")



In [None]:
import json
import pandas as pd
with open("/kaggle/input/wav2vec2-asr-100-150-true-outputs/wav2vec2-asr-100-150-true.json") as f:
    data = json.load(f)

df = pd.DataFrame(data)

In [None]:
threshold = 0.5

In [None]:
# threshold

df['preds'] = df['outputs'].apply(lambda x: [int(v >= threshold) for v in x])

In [None]:
df['id_f'] = df['id'].str.rsplit('_', n = 1).str[0]

In [None]:
# ==========================================================
# Main execution
# ==========================================================
if __name__ == "__main__":
    try:
        print("Starting ASR evaluation...")
        results_df = process_from_template_df(
            df=df,  # must define 'merge'
            model_size="base",
            frame_dur = 0.025,
            hop_dur = 0.01,
            pad = 15,
            device="cuda",
            max_workers=4,
            audio_root_path="/kaggle/input/test-100-150/test",
            fallback_on_no_segments=True,
            file_level_parallel=True,
            segment_level_parallel=False,
            save_segments_dir="segment"
        )
        if not results_df.empty:
            save_results(results_df, "asr_res_100_150.csv")
            result_df.to_json('asr_res_100_150.json', orient = 'records', force_ascii=False)
        else:
            print("No results to save.")

        save_dir = 'save'
        if os.path.exists(save_dir):
            shutil.make_archive("cut_segments_archive", "zip", save_dir)
            print("Segments saved and zipped to cut_segments_archive.zip")

    except NameError as e:
        print(f"Error: {e}")
        print("Make sure the 'merge' DataFrame is defined before running this script.")
    except Exception as e:
        print(f"Unexpected error: {e}")
        import traceback
        traceback.print_exc()

In [None]:
import shutil

save_dir = 'save'
if os.path.exists(save_dir):
    shutil.make_archive("cut_segments_archive", "zip", save_dir)
    print("Segments saved and zipped to cut_segments_archive.zip")