# Ambara — Colab Runner

Automated notebook for running **clip extraction** and **ASR training** on
Colab with Google Drive persistence.

**Usage:** Edit the configuration cell below, then **Runtime > Run All**.

- Repo is cloned to the VM for fast I/O.
- `data/` and `models/` are symlinked to Google Drive so they persist across sessions.

In [None]:
# ---- Google Drive ----
DRIVE_ROOT = "ambara"

# ---- Repository ----
REPO_URL = "https://github.com/ny-randriantsarafara/ny-feoko.git"
REPO_BRANCH = "main"

# ---- HuggingFace (leave empty to skip login) ----
HF_TOKEN = ""

# ---- Clip Extraction (set EXTRACT_ENABLED = False to skip) ----
EXTRACT_ENABLED = True
EXTRACT_INPUT = "data/input/my-recording.wav"
EXTRACT_WHISPER_MODEL = "small"
EXTRACT_WHISPER_HF = ""  # HuggingFace model ID — overrides EXTRACT_WHISPER_MODEL
EXTRACT_VAD_THRESHOLD = 0.35
EXTRACT_SPEECH_THRESHOLD = 0.35
EXTRACT_LABEL = ""

# ---- ASR Training (set TRAIN_ENABLED = False to skip) ----
TRAIN_ENABLED = True
TRAIN_DATASET = "data/training/my-dataset"
TRAIN_BASE_MODEL = "openai/whisper-small"
TRAIN_OUTPUT_DIR = "models/whisper-mg-v1"
TRAIN_EPOCHS = 10
TRAIN_BATCH_SIZE = 4
TRAIN_LR = 1e-5
TRAIN_PUSH_TO_HUB = ""  # HuggingFace repo ID — leave empty to skip

# ---- Re-draft (set REDRAFT_ENABLED = False to skip) ----
REDRAFT_ENABLED = True
REDRAFT_RUN_DIR = "data/output/my-run"  # local directory with clips/
REDRAFT_LABEL = ""  # run label in Supabase

## Environment Setup

Mounts Drive, clones the repo, creates symlinks for `data/` and `models/`,
and installs Python dependencies.

In [None]:
import os
import shutil
import subprocess
import sys
from pathlib import Path

from google.colab import drive

DRIVE_MOUNT = Path("/content/drive")
DRIVE_BASE = DRIVE_MOUNT / "MyDrive" / DRIVE_ROOT
REPO_DIR = Path("/content/ny-feoko")

# ---- Mount Google Drive ----
if not (DRIVE_MOUNT / "MyDrive").exists():
    drive.mount(str(DRIVE_MOUNT))
print(f"Drive mounted at {DRIVE_MOUNT}")

# ---- Create Drive directories ----
for subdir in ["data/input", "data/output", "data/training", "models"]:
    (DRIVE_BASE / subdir).mkdir(parents=True, exist_ok=True)
print(f"Drive root: {DRIVE_BASE}")

# ---- Clone or update repo ----
if REPO_DIR.exists():
    subprocess.run(
        ["git", "-C", str(REPO_DIR), "pull", "--ff-only"],
        check=True,
    )
    print(f"Repo updated: {REPO_DIR}")
else:
    subprocess.run(
        ["git", "clone", "-b", REPO_BRANCH, REPO_URL, str(REPO_DIR)],
        check=True,
    )
    print(f"Repo cloned: {REPO_DIR}")

os.chdir(REPO_DIR)

# ---- Symlink data/ and models/ to Drive ----
for name in ["data", "models"]:
    link = REPO_DIR / name
    target = DRIVE_BASE / name
    if link.is_symlink():
        link.unlink()
    elif link.is_dir():
        shutil.rmtree(link)
    link.symlink_to(target)
    print(f"  {name}/ -> {target}")

# ---- Symlink .env from Drive (if present) ----
env_drive = DRIVE_BASE / ".env"
env_local = REPO_DIR / ".env"
if env_drive.exists():
    if env_local.is_symlink() or env_local.exists():
        env_local.unlink()
    env_local.symlink_to(env_drive)
    print(f"  .env -> {env_drive}")

# ---- Python version check ----
v = sys.version_info
print(f"\nPython {v.major}.{v.minor}.{v.micro}")
if v < (3, 10):
    print("WARNING: This project requires Python >= 3.10.")

# ---- Install dependencies ----
subprocess.run(["make", "colab-install"], check=True, cwd=str(REPO_DIR))

# ---- HuggingFace login ----
if HF_TOKEN:
    from huggingface_hub import login
    login(token=HF_TOKEN)
    print("Logged in to HuggingFace Hub.")

# ---- Environment summary ----
import torch

gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
print(f"\n{'=' * 50}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA:    {torch.cuda.is_available()} ({gpu_name})")
print(f"Drive:   {DRIVE_BASE}")
print(f"Repo:    {REPO_DIR}")
print(f"{'=' * 50}")
print("Setup complete.")

## Clip Extraction

Runs the full pipeline: **VAD → classify → transcribe → write clips**.

Input file must exist at the configured path on Drive (e.g.
`My Drive/ambara/data/input/my-recording.wav`).

