# Vietnamese ASR (ZipFormer-30M RNNT) — Colab

This notebook runs the Vietnamese ZipFormer-30M RNNT model from this repo and exposes a public HTTP API for testing. It auto‑bootstraps when opened directly from GitHub on Colab.


In [None]:
#@title 1) Bootstrap + Environment setup (ffmpeg, deps, cloudflared)
import os, sys, subprocess, textwrap, pathlib, shutil

REPO_URL = 'https://github.com/nguyenlee97/auto-speech-regconition-vietnamese-deployment.git'
TARGET_DIR = '/content/auto-speech-regconition-vietnamese-deployment'

def sh(cmd: str, check: bool = True):
    print(cmd)
    return subprocess.run(cmd, shell=True, check=check)

# If opened from GitHub, clone and cd into repo so files are present
if not (os.path.exists('requirements.txt') and os.path.exists('model.py')):
    if os.path.exists(TARGET_DIR):
        os.chdir(TARGET_DIR)
    else:
        sh(f'git clone {REPO_URL} {TARGET_DIR}')
        os.chdir(TARGET_DIR)
print('CWD:', os.getcwd())

# ffmpeg for format conversion
sh('apt-get update -y')
sh('apt-get install -y --no-install-recommends ffmpeg')

# Python deps
print('Python:', sys.version)
sh('python -m pip install --upgrade pip wheel setuptools')
py_minor = sys.version_info.minor
if py_minor <= 10 and os.path.exists('requirements.txt'):
    # Use pinned wheels for Python 3.10
    sh('pip install -r requirements.txt')
else:
    # Fallback for Python 3.12+: use ONNX path only
    sh('pip install --index-url https://download.pytorch.org/whl/cpu torch==2.5.1+cpu torchaudio==2.5.1+cpu')
    sh('pip install sherpa-onnx fastapi uvicorn[standard] python-multipart requests sentencepiece numpy==1.26.4')
    # Patch model.py to be resilient when k2/sherpa are unavailable
    mp = pathlib.Path('model.py')
    s = mp.read_text(encoding='utf-8')
    if 'from __future__ import annotations' not in s.splitlines()[0:2]:
        s = 'from __future__ import annotations
' + s
    s = s.replace('import k2  # noqa', textwrap.dedent('''
try:
    import k2  # noqa
except Exception:
    k2 = None
'''))
    s = s.replace('import sherpa', textwrap.dedent('''
try:
    import types as _types
    import sherpa as _sherpa
    sherpa = _sherpa
except Exception:
    class _DummyBase: ...
    class _SherpaModule:
        OfflineRecognizer = _DummyBase
        OnlineRecognizer = _DummyBase
    sherpa = _SherpaModule()
'''))
    mp.write_text(s, encoding='utf-8')

# Cloudflared (public URL tunnel)
if not shutil.which('cloudflared'):
    sh('curl -sSL -o /usr/local/bin/cloudflared https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64')
    sh('chmod +x /usr/local/bin/cloudflared')

print('✅ Environment ready')


In [None]:
# 2) Verify files
import os, glob, platform, sys
print('Platform:', platform.platform())
print('CWD:', os.getcwd())
print('Has model.py?', os.path.exists('model.py'))
print('Found ONNX encoder?', os.path.exists('encoder-epoch-20-avg-10.int8.onnx') or os.path.exists('encoder-epoch-20-avg-10.onnx'))
print('Found ONNX decoder?', os.path.exists('decoder-epoch-20-avg-10.int8.onnx') or os.path.exists('decoder-epoch-20-avg-10.onnx'))
print('Found ONNX joiner?', os.path.exists('joiner-epoch-20-avg-10.int8.onnx') or os.path.exists('joiner-epoch-20-avg-10.onnx'))
print('Found config.json?', os.path.exists('config.json'))
print('Vietnamese test wavs:', glob.glob('test_wavs/vietnamese/*.wav'))


