In [None]:
# kill old stuff
!pkill -f "uvicorn.*8000" || true
!fuser -k 8000/tcp || true
!pkill -f "cloudflared tunnel" || true

In [None]:
!pip install -q kokoro>=0.9.2 soundfile
!apt-get -qq -y install espeak-ng > /dev/null 2>&1
!pip -q install "fastapi[standard]" uvicorn soundfile nump
!pip -q install scipy

from kokoro import KPipeline
from IPython.display import display, Audio
import soundfile as sf
import torch
import asyncio
import websockets
import io, base64, json
from typing import Optional
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import uvicorn

In [None]:
UI = r"""<!doctype html>
<meta charset="utf-8" />
<title>TTS (spec-compliant)</title>
<style>
  body{font:14px/1.4 system-ui,-apple-system,Segoe UI,Roboto,Helvetica,Arial,sans-serif;margin:24px}
  .row{margin-bottom:10px}
  input[type=text]{width:520px}
  textarea{width:520px;height:100px}
  #caps{
    white-space: pre-wrap;     /* keep quotes/brackets visible */
    overflow-wrap: anywhere;   /* wrap long JSON tokens */
    word-break: break-word;    /* legacy alias; fine to keep */
    background:#f6f6f6; padding:8px; border-radius:8px; min-height:3em;
    font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace;
  }
  #status{margin-left:8px;color:#555}
</style>

<h2>TTS WebSocket Client</h2>
<div class="row">
  <input id="url" type="text" placeholder="ws(s)://host/ws" />
  <button id="connect">Connect</button>
  <button id="speak">Speak</button>

  <label style="margin-left:12px">
    Words/chunk:
    <input id="wpc" type="number" min="1" value="3" style="width:60px">
  </label>
  <button id="speakChunked">Speak (chunked)</button>
  <button id="speakPunct">Speak (punct-only)</button>

  <button id="end">End</button>
  <span id="status">[disconnected]</span>
</div>

<div class="row">
  <textarea id="text" rows="3" cols="70" placeholder="Type text here..."></textarea>
</div>

<h3>Captions</h3>
<div id="caps"></div>

<script>
"use strict";

let ws, ctx;
let playhead = 0;

// punctuation boundaries (English + common CJK)
const PUNCT_RE = /[.!?;:,…，。？！；：、]/;

const $ = id => document.getElementById(id);

function relWsURL(){
  const p = location.protocol === "https:" ? "wss" : "ws";
  return p + "://" + location.host + "/ws";
}

window.addEventListener("DOMContentLoaded", () => {
  const u = $("url");
  if (u && !u.value) u.value = relWsURL();
});

async function ensureAudio(){
  if (!ctx) ctx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 44100 });
  if (ctx.state === "suspended") await ctx.resume();
  if (playhead === 0) playhead = ctx.currentTime;
}

// trim leading/trailing near-silence on Int16 PCM
function trimPCM16(pcm, thresh = 3) {
  let s = 0, e = pcm.length - 1;
  while (s < pcm.length && Math.abs(pcm[s]) <= thresh) s++;
  while (e >= s && Math.abs(pcm[e]) <= thresh) e--;
  return s > e ? new Int16Array(0) : pcm.subarray(s, e + 1);
}

function playPcm16Mono44100(b64){
  if (!ctx) ctx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 44100 });
  if (playhead === 0) playhead = ctx.currentTime;

  // decode base64 -> Int16
  const bytes = Uint8Array.from(atob(b64), c => c.charCodeAt(0));
  const view  = new DataView(bytes.buffer);
  const N = bytes.byteLength >> 1;
  const pcm = new Int16Array(N);
  for (let i = 0; i < N; i++) pcm[i] = view.getInt16(i*2, true);

  // trim head/tail near-silence
  const trimmed = trimPCM16(pcm);
  if (trimmed.length === 0) return;

  // Int16 -> Float32
  const f32 = new Float32Array(trimmed.length);
  for (let i = 0; i < trimmed.length; i++) {
    f32[i] = Math.max(-1, Math.min(1, trimmed[i] / 32767));
  }

  // make AudioBuffer
  const buf = ctx.createBuffer(1, f32.length, 44100);
  buf.copyToChannel(f32, 0, 0);

  // schedule sequential playback with a tight guard (3ms)
  const now = ctx.currentTime;
  if (playhead < now) playhead = now;
  const startAt = Math.max(playhead, now + 0.003);
  const endAt   = startAt + buf.duration;

  // tiny fades to avoid clicks
  const gn = ctx.createGain();
  gn.gain.setValueAtTime(0, startAt);
  gn.gain.linearRampToValueAtTime(1, startAt + 0.005);
  gn.gain.setValueAtTime(1, endAt - 0.005);
  gn.gain.linearRampToValueAtTime(0, endAt);

  const src = ctx.createBufferSource();
  src.buffer = buf;
  src.connect(gn);
  gn.connect(ctx.destination);
  src.start(startAt);

  playhead = endAt;
}

// Grapheme-aware splitter (handles emoji/accents). Falls back to Array.from.
const GRAPHEME_SEG = (window.Intl && Intl.Segmenter)
  ? new Intl.Segmenter(undefined, { granularity: "grapheme" })
  : null;

function splitGraphemes(str) {
  if (GRAPHEME_SEG) return Array.from(GRAPHEME_SEG.segment(str), s => s.segment);
  return Array.from(str);
}

// ONLY split at punctuation (no word-count fallback)
function chunkByPunctOnly(text, punctRe = PUNCT_RE) {
  const tokens = String(text).trim().split(/\s+/);
  const chunks = [];
  let cur = [];
  for (const tk of tokens) {
    cur.push(tk);
    if (punctRe.test(tk.slice(-1))) {
      chunks.push(cur.join(" ") + " ");
      cur = [];
    }
  }
  if (cur.length) chunks.push(cur.join(" ") + " ");
  return chunks;
}

const sleep = ms => new Promise(r => setTimeout(r, ms));
function setStatus(s){ $("status").textContent = s; }

// Connect
$("connect").onclick = async () => {
  await ensureAudio();
  playhead = ctx.currentTime;

  if (ws) { try { ws.close(); } catch {} }
  const url = $("url").value.trim();

  // simple validation without regex (prevents the regex error)
  try { new URL(url); } catch { alert("Invalid URL"); return; }
  const okScheme = url.startsWith("ws://") || url.startsWith("wss://");
  const okPath   = url.endsWith("/ws");
  if (!okScheme || !okPath) { alert("URL must start with ws:// or wss:// and end with /ws"); return; }

  ws = new WebSocket(url);
  ws.onopen = () => {
    setStatus("[connected]");
    ws.send(JSON.stringify({ text: " ", flush: false })); // prime
  };
  ws.onclose = () => setStatus("[disconnected]");
  ws.onerror = e => { setStatus("[error]"); console.error("ws error", e); };
  ws.onmessage = (e) => {
  let m; try { m = JSON.parse(e.data); } catch { return; }

  // play audio if present
  if (typeof m.audio === "string" && m.audio) {
    playPcm16Mono44100(m.audio);
  }

  if (m.alignment !== undefined) {
  $("caps").textContent = (typeof m.alignment === "string")
    ? m.alignment
    : JSON.stringify(m.alignment);  // no pretty-printing
  }
};

};

// Speak (single flush)
$("speak").onclick = async () => {
  if (!ws || ws.readyState !== 1) return alert("Connect first");
  await ensureAudio();
  playhead = Math.max(playhead, ctx.currentTime);

  const text = $("text").value.trim();
  if (!text) return;

  ws.send(JSON.stringify({ text, flush: true }));
};

// Speak (chunked by N words)
$("speakChunked").onclick = async () => {
  if (!ws || ws.readyState !== 1) return alert("Connect first");
  await ensureAudio();
  playhead = Math.max(playhead, ctx.currentTime);

  const text = $("text").value.trim();
  if (!text) return;

  const N = Math.max(1, parseInt($("wpc").value || "1", 10));
  const words = text.split(/\s+/);

  ws.send(JSON.stringify({ text: " ", flush: false })); // prime

  for (let i = 0; i < words.length; i += N) {
    const chunk = words.slice(i, i + N).join(" ") + " ";
    ws.send(JSON.stringify({ text: chunk, flush: true }));
    // await sleep(10); // optional tiny throttle
  }
  ws.send(JSON.stringify({ text: "", flush: false })); // end
};

// Speak (punctuation-only)
$("speakPunct").onclick = async () => {
  if (!ws || ws.readyState !== 1) return alert("Connect first");
  await ensureAudio();
  playhead = Math.max(playhead, ctx.currentTime);

  const text = $("text").value.trim();
  if (!text) return;

  const chunks = chunkByPunctOnly(text);

  ws.send(JSON.stringify({ text: " ", flush: false })); // prime
  for (const c of chunks) {
    ws.send(JSON.stringify({ text: c, flush: true }));
    // await sleep(10); // optional tiny throttle
  }
  ws.send(JSON.stringify({ text: "", flush: false }));  // end
};

// End
$("end").onclick = () => {
  if (ws && ws.readyState === 1) {
    try { ws.send(JSON.stringify({ text: "", flush: false })); } catch {}
    try { ws.close(); } catch {}
  }
  if (ctx) playhead = ctx.currentTime;
};
</script>
"""


