# Real-time Phoneme Transcription with Gradio
This notebook demonstrates real-time phoneme transcription using a pre-trained Wav2Vec2 model and Gradio. You can use your microphone or upload an audio file to see the phoneme transcription in real time.

In [None]:
# Import Required Libraries
import gradio as gr
import numpy as np
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import librosa
import soundfile as sf
from io import BytesIO
from Levenshtein import distance as levenshtein_distance


## Load Pre-trained Phoneme Transcription Model
We use a Wav2Vec2 model fine-tuned for phoneme recognition. The model and processor are loaded below.

In [None]:
# Load the phoneme model and processor
def load_phoneme_model():
    processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
    model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
    return processor, model

processor, model = load_phoneme_model()

## Define Real-time Audio Processing Function
This function takes audio input, splits it into chunks, transcribes phonemes for each chunk, and returns the results.

In [None]:
# Phoneme scoring and chunking logic from main.py
vowel_tolerance = {
    "a": {"a", "ɑ", "ə", "æ"},
    "e": {"e", "ɛ", "ə"},
    "i": {"i", "ɪ"},
    "o": {"o", "ɔ"},
    "u": {"u", "ʊ"}
}

def split_audio_chunks(y, sr, chunk_duration=1.2):
    chunk_samples = int(chunk_duration * sr)
    total_samples = len(y)
    chunks = []
    for start in range(0, total_samples, chunk_samples):
        end = min(start + chunk_samples, total_samples)
        chunks.append(y[start:end])
    return chunks

def phoneme_score(audio_bytes, target_word="masyarakat"):
    # Use global processor, model
    global processor, model
    y, sr = librosa.load(BytesIO(audio_bytes), sr=16000)
    input_values = processor(y, sampling_rate=16000, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model(**input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    phonemes = processor.batch_decode(predicted_ids)[0].split()
    target_phonemes = "m a ʃ a ɾ a k a t".split()
    child_phonemes = phonemes
    tolerant_child = []
    for ph in child_phonemes:
        if any(ph in vowel_tolerance.get(v, set()) for v in vowel_tolerance):
            for v, variants in vowel_tolerance.items():
                if ph in variants:
                    tolerant_child.append(v)
                    break
        else:
            tolerant_child.append(ph)
    edit_distance = levenshtein_distance("".join(target_phonemes), "".join(tolerant_child))
    max_distance = max(len(target_phonemes), len(tolerant_child))
    score = ((max_distance - edit_distance) / max_distance) * 100 if max_distance > 0 else 0
    mismatches = []
    for i, (target, child) in enumerate(zip(target_phonemes, tolerant_child + [""] * (len(target_phonemes) - len(tolerant_child)))):
        if i >= len(tolerant_child) or (target != child and not (target in vowel_tolerance and child in vowel_tolerance.get(target, set()))):
            mismatches.append(f"Pos {i+1}: {target} → {child if i < len(tolerant_child) else 'missing'}")
    return {
        "phonemes": " ".join(phonemes),
        "score": score,
        "mismatches": mismatches,
        "target_phonemes": " ".join(target_phonemes),
        "tolerant_child": " ".join(tolerant_child)
    }

def process_audio(audio, chunk_duration=1.2):
    if isinstance(audio, tuple):
        audio = audio[0]  # gradio mic returns (np.array, sr)
    if isinstance(audio, np.ndarray):
        # Convert to bytes
        buf = BytesIO()
        sf.write(buf, audio, 16000, format='WAV')
        audio_bytes = buf.getvalue()
    else:
        audio_bytes = audio
    y, sr = librosa.load(BytesIO(audio_bytes), sr=16000)
    chunks = split_audio_chunks(y, sr, chunk_duration)
    results = []
    for idx, chunk in enumerate(chunks):
        if len(chunk) == 0:
            continue
        chunk_bytes = BytesIO()
        sf.write(chunk_bytes, chunk, sr, format='WAV')
        chunk_bytes.seek(0)
        chunk_bytes_data = chunk_bytes.read()
        result = phoneme_score(chunk_bytes_data)
        results.append({
            "chunk": idx+1,
            "start": idx*chunk_duration,
            "end": (idx+1)*chunk_duration,
            **result
        })
    return results


## Create Gradio Interface for Real-time Phoneme Transcription
We will now set up a Gradio interface that accepts live audio input and displays the phoneme transcription output for each chunk.

In [None]:
# Gradio interface for real-time phoneme transcription
def gradio_transcribe(audio):
    results = process_audio(audio)
    display = ""
    for r in results:
        display += f"<b>Chunk {r['chunk']} ({r['start']:.1f}-{r['end']:.1f}s)</b><br>"
        display += f"<b>Target Phonemes:</b> <code>{r['target_phonemes']}</code><br>"
        display += f"<b>Detected Phonemes:</b> <code>{r['tolerant_child']}</code><br>"
        display += f"<b>Raw Model Output:</b> <code>{r['phonemes']}</code><br>"
        display += f"<b>Score:</b> <code>{r['score']:.1f}%</code><br>"
        if r["mismatches"]:
            display += f"<span style='color:red'><b>Missed/Incorrect Phonemes:</b><br>{'<br>'.join(r['mismatches'])}</span><br>"
        else:
            display += f"<span style='color:green'><b>All phonemes correct! 🎉</b></span><br>"
        display += "<hr>"
    return display

iface = gr.Interface(
    fn=gradio_transcribe,
    inputs=gr.Audio(source="microphone", type="filepath", label="Speak or Upload Audio (WAV)"),
    outputs=gr.outputs.HTML(label="Phoneme Transcription Results"),
    title="Real-time Phoneme Transcription",
    description="Speak or upload a WAV file to see real-time phoneme transcription and scoring."
)


## Launch Gradio App
Run the cell below to launch the Gradio app and start real-time phoneme transcription.

In [None]:
# Launch the Gradio app
iface.launch(debug=True)