In [None]:
if not EXTRACT_ENABLED:
    print("Clip extraction skipped (EXTRACT_ENABLED = False).")
else:
    cmd = [
        "python", "-m", "clip_extraction.cli", "run",
        "--input", EXTRACT_INPUT,
        "--output", "data/output",
        "--device", "cuda",
        "--vad-threshold", str(EXTRACT_VAD_THRESHOLD),
        "--speech-threshold", str(EXTRACT_SPEECH_THRESHOLD),
        "--verbose",
    ]
    if EXTRACT_WHISPER_HF:
        cmd += ["--whisper-hf", EXTRACT_WHISPER_HF]
    else:
        cmd += ["--whisper-model", EXTRACT_WHISPER_MODEL]
    if EXTRACT_LABEL:
        cmd += ["--label", EXTRACT_LABEL]

    print(f"Running: {' '.join(cmd)}")
    subprocess.run(cmd, check=True)

## ASR Training

Fine-tunes Whisper on the configured training dataset. The dataset must
already exist on Drive (exported locally via `./ambara export-training`,
then uploaded to `My Drive/ambara/data/training/<dataset>/`).

In [None]:
if not TRAIN_ENABLED:
    print("ASR training skipped (TRAIN_ENABLED = False).")
else:
    cmd = [
        "python", "-m", "asr_training.cli", "train",
        "--data-dir", TRAIN_DATASET,
        "--output-dir", TRAIN_OUTPUT_DIR,
        "--device", "cuda",
        "--base-model", TRAIN_BASE_MODEL,
        "--epochs", str(TRAIN_EPOCHS),
        "--batch-size", str(TRAIN_BATCH_SIZE),
        "--lr", str(TRAIN_LR),
    ]
    if TRAIN_PUSH_TO_HUB:
        cmd += ["--push-to-hub", TRAIN_PUSH_TO_HUB]

    print(f"Running: {' '.join(cmd)}")
    subprocess.run(cmd, check=True)

## Re-draft Pending Clips

Uses the trained model to re-transcribe pending clips in Supabase.
Only clips with `status = 'pending'` are updated — corrected clips are left untouched.

Requires `.env` with Supabase credentials on Drive (`My Drive/ambara/.env`).

In [None]:
if not REDRAFT_ENABLED:
    print("Re-draft skipped (REDRAFT_ENABLED = False).")
else:
    # Use the trained model (local or HuggingFace)
    model_path = TRAIN_PUSH_TO_HUB if TRAIN_PUSH_TO_HUB else f"{TRAIN_OUTPUT_DIR}/model"
    
    cmd = [
        "python", "-m", "asr_training.cli", "re-draft",
        "--model", model_path,
        "-d", REDRAFT_RUN_DIR,
        "--device", "cuda",
    ]
    if REDRAFT_LABEL:
        cmd += ["--label", REDRAFT_LABEL]

    print(f"Running: {' '.join(cmd)}")
    subprocess.run(cmd, check=True)

## Results

Shows extraction and training outputs, Drive usage, and suggested next steps.

In [None]:
print(f"{'=' * 50}")
print("Results")
print(f"{'=' * 50}")

# ---- Extraction output ----
output_dir = REPO_DIR / "data" / "output"
if output_dir.exists():
    runs = sorted(d for d in output_dir.iterdir() if d.is_dir())
    if runs:
        latest = runs[-1]
        clips_dir = latest / "clips"
        clip_count = len(list(clips_dir.glob("*.wav"))) if clips_dir.exists() else 0
        print(f"\nLatest extraction run: {latest.name}")
        print(f"  Clips extracted: {clip_count}")

# ---- Training output ----
model_dir = REPO_DIR / TRAIN_OUTPUT_DIR / "model"
if model_dir.exists():
    size_mb = sum(f.stat().st_size for f in model_dir.rglob("*") if f.is_file()) / (1024 * 1024)
    print(f"\nTrained model: {model_dir}")
    print(f"  Size: {size_mb:.0f} MB")
    print(f"  Persisted at: {DRIVE_BASE / TRAIN_OUTPUT_DIR / 'model'}")

if TRAIN_PUSH_TO_HUB:
    print(f"  HuggingFace: https://huggingface.co/{TRAIN_PUSH_TO_HUB}")

# ---- Drive usage ----
print(f"\nDrive usage ({DRIVE_BASE}):")
for subdir in ["data/input", "data/output", "data/training", "models"]:
    p = DRIVE_BASE / subdir
    if p.exists():
        total = sum(f.stat().st_size for f in p.rglob("*") if f.is_file())
        print(f"  {subdir}: {total / (1024 * 1024):.0f} MB")

# ---- Next steps ----
next_model = TRAIN_PUSH_TO_HUB if TRAIN_PUSH_TO_HUB else f"{TRAIN_OUTPUT_DIR}/model"
print(f"\nNext steps:")
print(f"  ./ambara re-draft --model {next_model} \\")
print(f"      -d data/output/<run-dir> --label <run-label>")
if TRAIN_PUSH_TO_HUB:
    print(f"  ./ambara extract -i audio.wav -o data/output/ --device mps \\")
    print(f"      --whisper-hf {TRAIN_PUSH_TO_HUB}")