In [None]:
# === Spec helpers: tensor→numpy, resample to 44.1 kHz, PCM16 bytes, alignment ===
import base64, json, numpy as np
from scipy import signal

try:
    import torch
    TORCH = True
except Exception:
    TORCH = False


SRC_SR_DEFAULT = 24000   # model's native sample rate
DST_SR = 44100           # spec requires 44.1k mono PCM16

def to_numpy_audio(x):
    import torch
    if isinstance(x, torch.Tensor):
        x = x.detach().to("cpu").float().numpy()
    else:
        x = np.asarray(x, dtype=np.float32)
    if x.ndim == 2:          # downmix
        x = x.mean(axis=-1)
    return np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)

def log_audio(tag, a):
    x = to_numpy_audio(a)
    rms = float(np.sqrt(np.mean(x**2))) if x.size else 0.0
    print(f"[{tag}] shape={x.shape} rms={rms:.6f}")

def resample_to_44k(a_f32, src_sr, dst_sr=DST_SR):
    if a_f32.size == 0 or src_sr == dst_sr: return a_f32
    g = np.gcd(src_sr, dst_sr); up, down = dst_sr // g, src_sr // g
    return signal.resample_poly(a_f32, up, down).astype(np.float32, copy=False)

def encode_chunk_base64_pcm16_44k(audio, src_sr, min_ms=100):
    """
    Returns (base64_pcm16, n_samples_44k). Ensures >= min_ms of audio so the browser plays it.
    """
    a = resample_to_44k(to_numpy_audio(audio), src_sr, DST_SR)
    min_samples = int(DST_SR * (min_ms / 1000.0))
    if a.size < min_samples:
        a = np.pad(a, (0, min_samples - a.size))
    a = np.clip(a, -1.0, 1.0)
    pcm16 = (a * 32767.0).astype(np.int16, copy=False)
    b64 = base64.b64encode(pcm16.tobytes()).decode("ascii")
    return b64, int(a.size)

