In [1]:
from __future__ import annotations

import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd()
SRC_DIR = PROJECT_ROOT / "src"
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

In [2]:
import contextlib
import csv
import io
import subprocess
import tempfile
import time
import traceback
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import ipywidgets as widgets
from IPython.display import display
from yt_dlp import YoutubeDL
from yt_dlp.utils import DownloadError

from gigacan import segmenter, transcriber
from gigacan.corrector import RuleBasedCorrector

In [3]:
PROMPT_TEMPLATE = """以下係香港立法會會議錄音嘅一段節選，佢嘅標題係{title}，會議介紹如下：
---
{description}
---

請將呢段錄音轉寫成標準粵文，唔好寫普通話，要求區分「咁噉」「係系喺」，除非固定譯名用「俾」之外規定都用「畀」。語氣詞要用「呢」，除非係 le4 先至寫「咧」。
"""


def escape_braces(value: str) -> str:
    return (value or "").replace("{", "{{").replace("}", "}}")


def build_system_prompt(template: str, title: str, description: str) -> str:
    return template.format(
        title=escape_braces(title or ""),
        description=escape_braces(description or ""),
    )


def load_playlist_rows(csv_path: Path) -> Tuple[List[Dict[str, str]], List[str]]:
    if not csv_path.exists():
        raise FileNotFoundError(f"CSV file not found: {csv_path}")
    with csv_path.open("r", encoding="utf-8-sig", newline="") as handle:
        reader = csv.DictReader(handle)
        fieldnames = list(reader.fieldnames or [])
        rows = [dict(row) for row in reader]
    if "completed" not in fieldnames:
        fieldnames.append("completed")
        for row in rows:
            row["completed"] = "FALSE"
    else:
        for row in rows:
            completed_raw = str(row.get("completed", "")).strip().upper()
            row["completed"] = "TRUE" if completed_raw in {"TRUE", "1", "YES"} else "FALSE"
    return rows, fieldnames


def write_playlist_rows(csv_path: Path, fieldnames: List[str], rows: List[Dict[str, str]]) -> None:
    ensure_directory(csv_path.parent)
    with csv_path.open("w", encoding="utf-8", newline="") as handle:
        writer = csv.DictWriter(handle, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def ensure_directory(path: Path) -> None:
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)


def _fmt_seconds(seconds: Optional[float]) -> str:
    if seconds is None:
        return "?"
    seconds = max(0, int(seconds))
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)
    if h:
        return f"{h:d}:{m:02d}:{s:02d}"
    return f"{m:d}:{s:02d}"


def _fmt_rate(speed: Optional[float]) -> str:
    if not speed:
        return "?"
    units = ["B", "KB", "MB", "GB"]
    rate = float(speed)
    idx = 0
    while rate >= 1024 and idx < len(units) - 1:
        rate /= 1024
        idx += 1
    return f"{rate:.1f}{units[idx]}/s"


class StreamLogger(io.TextIOBase):
    """File-like helper that forwards lines to a callback."""

    def __init__(self, callback: Optional[Callable[[str], None]]):
        super().__init__()
        self._callback = callback
        self._buffer: str = ""

    def write(self, data: str) -> int:
        if not data:
            return 0
        if self._callback is None:
            return len(data)
        self._buffer += data
        while "\n" in self._buffer:
            line, self._buffer = self._buffer.split("\n", 1)
            stripped = line.strip()
            if stripped:
                self._callback(stripped)
        return len(data)

    def flush(self) -> None:
        if self._callback is None:
            return
        stripped = self._buffer.strip()
        if stripped:
            self._callback(stripped)
        self._buffer = ""


