# Speech-to-LaTeX demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dkorzh10/speech2latex/blob/main/demo_speech2latex.ipynb)

Try ASR post-correction models: play samples from the repo or record your own audio (Colab microphone / file upload).

## 1. Install and clone

In [None]:
!pip install -q transformers accelerate torch torchaudio huggingface_hub peft whisper ipywidgets

In [None]:
import os
import sys
REPO_DIR = "speech2latex"
if not os.path.exists(REPO_DIR):
    !git clone --depth 1 https://github.com/dkorzh10/speech2latex.git
os.chdir(REPO_DIR)
asr_path = os.path.join(os.getcwd(), "ASRPostCorrection")
if asr_path not in sys.path:
    sys.path.insert(0, asr_path)
print("Repo root:", os.getcwd())

## 2. Model choice and load

In [None]:
# Available models on Hugging Face (marsianin500)
MODELS = [
    "marsianin500/Qwen2.5-0.5B-instruct-equations_multilingual_mix",
    "marsianin500/Qwen2.5-0.5B-instruct-equations_multilingual_mix_full",
    "marsianin500/Qwen2.5-0.5B-instruct-sentences_eng_mix",
    "marsianin500/Qwen2.5-1.5B-instruct-equations_multilingual_mix",
    "marsianin500/Qwen2.5-math-1.5B-instruct-equations_multilingual_mix_full",
    "marsianin500/Qwen2.5-math-1.5B-instruct-equations_multilingual_mix",
    "marsianin500/Qwen2.5-math-1.5B-instruct-sentences_eng_mix",
    "marsianin500/Qwen2.5-7B-instruct-r16a64-equations_multilingual_mix",
    "marsianin500/Qwen2.5-7B-instruct-r16a64-equations_multilingual_mix_full",
]

# 7B models are LoRA adapters; base model name for them
BASE_7B = "Qwen/Qwen2.5-7B-Instruct"

In [None]:
from IPython.display import display
import ipywidgets as widgets

model_dropdown = widgets.Dropdown(
    options=[(m.replace("marsianin500/", "")[:70], m) for m in MODELS],
    value=MODELS[0],
    description="Model:",
    layout=widgets.Layout(width="500px")
)
display(model_dropdown)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

def load_model(repo_id: str, device: str = "cuda"):
    from huggingface_hub import list_repo_files
    try:
        tokenizer = AutoTokenizer.from_pretrained(repo_id)
    except Exception:
        tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder="tokenizer")
    # Check if repo has tuned-model (full) or adapter (LoRA)
    files = list_repo_files(repo_id)
    has_tuned = any("tuned-model" in f or f.startswith("tuned-model/") for f in files)
    has_adapter = any("adapter_model" in f or "adapter_config" in f for f in files)
    if has_tuned:
        model = AutoModelForCausalLM.from_pretrained(
            repo_id, subfolder="tuned-model",
            torch_dtype=torch.bfloat16, device_map=device
        )
    elif has_adapter:
        base = AutoModelForCausalLM.from_pretrained(
            BASE_7B, torch_dtype=torch.bfloat16, device_map=device
        )
        model = PeftModel.from_pretrained(base, repo_id)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            repo_id, torch_dtype=torch.bfloat16, device_map=device
        )
    model.eval()
    return tokenizer, model

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading {model_dropdown.value} on {device}...")
tokenizer, model = load_model(model_dropdown.value, device)
print("Done.")

## 3. Inference helper (from repo)

In [None]:
import sys
sys.path.insert(0, os.path.join(os.getcwd(), "ASRPostCorrection"))
from chat_template_with_generation import CHAT_TEMPLATE_WITH_GENERATION

def transcribe_whisper(audio_path_or_array, language="en"):
    import whisper
    w = whisper.load_model("base", device=device)
    if isinstance(audio_path_or_array, str):
        r = w.transcribe(audio_path_or_array, language=language, fp16=(device=="cuda"))
    else:
        r = w.transcribe(audio_path_or_array, language=language, fp16=(device=="cuda"))
    return (r.get("text") or "").strip()