def total_ms_from_samples(n44):  # 44.1k only
    return 1000.0 * n44 / float(DST_SR)

# alignment
def _char_weight(ch):
    if ch.isspace(): return 0.45
    if ch in ".!?": return 0.40
    if ch in ",;:": return 0.55
    if ch.lower() in "aeiou": return 1.15
    return 1.0

def build_alignment_from_gs(gs, total_ms):

    chars = list(gs or "")
    if not chars: return {"chars": [], "char_start_times_ms": [], "char_durations_ms": []}
    w = np.array([_char_weight(c) for c in chars], float); s = w.sum() or 1.0
    d = w / s * float(total_ms)
    di = np.floor(d).astype(int); diff = int(round(total_ms)) - int(di.sum())
    if diff: di[-1] += diff
    MIN=20; need = np.maximum(0, MIN - di);
    if need.any(): di += need; di[-1] -= int(need.sum()); di[-1] = max(1, di[-1])
    starts = np.concatenate(([0], np.cumsum(di[:-1]))).astype(int).tolist()
    return {"chars": chars, "char_start_times_ms": starts, "char_durations_ms": di.astype(int).tolist()}


In [None]:
from fastapi import WebSocket, WebSocketDisconnect
import time, statistics
from time import perf_counter_ns
import numpy as np
import base64
LAT = []   # collect many samples to compute p50 across utterances


DEFAULT_LANG = "a"
DEFAULT_VOICE = "af_heart"

_pipeline = None
_cur_lang = None

def get_pipeline(lang_code: str):
    global _pipeline, _cur_lang
    if _pipeline is None or _cur_lang != lang_code:
        _pipeline = KPipeline(lang_code=lang_code)   # <-- your TTS pipeline
        _cur_lang = lang_code
    return _pipeline

app = FastAPI()

@app.get("/ui")
def ui():
    return HTMLResponse(UI)