def download_audio(
    url: str,
    *,
    output_dir: Path,
    audio_format: str,
    sample_rate: Optional[int],
    sleep_interval: Optional[float],
    max_sleep_interval: Optional[float],
    skip_existing: bool,
    log_callback: Optional[Callable[[str], None]] = None,
) -> Tuple[Path, dict]:
    ensure_directory(output_dir)
    if log_callback:
        log_callback(f"[download] Resolving {url}")
    ydl_opts = {
        "quiet": True,
        "no_warnings": True,
        "noprogress": True,
        "outtmpl": str(output_dir / "%(id)s.%(ext)s"),
        "format": "bestaudio/best",
        "noplaylist": True,
        "postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": audio_format}],
        "overwrites": not skip_existing,
    }
    if sample_rate and sample_rate > 0:
        ydl_opts["postprocessor_args"] = ["-ar", str(int(sample_rate))]
    if sleep_interval and sleep_interval > 0:
        ydl_opts["sleep_interval"] = float(sleep_interval)
    if max_sleep_interval and max_sleep_interval > 0:
        if not sleep_interval or max_sleep_interval >= sleep_interval:
            ydl_opts["max_sleep_interval"] = float(max_sleep_interval)

    last_hook_emit = {"time": 0.0}
    video_label = {"value": url}

    def hook(status_dict: dict) -> None:
        if log_callback is None:
            return
        status = status_dict.get("status")
        if status == "downloading":
            now = time.monotonic()
            if now - last_hook_emit["time"] < 0.5:
                return
            last_hook_emit["time"] = now
            percent = status_dict.get("_percent_str", "").strip()
            eta = _fmt_seconds(status_dict.get("eta"))
            rate = _fmt_rate(status_dict.get("speed"))
            log_callback(
                f"[download] {video_label['value']} {percent or '?'} at {rate} (ETA {eta})"
            )
        elif status == "finished":
            log_callback("[download] Download finished, running post-processing")

    ydl_opts["progress_hooks"] = [hook]

    with YoutubeDL(ydl_opts) as ydl:
        info = ydl.extract_info(url, download=False)
        video_id = info.get("id")
        video_label["value"] = video_id or url
        target_path = output_dir / f"{video_label['value']}.{audio_format}"
        if skip_existing and target_path.exists():
            if log_callback:
                log_callback(f"[download] Reusing existing audio at {target_path}")
            return target_path, info
        if log_callback:
            log_callback(f"[download] Fetching audio for {video_label['value']}")
        info = ydl.extract_info(url, download=True)
        if not target_path.exists():
            matches = sorted(output_dir.glob(f"{video_label['value']}.*"))
            if matches:
                target_path = matches[0]
            else:
                raise FileNotFoundError(f"Expected audio file {target_path} was not created.")
        if log_callback:
            log_callback(f"[download] Saved audio to {target_path}")
        return target_path, info


def transcribe_media(
    media_path: Path,
    vtt_path: Path,
    *,
    language: str,
    enable_itn: bool,
    concurrency: int,
    max_rpm: int,
    max_retries: int,
    backoff_base: float,
    max_seg_seconds: float,
    vad_merge_ms: int,
    min_speech_ms: int,
    system_prompt: str,
    log_callback: Optional[Callable[[str], None]] = None,
) -> None:
    ensure_directory(vtt_path.parent)
    if log_callback:
        log_callback(f"[transcribe] Preparing audio from {media_path}")
    with tempfile.TemporaryDirectory() as tmpdir:
        wav_path = Path(tmpdir) / "audio_16k_mono.wav"
        try:
            segmenter.extract_mono_wav(str(media_path), str(wav_path), sr=16000)
        except subprocess.CalledProcessError as exc:
            raise RuntimeError(f"FFmpeg failed to extract audio: {exc}") from exc

        total_duration = segmenter.ffprobe_duration_seconds(str(wav_path))
        if total_duration <= 0:
            raise RuntimeError("Could not determine media duration.")
        if log_callback:
            log_callback(f"[transcribe] Source duration {total_duration:.1f}s")

        merge_ms = max(0, int(vad_merge_ms))
        min_speech = max(0, int(min_speech_ms))
        max_seg = float(max_seg_seconds)
        segs = segmenter.try_silero_vad_segments(
            str(wav_path),
            max_seg_s=max_seg,
            merge_gap_ms=merge_ms,
            min_speech_ms=min_speech,
        )
        if segs:
            if log_callback:
                log_callback(f"[transcribe] VAD produced {len(segs)} segments")
        else:
            fallback_window = max(max_seg, 30.0) if max_seg > 0 else 30.0
            segs = segmenter.fixed_window_segments(total_duration, fallback_window)
            if log_callback:
                log_callback(
                    f"[transcribe] VAD failed; falling back to fixed windows ({len(segs)} segments)"
                )

        prepared = segmenter.prepare_segments(str(wav_path), segs, tmpdir)
        if log_callback:
            log_callback(
                f"[transcribe] Prepared {len(prepared)} segment files (concurrency={concurrency})"
            )

        logger = StreamLogger(log_callback)
        with contextlib.redirect_stdout(logger), contextlib.redirect_stderr(logger):
            entries = transcriber.transcribe_segments(
                prepared,
                language=language or "zh",
                enable_itn=enable_itn,
                concurrency=max(1, int(concurrency)),
                max_rpm=max(1, int(max_rpm)),
                max_retries=max(0, int(max_retries)),
                backoff_base=float(backoff_base),
                system_prompt=system_prompt,
            )
        logger.flush()

        entries = RuleBasedCorrector().correct_entries(entries)
        if log_callback:
            log_callback(f"[transcribe] Writing VTT to {vtt_path}")
        transcriber.write_webvtt(str(vtt_path), entries)
        if log_callback:
            log_callback("[transcribe] Completed transcription")