In [None]:
# 3) Quick local inference
from model import get_pretrained_model, decode
REPO_ID = 'hynt/sherpa-onnx-zipformer-vi-int8-2025-10-16'
DECODING_METHOD = 'modified_beam_search'
NUM_ACTIVE_PATHS = 15
TEST_WAV = 'test_wavs/vietnamese/0.wav'
rec = get_pretrained_model(REPO_ID, decoding_method=DECODING_METHOD, num_active_paths=NUM_ACTIVE_PATHS)
print('Recognizer ready')
print('Recognized text:', decode(rec, TEST_WAV))


## Public API server
- `GET /healthz`
- `GET /readyz`
- `POST /v1/transcribe` (multipart key `file`, or JSON with `audio_url`/`audio_base64`)
- Set `REQUIRE_API_KEY=true` and `API_KEY=...` to enforce auth


In [None]:
#@title 4) Start FastAPI (background)
import os, threading, time, requests, uvicorn
os.environ.setdefault('UVICORN_PORT', '8000')
os.environ.setdefault('REQUIRE_API_KEY', 'false')
os.environ.setdefault('API_KEY', '')
os.environ.setdefault('MODEL_REPO_ID', 'hynt/sherpa-onnx-zipformer-vi-int8-2025-10-16')
os.environ.setdefault('DECODING_METHOD', 'modified_beam_search')
os.environ.setdefault('NUM_ACTIVE_PATHS', '15')
os.environ.setdefault('MAX_DURATION_SEC', '60')
def _run():
    uvicorn.run('api_server:app', host='0.0.0.0', port=int(os.environ['UVICORN_PORT']), workers=1)
thr = threading.Thread(target=_run, daemon=True)
thr.start()
base = f"http://127.0.0.1:{os.environ['UVICORN_PORT']}"
for _ in range(60):
    try:
        r = requests.get(base + '/healthz', timeout=1)
        if r.ok:
            print('✅ Local server:', base)
            break
    except Exception:
        time.sleep(0.2)


In [None]:
#@title 5) Expose a public URL via Cloudflared
import subprocess, re, time, requests, os
port = os.environ.get('UVICORN_PORT', '8000')
proc = subprocess.Popen(['cloudflared', 'tunnel', '--url', f'http://localhost:{port}', '--no-autoupdate'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
public_url = None
deadline = time.time() + 90
while time.time() < deadline:
    line = proc.stdout.readline()
    if not line:
        time.sleep(0.2)
        continue
    m = re.search(r'(https://[-a-z0-9.]+trycloudflare.com)', line)
    if m:
        public_url = m.group(1)
        break
if not public_url:
    raise RuntimeError('Failed to obtain public URL from cloudflared logs')
print('🌐 Public API:', public_url)
print('Health:', requests.get(public_url + '/healthz', timeout=5).json())
print('Ready :', requests.get(public_url + '/readyz', timeout=10).json())


In [None]:
# 6) Call the API (examples)
import requests, base64, os
API_BASE = public_url
API_KEY = os.environ.get('API_KEY', '')
use_auth = os.environ.get('REQUIRE_API_KEY','false').lower()=='true'
headers = ({'Authorization': f'Bearer {API_KEY}'} if use_auth else {})
# multipart
with open('test_wavs/vietnamese/0.wav', 'rb') as f:
    r = requests.post(API_BASE + '/v1/transcribe', files={'file': ('0.wav', f, 'audio/wav')}, headers=headers, timeout=60)
print('Multipart:', r.status_code, r.json())
# base64 JSON
with open('test_wavs/vietnamese/1.wav', 'rb') as f:
    b64 = base64.b64encode(f.read()).decode('utf-8')
payload = {'audio_base64': b64, 'decoding_method': 'modified_beam_search', 'num_active_paths': 15}
r = requests.post(API_BASE + '/v1/transcribe', json=payload, headers=headers, timeout=60)
print('Base64  :', r.status_code, r.json())