@app.websocket("/ws")
async def tts_ws(ws: WebSocket):
    """
    INPUT (client → server): { "text": str, "flush": bool }
      - first chunk: " " (single space)
      - flush:true → synthesize buffered text so far, KEEP socket open
      - final chunk: "" (empty string, flush:false/omitted) → CLOSE socket

    OUTPUT (server → client) for each TTS chunk:
      { "audio": "<base64 PCM16 mono @ 44.1k>", "alignment": { ... } }
    """
    await ws.accept()
    lang_code = DEFAULT_LANG
    voice = DEFAULT_VOICE
    buf: list[str] = []

    t_first_text_ns  = None   # when the first *real* text chunk arrives
    t_first_audio_ns = None   # when we send the first audio frame
    lat_printed = False

    async def synth_and_stream(text: str):

        nonlocal t_first_text_ns, t_first_audio_ns, lat_printed

        if not text.strip():
            return

        pipe = get_pipeline(lang_code)

        # learn source sample rate if exposed by pipeline
        src_sr = getattr(pipe, "sample_rate", SRC_SR_DEFAULT)

        if t_first_text_ns is None:
            t_first_text_ns = perf_counter_ns()

        try:
            generator = pipe(text, voice=voice)
            for gs, ps, audio in generator:

                if not lat_printed:
                    t_first_audio_ns = time.perf_counter_ns()
                    dt_ms = (t_first_audio_ns - t_first_text_ns or t_first_audio_ns) / 1e6

                    LAT.append(dt_ms)
                    p50 = statistics.median(LAT)
                    print(f"[lat] first_text→first_audio = {dt_ms:.1f} ms " f"(p50={p50:.1f} ms, n={len(LAT)})")
                    lat_printed = True



                #log_audio("model_chunk", audio)

                # encode audio as raw PCM16 @ 44.1k (base64)
                b64, n44 = encode_chunk_base64_pcm16_44k(audio, src_sr=src_sr)
                chunk_ms = total_ms_from_samples(n44)

                # build alignment from gs and chunk duration
                alignment = build_alignment_from_gs(gs, chunk_ms)

                # send EXACTLY the two required fields
                await ws.send_text(json.dumps({
                    "audio": b64,
                    "alignment": alignment
                }))

                #print(f"sent: {n44} samples @44.1k ({chunk_ms:.0f} ms), base64_len={len(b64)}, chars={len(alignment['chars'])}")


        except Exception as e:
            # optional: log error; spec has no error packet for audio chunks
            print("TTS error:", e)
        finally:
            # Ready to measure the next utterance on the same socket
            t_first_text_ns  = None
            t_first_audio_ns = None

    try:
        while True:
            raw = await ws.receive_text()
            try:
                msg = json.loads(raw)
            except Exception:
                # ignore bad JSON
                continue


            text = msg.get("text", "")
            flush = bool(msg.get("flush", False))



            # first priming chunk: " "
            if text == " " and not flush and not buf:
                buf.append(text)
                continue


            # accumulate non-empty, non-priming text
            if text and text != " ":
                buf.append(text)

            # flush → synthesize buffered text, keep socket open
            if flush:
                text_to_speak = "".join(buf)
                buf.clear()
                await synth_and_stream(text_to_speak)
                continue

            # final empty chunk (no flush) → close session
            if text == "" and not flush:
                break



    except WebSocketDisconnect:
        pass



In [None]:
def warm_tts():
    pipe = get_pipeline(DEFAULT_LANG)
    _ = getattr(pipe, "sample_rate", SRC_SR_DEFAULT)
    try:
        gen = pipe("warmup.", voice=DEFAULT_VOICE)
        next(iter(gen))   # force first yield
    except StopIteration:
        pass
    except Exception as e:
        print("warmup error:", e)

# Call once after defining app (before starting uvicorn):
warm_tts()


In [None]:
# start server in background
import os, time, threading, uvicorn

def _run():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="warning")

thr = threading.Thread(target=_run, daemon=True)
thr.start()
time.sleep(2)
print("✅ Spec-compliant WS running at ws://127.0.0.1:8000/ws (tunnel wss://…/ws)")


In [None]:
!apt -yq install cloudflared || true

import os, stat, urllib.request
if os.system("cloudflared --version > /dev/null 2>&1") != 0:
    url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"
    urllib.request.urlretrieve(url, "/usr/local/bin/cloudflared")
    os.chmod("/usr/local/bin/cloudflared", os.stat("/usr/local/bin/cloudflared").st_mode | stat.S_IEXEC)

print("🚀 Starting tunnel; keep this cell running while testing.")
!cloudflared tunnel --url http://127.0.0.1:8000 --no-autoupdate