In [4]:
csv_path_widget = widgets.Text(
    value="legco.csv",
    description="CSV Path",
    layout=widgets.Layout(width="60%"),
)
limit_widget = widgets.IntText(
    value=0,
    description="Limit",
    tooltip="Process only the first N rows (0 means all).",
)
download_dir_widget = widgets.Text(
    value="download",
    description="Audio Dir",
    layout=widgets.Layout(width="50%"),
)
vtt_dir_widget = widgets.Text(
    value="vtt",
    description="VTT Dir",
    layout=widgets.Layout(width="50%"),
)
audio_format_widget = widgets.Dropdown(
    options=["opus", "mp3", "wav", "m4a"],
    value="opus",
    description="Audio fmt",
)
sample_rate_widget = widgets.IntText(
    value=16000,
    description="Sample Hz",
)
sleep_interval_widget = widgets.FloatText(
    value=0.0,
    description="Sleep s",
    tooltip="Seconds to sleep between yt-dlp requests.",
)
max_sleep_interval_widget = widgets.FloatText(
    value=0.0,
    description="Sleep max",
    tooltip="Optional random upper bound for sleep interval.",
)
skip_existing_audio_widget = widgets.Checkbox(
    value=True,
    description="Reuse audio",
    tooltip="Skip download when audio already exists.",
)
reuse_vtt_widget = widgets.Checkbox(
    value=True,
    description="Reuse VTT",
    tooltip="Skip transcription when VTT already exists.",
)
process_completed_widget = widgets.Checkbox(
    value=False,
    description="Reprocess completed",
    tooltip="Process rows already marked completed.",
)
language_widget = widgets.Text(
    value="zh",
    description="Language",
)
no_itn_widget = widgets.Checkbox(
    value=False,
    description="Disable ITN",
)
max_seg_widget = widgets.FloatText(
    value=0,
    description="Max seg s",
    tooltip="Maximum segment length (0 disables extra splitting).",
)
vad_merge_widget = widgets.IntText(
    value=100,
    description="VAD merge ms",
)
min_speech_widget = widgets.IntText(
    value=200,
    description="Min speech ms",
)
concurrency_widget = widgets.IntSlider(
    value=4,
    min=1,
    max=8,
    step=1,
    description="Workers",
    readout=True,
)
max_rpm_widget = widgets.IntText(
    value=60,
    description="Max RPM",
)
retries_widget = widgets.IntText(
    value=3,
    description="Retries",
)
backoff_base_widget = widgets.FloatText(
    value=0.8,
    description="Backoff s",
)
prompt_label = widgets.HTML("<b>System prompt template</b> (placeholders: {title}, {description})")
system_prompt_widget = widgets.Textarea(
    value=PROMPT_TEMPLATE,
    layout=widgets.Layout(width="100%", height="200px"),
)
run_button = widgets.Button(
    description="Process CSV",
    button_style="primary",
    icon="play",
)
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=1,
    description="Idle",
    bar_style="",
    layout=widgets.Layout(width="100%"),
)
status_label = widgets.HTML("<em>Waiting for input...</em>")
log_output = widgets.Output(
    layout=widgets.Layout(border="1px solid #ddd", padding="0.5rem", max_height="320px", overflow="auto"),
)