def pronunciation_to_latex(pronunciation: str):
    chat = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Please, give me LaTeX representation of the following formula. Formula pronunciation: " + pronunciation},
    ]
    out = tokenizer.apply_chat_template(
        chat, padding=True, tokenize=True, chat_template=CHAT_TEMPLATE_WITH_GENERATION,
        return_assistant_tokens_mask=True, return_dict=True, return_tensors="pt",
        add_generation_prompt=True
    )
    gen_ids = model.generate(
        inputs=out["input_ids"].to(model.device),
        attention_mask=out["attention_mask"].to(model.device),
        max_new_tokens=256, do_sample=False, pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    text = tokenizer.decode(gen_ids[0][out["input_ids"].shape[1]:], skip_special_tokens=False)
    if "<|im_end|>" in text: text = text.split("<|im_end|>")[0]
    if "<|endoftext|>" in text: text = text.split("<|endoftext|>")[0]
    return text.strip()

## 4. Demo: samples from repo (no dataset download)

In [None]:
GITHUB_RAW = "https://raw.githubusercontent.com/dkorzh10/speech2latex/main"
SAMPLES = {
    "equations_test": ["human_eng_00", "human_eng_01", "human_eng_02", "human_eng_03", "tts_eng_00"],
    "sentences_test": ["human_eng_00", "tts_eng_00"],
}

import urllib.request
import json

# Optional: load precomputed demo_results.json for reference/predicted
demo_results = {}
try:
    with urllib.request.urlopen(GITHUB_RAW + "/docs/demo_results.json") as r:
        data = json.load(r)
        for item in data.get("results", []):
            key = item["split"] + "/" + item["sample_id"]
            demo_results[key] = item
except Exception as e:
    print("Could not load demo_results.json:", e)

In [None]:
split = "equations_test"
sample_id = "human_eng_01"
url = f"{GITHUB_RAW}/sample_datasets/{split}/{sample_id}.wav"
local_wav = "/tmp/sample.wav"
urllib.request.urlretrieve(url, local_wav)

from IPython.display import Audio, display
display(Audio(local_wav))

pron = transcribe_whisper(local_wav, language="en")
print("Whisper:", pron)
latex = pronunciation_to_latex(pron)
print("LaTeX:", latex)
key = split + "/" + sample_id
if key in demo_results:
    print("Reference:", demo_results[key].get("reference_latex", ""))

## 5. Record your own audio (microphone) or upload file

In [None]:
from google.colab import files
import io

print("Upload an audio file (e.g. .wav) or use the recorder below.")
uploaded = files.upload()
if uploaded:
    fname = list(uploaded.keys())[0]
    user_audio_path = f"/tmp/user_{fname}"
    with open(user_audio_path, "wb") as f:
        f.write(uploaded[fname])
    print("Saved to", user_audio_path)

In [None]:
# Run this cell once so the recorder can save directly to Colab (no re-upload).
from google.colab import output
import base64

RECORDED_PATH = "/tmp/recording.webm"

def save_recorded_audio(b64_data):
    """Called from JS when user stops recording. Saves to /tmp/recording.webm."""
    global RECORDED_PATH
    if b64_data:
        data = base64.b64decode(b64_data)
        with open(RECORDED_PATH, "wb") as f:
            f.write(data)
        print(f"Saved to {RECORDED_PATH} ({len(data)} bytes). Run the 'Process' cell below.")
    return "ok"

output.register_callback("save_recorded_audio", save_recorded_audio)
print("Callback registered. Use the Record/Stop button below; audio will save in Colab.")

In [None]:
# Record from microphone. Click Record, speak, then Stop — audio saves in Colab (run the callback cell above first).
from IPython.display import HTML, display
display(HTML("""
<button id="recBtn">Record</button>
<span id="status" style="margin-left:8px"></span>
<script>
(function() {
  let rec, chunks = [], stream;
  const btn = document.getElementById('recBtn');
  const status = document.getElementById('status');
  const invoke = (typeof google !== 'undefined' && google.colab && google.colab.kernel) ? google.colab.kernel.invokeFunction.bind(google.colab.kernel) : null;
  btn.onclick = async function() {
    if (!rec || rec.state === 'inactive') {
      try {
        stream = await navigator.mediaDevices.getUserMedia({audio: true});
        rec = new MediaRecorder(stream);
        chunks = [];
        rec.ondataavailable = e => e.data.size && chunks.push(e.data);
        rec.onstop = async () => {
          stream.getTracks().forEach(t => t.stop());
          const blob = new Blob(chunks, {type: 'audio/webm'});
          if (invoke) {
            const reader = new FileReader();
            reader.onload = async () => {
              const b64 = reader.result.split(',')[1] || '';
              try { await invoke('save_recorded_audio', [b64], {}); status.textContent = 'Saved in Colab. Run the Process cell.'; }
              catch (e) { status.textContent = 'Callback failed — download and upload instead.'; const a = document.createElement('a'); a.href = URL.createObjectURL(blob); a.download = 'recording.webm'; a.click(); }
            };
            reader.readAsDataURL(blob);
          } else {
            const a = document.createElement('a'); a.href = URL.createObjectURL(blob); a.download = 'recording.webm'; a.click(); status.textContent = 'Downloaded — upload above.';
          }
        };
        rec.start(); btn.textContent = 'Stop'; status.textContent = 'Recording...';
      } catch (e) { status.textContent = 'Error: ' + e.message; }
    } else { rec.stop(); btn.textContent = 'Record'; }
  };
})();
</script>
"""))

In [None]:
# Process: use microphone recording (if any) or the file you uploaded above.
RECORDED_PATH = "/tmp/recording.webm"
try: uploaded
except NameError: uploaded = {}
if os.path.isfile(RECORDED_PATH):
    user_audio_path = RECORDED_PATH
    print("Using recorded audio from microphone.")
elif uploaded:
    user_audio_path = "/tmp/user_" + list(uploaded.keys())[0]
    print("Using uploaded file.")
else:
    user_audio_path = None
if user_audio_path and os.path.isfile(user_audio_path):
    display(Audio(user_audio_path))
    lang = "en"  # or "ru" for Russian
    pron = transcribe_whisper(user_audio_path, language=lang)
    print("Whisper:", pron)
    latex = pronunciation_to_latex(pron)
    print("LaTeX:", latex)
else:
    print("Record with the Record/Stop button above (run the callback cell first), or upload an audio file.")