In [5]:
def run_batch(_):
    progress_bar.bar_style = ""
    progress_bar.value = 0
    progress_bar.max = 1
    progress_bar.description = "Idle"
    log_output.clear_output()
    status_label.value = "<em>Preparing...</em>"

    csv_path = Path(csv_path_widget.value.strip() or "legco.csv").expanduser()
    download_dir = Path(download_dir_widget.value.strip() or "download").expanduser()
    vtt_dir = Path(vtt_dir_widget.value.strip() or "vtt").expanduser()
    try:
        limit = int(limit_widget.value)
    except Exception:
        limit = 0
    limit = max(0, limit)
    audio_format = audio_format_widget.value
    try:
        sample_rate_value = int(sample_rate_widget.value)
    except Exception:
        sample_rate_value = 0
    sample_rate = sample_rate_value if sample_rate_value > 0 else None
    sleep_interval = float(sleep_interval_widget.value or 0.0)
    max_sleep_interval = float(max_sleep_interval_widget.value or 0.0)
    skip_audio = bool(skip_existing_audio_widget.value)
    reuse_vtt = bool(reuse_vtt_widget.value)
    reprocess_completed = bool(process_completed_widget.value)
    language = (language_widget.value or "zh").strip()
    enable_itn = not bool(no_itn_widget.value)
    max_seg_seconds = float(max_seg_widget.value)
    vad_merge_ms = int(vad_merge_widget.value)
    min_speech_ms = int(min_speech_widget.value)
    concurrency = int(concurrency_widget.value)
    max_rpm = int(max_rpm_widget.value)
    retries = int(retries_widget.value)
    backoff_base = float(backoff_base_widget.value)
    prompt_template = system_prompt_widget.value or PROMPT_TEMPLATE

    def log_line(message: str) -> None:
        if not message:
            return
        if hasattr(log_output, "append_stdout"):
            log_output.append_stdout(f"{message}\n")
        else:
            print(message, file=sys.__stdout__)

    try:
        rows, fieldnames = load_playlist_rows(csv_path)
    except Exception as exc:
        status_label.value = f"<span style='color:#d33'>Failed to read CSV: {exc}</span>"
        log_line(traceback.format_exc())
        return

    if not rows:
        status_label.value = "<span style='color:#d33'>CSV contains no rows.</span>"
        return

    candidates: List[Tuple[int, Dict[str, str]]] = []
    for idx, row in enumerate(rows):
        if not reprocess_completed and row.get("completed") == "TRUE":
            continue
        candidates.append((idx, row))
    if limit > 0:
        candidates = candidates[:limit]
    total = len(candidates)
    if total == 0:
        status_label.value = "<span style='color:#d33'>Nothing to process.</span>"
        return

    progress_bar.max = total
    progress_bar.value = 0
    progress_bar.description = f"0/{total}"

    try:
        segmenter.check_ffmpeg()
    except Exception as exc:
        status_label.value = f"<span style='color:#d33'>FFmpeg check failed: {exc}</span>"
        log_line(str(exc))
        return

    failures = 0
    for position, (row_index, row) in enumerate(candidates, start=1):
        progress_bar.description = f"{position}/{total}"
        progress_bar.value = position - 1
        title = row.get("title") or ""
        description = row.get("description") or ""
        url = (row.get("url") or "").strip()
        row_label = row.get("title") or url or f"Row {row_index + 1}"
        log_line(f"[row {row_index + 1}] Started {row_label}")
        if not url:
            failures += 1
            status_label.value = "<span style='color:#d33'>Missing URL in CSV row.</span>"
            log_line(f"[skip] Row {row_index + 1} has no URL.")
            row["completed"] = "FALSE"
            write_playlist_rows(csv_path, fieldnames, rows)
            progress_bar.value = position
            continue

        status_label.value = f"<em>Downloading audio ({position}/{total})</em>"
        try:
            audio_path, info = download_audio(
                url,
                output_dir=download_dir,
                audio_format=audio_format,
                sample_rate=sample_rate,
                sleep_interval=sleep_interval,
                max_sleep_interval=max_sleep_interval,
                skip_existing=skip_audio,
                log_callback=log_line,
            )
        except DownloadError as exc:
            failures += 1
            status_label.value = f"<span style='color:#d33'>yt-dlp failed for {url}</span>"
            log_line(f"[error] yt-dlp failed for {url}: {exc}")
            row["completed"] = "FALSE"
            write_playlist_rows(csv_path, fieldnames, rows)
            progress_bar.value = position
            continue
        except Exception as exc:
            failures += 1
            status_label.value = f"<span style='color:#d33'>Download error: {exc}</span>"
            log_line(f"[error] Download error for {url}: {exc}")
            log_line(traceback.format_exc())
            row["completed"] = "FALSE"
            write_playlist_rows(csv_path, fieldnames, rows)
            progress_bar.value = position
            continue

        video_id = info.get("id") or audio_path.stem
        vtt_path = vtt_dir / f"{video_id}.vtt"
        system_prompt = build_system_prompt(prompt_template, title, description)

        if reuse_vtt and vtt_path.exists():
            row["completed"] = "TRUE"
            status_label.value = f"<em>VTT already exists for {video_id}, skipping.</em>"
            log_line(f"[skip] Reusing existing VTT at {vtt_path}")
            write_playlist_rows(csv_path, fieldnames, rows)
            progress_bar.value = position
            continue

        status_label.value = f"<em>Transcribing ({position}/{total})</em>"
        try:
            transcribe_media(
                audio_path,
                vtt_path,
                language=language,
                enable_itn=enable_itn,
                concurrency=concurrency,
                max_rpm=max_rpm,
                max_retries=retries,
                backoff_base=backoff_base,
                max_seg_seconds=max_seg_seconds,
                vad_merge_ms=vad_merge_ms,
                min_speech_ms=min_speech_ms,
                system_prompt=system_prompt,
                log_callback=log_line,
            )
        except Exception as exc:
            failures += 1
            status_label.value = f"<span style='color:#d33'>Transcription failed: {exc}</span>"
            log_line(f"[error] Transcription failed for {video_id}: {exc}")
            log_line(traceback.format_exc())
            row["completed"] = "FALSE"
        else:
            row["completed"] = "TRUE"
            status_label.value = f"<strong>Completed {video_id}</strong>"
            log_line(f"[done] Saved VTT to {vtt_path.resolve()}")
        finally:
            write_playlist_rows(csv_path, fieldnames, rows)
            progress_bar.value = position

    if failures:
        progress_bar.bar_style = "warning"
        status_label.value = f"<span style='color:#d33'>Finished with {failures} failure(s).</span>"
    else:
        progress_bar.bar_style = "success"
        status_label.value = "<strong>All rows completed.</strong>"

In [6]:
run_button.on_click(run_batch)

controls = widgets.VBox(
    [
        widgets.HTML(
            "<h3>Batch VTT Generator</h3><p>Download each URL from the CSV, transcribe it with the title and description as context, and mark rows as completed.</p>"
        ),
        widgets.HBox([csv_path_widget, limit_widget]),
        widgets.HBox([download_dir_widget, vtt_dir_widget]),
        widgets.HBox([audio_format_widget, sample_rate_widget]),
        widgets.HBox([sleep_interval_widget, max_sleep_interval_widget]),
        widgets.HBox([skip_existing_audio_widget, reuse_vtt_widget, process_completed_widget]),
        widgets.HBox([language_widget, no_itn_widget]),
        widgets.HBox([max_seg_widget, vad_merge_widget, min_speech_widget]),
        widgets.HBox([concurrency_widget, max_rpm_widget, retries_widget, backoff_base_widget]),
        prompt_label,
        system_prompt_widget,
        widgets.HBox([run_button, progress_bar]),
        status_label,
        log_output,
    ],
    layout=widgets.Layout(width="100%", gap="0.6rem"),
)

display(controls)

VBox(children=(HTML(value='<h3>Batch VTT Generator</h3><p>Download each URL from the CSV, transcribe it with t…