<a href="https://colab.research.google.com/github/chenchihwang/SDS/blob/main/multilingual_consistent_asr_%2B_tts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio openai-whisper
import gradio as gr
import whisper
import numpy as np
import tempfile
import os
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load Whisper
# 'tiny', 'base', 'small', 'medium', 'large'
model_size = "tiny"
print(f"Loading ASR {model_size} model")
model = whisper.load_model(model_size, device=device)
print("ASR Model loaded")

def transcribe_audio(audio, detect_language=False):
    """
    Transcribe audio and optionally detect language

    Args:
        audio: Tuple (sample_rate, audio_data) from Gradio
        detect_language: Boolean to toggle language detection

    Returns:
        Tuple of (transcription text, detected language)
    """
    if audio is None:
        return "No audio recorded. Please record some audio.", ""

    sample_rate, audio_data = audio

    # Save audio data as temporary WAV file
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
        temp_filename = temp_audio.name

    # Save audio with sample rate
    import scipy.io.wavfile as wav
    wav.write(temp_filename, sample_rate, audio_data)

    # Transcribe audio
    try:
        language = ""
        if detect_language:
            # LID
            audio = whisper.load_audio(temp_filename)
            audio = whisper.pad_or_trim(audio)
            mel = whisper.log_mel_spectrogram(audio).to(device)
            _, probs = model.detect_language(mel)
            detected_lang_code = max(probs, key=probs.get)

            language_names = {
                "en": "English", "zh": "Chinese", "de": "German",
                "es": "Spanish", "ru": "Russian", "ko": "Korean",
                "fr": "French", "ja": "Japanese", "pt": "Portuguese",
                "tr": "Turkish", "pl": "Polish", "ca": "Catalan",
                "nl": "Dutch", "ar": "Arabic", "sv": "Swedish",
                "it": "Italian", "id": "Indonesian", "hi": "Hindi",
                "fi": "Finnish", "vi": "Vietnamese", "he": "Hebrew",
                "uk": "Ukrainian", "el": "Greek", "ms": "Malay",
                "cs": "Czech", "ro": "Romanian", "da": "Danish",
                "hu": "Hungarian", "ta": "Tamil", "no": "Norwegian",
                "th": "Thai", "ur": "Urdu", "hr": "Croatian",
                "bg": "Bulgarian", "lt": "Lithuanian", "la": "Latin",
                "mi": "Maori", "ml": "Malayalam", "cy": "Welsh",
                "sk": "Slovak", "te": "Telugu", "fa": "Persian",
                "lv": "Latvian", "bn": "Bengali", "sr": "Serbian",
                "az": "Azerbaijani", "sl": "Slovenian", "kn": "Kannada",
                "et": "Estonian", "mk": "Macedonian", "br": "Breton",
                "eu": "Basque", "is": "Icelandic", "hy": "Armenian",
                "ne": "Nepali", "mn": "Mongolian", "bs": "Bosnian",
                "kk": "Kazakh", "sq": "Albanian", "sw": "Swahili",
                "gl": "Galician", "mr": "Marathi", "pa": "Punjabi",
                "si": "Sinhala", "km": "Khmer", "sn": "Shona",
                "yo": "Yoruba", "so": "Somali", "af": "Afrikaans",
                "oc": "Occitan", "ka": "Georgian", "be": "Belarusian",
                "tg": "Tajik", "sd": "Sindhi", "gu": "Gujarati",
                "am": "Amharic", "yi": "Yiddish", "lo": "Lao",
                "uz": "Uzbek", "fo": "Faroese", "ht": "Haitian Creole",
                "ps": "Pashto", "tk": "Turkmen", "nn": "Nynorsk",
                "mt": "Maltese", "sa": "Sanskrit", "lb": "Luxembourgish",
                "my": "Myanmar", "bo": "Tibetan", "tl": "Tagalog",
                "mg": "Malagasy", "as": "Assamese", "tt": "Tatar",
                "haw": "Hawaiian", "ln": "Lingala", "ha": "Hausa",
                "ba": "Bashkir", "jw": "Javanese", "su": "Sundanese"
            }

            language = f"Detected language: {language_names.get(detected_lang_code, detected_lang_code)} ({detected_lang_code})"
            confidence = round(probs[detected_lang_code] * 100, 2)
            language += f" - Confidence: {confidence}%"

        # Transcribe
        result = model.transcribe(temp_filename)
        transcription = result["text"]

        return transcription, language

    except Exception as e:
        return f"Error transcribing audio: {str(e)}", ""

    finally:
        # Remove temporary file
        if os.path.exists(temp_filename):
            os.remove(temp_filename)

# Make Gradio interface
demo = gr.Interface(
    fn=transcribe_audio,
    inputs=[
        gr.Audio(sources=["microphone"]),
        gr.Checkbox(label="Detect Language", value=False)
    ],
    outputs=[
        gr.Textbox(label="Transcription"),
        gr.Textbox(label="Language Detection")
    ],
    title="ASR Demo + Language Detection",
    description="Record audio and get transcription. Optionally detect the language of the audio.",
    examples=None,
    theme="default"
)

demo.launch(debug=True, share=True)

[0mCollecting gradio
  Using cached gradio-5.25.2-py3-none-any.whl.metadata (16 kB)
Collecting openai-whisper
  Using cached openai_whisper-20240930-py3-none-any.whl
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Using cached aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Using cached fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Using cached ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.8.0 (from gradio)
  Using cached gradio_client-1.8.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Using cached groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
  Using cached safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
Collecting starlette<1.0,>=0.40.0 (from gradio)
  Using cached starlette-0.46.2-py3-none-any.whl.metadata (6.2 kB)
Collecting tiktoken (from openai-whisper)
  Using cached tiktoken-0.9.

100%|█████████████████████████████████████| 72.1M/72.1M [00:00<00:00, 88.6MiB/s]


ASR Model loaded
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://3cd1caf4e1cfe0fbd8.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://3cd1caf4e1cfe0fbd8.gradio.live




In [None]:
!pip install gradio transformers torch accelerate bitsandbytes -q

import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

SYSTEM_PROMPT = """You are a helpful, respectful and honest conversational
assistant. Keep answers concise, usually one sentence at most."""

def load_llm():
    model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        load_in_4bit=True
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    return model, tokenizer

print("Loading LLM model")
model, tokenizer = load_llm()
print("LLM Model loaded")

def format_prompt(messages):
    formatted = [{"role": "system", "content": SYSTEM_PROMPT}]
    for msg in messages:
        formatted.append({"role": "user", "content": msg[0]})
        formatted.append({"role": "assistant", "content": msg[1]})
    return formatted

def generate_response(history, new_input):
    messages = format_prompt(history)

    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True,
        return_attention_mask=True
    ).to(model.device)

    outputs = model.generate(
        inputs,
        attention_mask=torch.ones_like(inputs),
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )

    full_response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)

    return history + [(new_input, full_response)], ""

# Make Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Chat Assistant")
    chatbot = gr.Chatbot(height=500)
    msg = gr.Textbox(label="Your message")
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def chat(history):
        messages = format_prompt(history[:-1])
        new_message = history[-1][0]

        inputs = tokenizer.apply_chat_template(
            messages + [{"role": "user", "content": new_message}],
            return_tensors="pt",
            add_generation_prompt=True,
            padding=True,
            return_attention_mask=True
        ).to(model.device)

        outputs = model.generate(
            inputs,
            attention_mask=torch.ones_like(inputs),
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id
        )

        response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
        history[-1][1] = response
        return history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        chat, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch(debug=True, share=True)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[0mLoading LLM model


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

LLM Model loaded


  chatbot = gr.Chatbot(height=500)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://693e3b28fe93842da5.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://693e3b28fe93842da5.gradio.live




In [None]:
!pip uninstall numpy
!pip install numpy==1.26.4
!pip install ttsmms
!pip install gradio torchaudio transformers torch uroman -q
!pip install gtts # just for chinese + japanese, replace later, and probably do not show off
!pip install soundfile

import gradio as gr
import numpy as np
from ttsmms import TTS, download
from gtts import gTTS
import soundfile as sf

tts_lang_map = {
    "English": "eng",
    "Spanish": "spa",
    "French": "fra",
    "German": "deu",
    "Italian": "ita",
    "Portuguese": "por",
    "Chinese": "cmn",
    "Japanese": "jpn",
    "Korean": "kor",
    "Russian": "rus",
    "Arabic": "arb",
    "Hindi": "hin"
}

loaded_models = {}

# Download + load model
def get_model(lang_code):
    if lang_code not in loaded_models:
        try:
            model_dir = download(lang_code, "./models")
            loaded_models[lang_code] = TTS(model_dir)
        except Exception as e:
            print(f"Error loading model {lang_code}: {e}")
            return None
    return loaded_models[lang_code]

def text_to_speech(text, use_multilingual, language_name):
    if not text.strip():
        return None, "Please enter text"

    lang_code = tts_lang_map.get(language_name, "eng")

    # English if not multilingual
    if not use_multilingual:
        lang_code = "eng"


    if lang_code == 'cmn' or lang_code == 'jpn':
        lang_code = 'zh' if lang_code == 'cmn' else lang_code
        lang_code = 'ja' if lang_code == 'jpn' else lang_code
        tts = gTTS(text, lang=lang_code)
        tts.save("gtts.wav")
        data, sr = sf.read("gtts.wav")
        audio = data.astype(np.float32)

        return ((sr, audio), f"Generated {language_name} audio (Model: {lang_code})")

    tts = get_model(lang_code)
    if not tts:
        return None, f"Model for {language_name} ({lang_code}) not available"

    try:
        result = tts.synthesis(text)
        audio = result["x"].astype(np.float32)
        sr = result["sampling_rate"]
        return (sr, audio), f"Generated {language_name} audio (Model: {lang_code})"
    except Exception as e:
        return None, f"Error: {str(e)}"

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# TTS Multilingual Demo")

    with gr.Row():
        text_input = gr.Textbox(label="Input Text", lines=4,
                              placeholder="Enter text here")
        with gr.Column():
            use_multilingual = gr.Checkbox(label="Multilingual", value=True)
            language_selector = gr.Dropdown(
                label="Language",
                choices=list(tts_lang_map.keys()),
                value="English"
            )

    generate_btn = gr.Button("Generate Speech")
    audio_output = gr.Audio(label="Output", format="wav")
    info_output = gr.Textbox(label="Status")

    generate_btn.click(
        text_to_speech,
        inputs=[text_input, use_multilingual, language_selector],
        outputs=[audio_output, info_output]
    )

demo.launch(share=True, debug=True)

In [None]:
!pip uninstall numpy
!pip install numpy==1.26.4
!pip install gradio transformers torch torchaudio accelerate bitsandbytes uroman -q
!pip install openai-whisper
!pip install ttsmms
!pip install gtts # just for chinese + japanese, replace later, and probably do not show off
!pip install soundfile

import gradio as gr
import whisper
import numpy as np
import tempfile
import os
import torch
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoTokenizer
from ttsmms import TTS, download
from gtts import gTTS

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# System prompt for LLM
SYSTEM_PROMPT = """You are a helpful, respectful and honest conversational
assistant. Keep answers concise, usually one sentence at most."""

tts_lang_map = {
    "Automatic": "auto",
    "English": "eng",
    "Spanish": "spa",
    "French": "fra",
    "German": "deu",
    "Italian": "ita",
    "Portuguese": "por",
    "Chinese": "cmn",
    "Japanese": "jpn",
    "Korean": "kor",
    "Russian": "rus",
    "Arabic": "arb",
    "Hindi": "hin"
}

language_code_to_name = {
    "en": "English",
    "es": "Spanish",
    "fr": "French",
    "de": "German",
    "it": "Italian",
    "pt": "Portuguese",
    "zh": "Chinese",
    "ja": "Japanese",
    "ko": "Korean",
    "ru": "Russian",
    "ar": "Arabic",
    "hi": "Hindi",
}

loaded_tts_models = {}

conversation_history = []

# Load Models

# Load ASR model
print("Loading ASR model...")
asr_model_size = "tiny"  # Options: 'tiny', 'base', 'small', 'medium', 'large'
asr_model = whisper.load_model(asr_model_size, device=device)
print(f"ASR Model '{asr_model_size}' loaded")

# Load LLM model
print("Loading LLM model...")
def load_llm():
    model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        load_in_4bit=True
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    return model, tokenizer

llm_model, llm_tokenizer = load_llm()
print("LLM Model loaded")

# ASR + LID

def transcribe_audio(audio):
    """Transcribe audio and detect language"""
    if audio is None:
        return "No audio recorded", ""

    sample_rate, audio_data = audio

    # Temporary audio file
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
        temp_filename = temp_audio.name

    try:
        # Save audio with sample rate
        import scipy.io.wavfile as wav
        wav.write(temp_filename, sample_rate, audio_data)

        # LID
        audio_data = whisper.load_audio(temp_filename)
        audio_data = whisper.pad_or_trim(audio_data)
        mel = whisper.log_mel_spectrogram(audio_data).to(device)
        _, probs = asr_model.detect_language(mel)
        detected_lang_code = max(probs, key=probs.get)

        language_names = {
            "en": "English", "zh": "Chinese", "de": "German",
            "es": "Spanish", "ru": "Russian", "ko": "Korean",
            "fr": "French", "ja": "Japanese", "pt": "Portuguese",
            "tr": "Turkish", "pl": "Polish", "ca": "Catalan",
            "nl": "Dutch", "ar": "Arabic", "sv": "Swedish",
            "it": "Italian", "id": "Indonesian", "hi": "Hindi",
            "fi": "Finnish", "vi": "Vietnamese", "he": "Hebrew",
            "uk": "Ukrainian", "el": "Greek", "ms": "Malay",
        }

        lang_name = language_names.get(detected_lang_code, detected_lang_code)
        confidence = round(probs[detected_lang_code] * 100, 2)
        language_info = f"Detected: {lang_name} ({detected_lang_code}) - Confidence: {confidence}%"

        # Transcribe
        result = asr_model.transcribe(temp_filename)
        transcription = result["text"].strip()

        return transcription, language_info, detected_lang_code, lang_name

    except Exception as e:
        return f"Error transcribing: {str(e)}", "", "en", "English"

    finally:
        # Delete temp file
        if os.path.exists(temp_filename):
            os.remove(temp_filename)

# LLM

def format_prompt(messages, detected_language=""):
    # Modify system prompt based on detected language
    adjusted_prompt = SYSTEM_PROMPT
    if detected_language and detected_language != "en":
        lang_name = language_code_to_name.get(detected_language, detected_language)
        adjusted_prompt += f"\nThe user is speaking in {lang_name}. Respond in {lang_name} to be understood."
        # Change this to whatever you think will get it to respond in the right language
        # And coherently

    formatted = [{"role": "system", "content": adjusted_prompt}]
    for msg in messages:
        formatted.append({"role": "user", "content": msg[0]})
        formatted.append({"role": "assistant", "content": msg[1]})
    return formatted

def generate_llm_response(input_text, history, detected_language="en"):
    messages = format_prompt(history, detected_language)

    inputs = llm_tokenizer.apply_chat_template(
        messages + [{"role": "user", "content": input_text}],
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True,
        return_attention_mask=True
    ).to(llm_model.device)

    outputs = llm_model.generate(
        inputs,
        attention_mask=torch.ones_like(inputs),
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=llm_tokenizer.eos_token_id
    )

    response = llm_tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
    return response

# TTS

def get_tts_model(lang_code):
    if lang_code not in loaded_tts_models:
        try:
            model_dir = download(lang_code, "./models")
            loaded_tts_models[lang_code] = TTS(model_dir)
        except Exception as e:
            print(f"Error loading TTS model {lang_code}: {e}")
            return None
    return loaded_tts_models[lang_code]

def text_to_speech(text, language_selection="Automatic", detected_lang_code="en"):
    if not text.strip():
        return None, "No text to synthesize"

    if language_selection == "Automatic":
        # from LID
        lang_code_tts = {
            "en": "eng",
            "es": "spa",
            "fr": "fra",
            "de": "deu",
            "it": "ita",
            "pt": "por",
            "zh": "cmn",
            "ja": "jpn",
            "ko": "kor",
            "ru": "rus",
            "ar": "arb",
            "hi": "hin",
        }.get(detected_lang_code, "eng") # english default
    else:
        # from drop-down
        lang_code_tts = tts_lang_map.get(language_selection, "eng")

    # Google TTS for Chinese + Japanese for now
    if lang_code_tts == 'cmn' or lang_code_tts == 'jpn':
        lang_code_gtts = 'zh' if lang_code_tts == 'cmn' else 'ja'
        try:
            tts = gTTS(text, lang=lang_code_gtts)
            tts.save("gtts.wav")
            data, sr = sf.read("gtts.wav")
            audio = data.astype(np.float32)
            return (sr, audio)
        except Exception as e:
            print(f"GTTS error for {lang_code_tts}: {e}")
            # try in English instead
            return text_to_speech(text, "English", "en")

    # TTSMMS
    tts = get_tts_model(lang_code_tts)
    if not tts:
        # Default to English if code is wrong
        print(f"Falling back to English for {lang_code_tts}")
        tts = get_tts_model("eng")
        if not tts:
            return None

    try:
        result = tts.synthesis(text)
        audio = result["x"].astype(np.float32)
        sr = result["sampling_rate"]
        return (sr, audio)
    except Exception as e:
        print(f"TTS error: {e}")
        # Default to English
        if lang_code_tts != "eng":
            print("Falling back to English TTS")
            return text_to_speech(text, "English", "en")
        return None

# SDS

def process_speech_input(audio, output_language):
    global conversation_history

    # Transcribe audio, do LID
    if audio is None:
        return None, "", "No audio input received", [], ""

    transcription, language_info, detected_lang_code, detected_lang_name = transcribe_audio(audio)

    if not transcription or transcription.startswith("Error"):
        return None, "", f"Transcription failed: {transcription}", conversation_history, language_info

    # Get LLM response
    llm_response = generate_llm_response(transcription, conversation_history, detected_lang_code)

    # Apply TTS
    speech_output = text_to_speech(llm_response, output_language, detected_lang_code)

    conversation_history.append([transcription, llm_response])

    # Prevent context from getting too big, don't want it to hallucinate
    if len(conversation_history) > 5:
        conversation_history = conversation_history[-5:]

    # Format chat display
    chat_display = []
    for user_msg, assistant_msg in conversation_history:
        chat_display.append([user_msg, assistant_msg])

    return speech_output, transcription, llm_response, chat_display, language_info

def clear_conversation():
    global conversation_history
    conversation_history = []
    return None, "", "", [], ""

# Gradio

with gr.Blocks() as demo:
    gr.Markdown("# Spoken Dialogue System")
    gr.Markdown("Robust Multilingual Cascaded Spoken Dialogue System Demo.")

    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here")
            output_language = gr.Dropdown(
                label="Output Speech Language",
                choices=list(tts_lang_map.keys()),
                value="Automatic"
            )
            submit_btn = gr.Button("Process Speech", variant="primary")
            clear_btn = gr.Button("Clear Conversation")

        with gr.Column(scale=2):
            audio_output = gr.Audio(label="Assistant Response (Audio)")
            transcription_output = gr.Textbox(label="Your Speech (Transcribed)")
            language_info_output = gr.Textbox(label="Language Detection")
            response_output = gr.Textbox(label="Assistant Response (Text)")
            chat_history = gr.Chatbot(label="Conversation History", height=400)

    submit_btn.click(
        process_speech_input,
        inputs=[audio_input, output_language],
        outputs=[audio_output, transcription_output, response_output, chat_history, language_info_output]
    )

    clear_btn.click(
        clear_conversation,
        inputs=[],
        outputs=[audio_output, transcription_output, response_output, chat_history, language_info_output]
    )

print("Starting Spoken Dialogue System...")
demo.launch(debug=True, share=True)

Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Would remove:
    /usr/local/bin/f2py
    /usr/local/lib/python3.11/dist-packages/numpy-1.26.4.dist-info/*
    /usr/local/lib/python3.11/dist-packages/numpy.libs/libgfortran-040039e1.so.5.0.0
    /usr/local/lib/python3.11/dist-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
    /usr/local/lib/python3.11/dist-packages/numpy.libs/libquadmath-96973f99.so.0.0.0
    /usr/local/lib/python3.11/dist-packages/numpy/*
Proceed (Y/n)? Y
  Successfully uninstalled numpy-1.26.4
Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
Installing collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8

100%|█████████████████████████████████████| 72.1M/72.1M [00:01<00:00, 66.5MiB/s]


ASR Model 'tiny' loaded
Loading LLM model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

LLM Model loaded
Starting Spoken Dialogue System...


  chat_history = gr.Chatbot(label="Conversation History", height=400)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://cf277bed4200a2ebc3.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




downloading eng from https://dl.fbaipublicfiles.com/mms/tts/eng.tar.gz
extract all eng to ./models/eng
Done


  WeightNorm.apply(module, name, dim)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://cf277bed4200a2ebc3.gradio.live




temp for testing cchwang

In [None]:
!pip uninstall numpy
!pip install numpy==1.26.4
!pip install gradio transformers torch torchaudio accelerate bitsandbytes uroman -q
!pip install openai-whisper
!pip install ttsmms
!pip install gtts # just for chinese + japanese, replace later, and probably do not show off
!pip install soundfile



Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
  Would remove:
    /usr/local/bin/f2py
    /usr/local/bin/numpy-config
    /usr/local/lib/python3.11/dist-packages/numpy-2.0.2.dist-info/*
    /usr/local/lib/python3.11/dist-packages/numpy.libs/libgfortran-040039e1-0352e75f.so.5.0.0
    /usr/local/lib/python3.11/dist-packages/numpy.libs/libquadmath-96973f99-934c22de.so.0.0.0
    /usr/local/lib/python3.11/dist-packages/numpy.libs/libscipy_openblas64_-99b71e71.so
    /usr/local/lib/python3.11/dist-packages/numpy/*
Proceed (Y/n)? Y
  Successfully uninstalled numpy-2.0.2
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [None]:

import gradio as gr
import whisper
import numpy as np
import tempfile
import os
import torch
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoTokenizer
from ttsmms import TTS, download
from gtts import gTTS

In [None]:
# Add this function to detect language in audio
def detect_audio_language(audio_data, sample_rate):
    """Detect language in audio using whisper"""
    try:
        # Save audio to temporary file
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
            temp_filename = temp_audio.name

        # Save the audio data
        sf.write(temp_filename, audio_data, sample_rate)

        # Process with whisper for language detection
        audio_data = whisper.load_audio(temp_filename)
        audio_data = whisper.pad_or_trim(audio_data)
        mel = whisper.log_mel_spectrogram(audio_data).to(device)
        _, probs = asr_model.detect_language(mel)
        detected_lang_code = max(probs, key=probs.get)

        language_names = {
            "en": "English", "zh": "Chinese", "de": "German",
            "es": "Spanish", "ru": "Russian", "ko": "Korean",
            "fr": "French", "ja": "Japanese", "pt": "Portuguese",
            "tr": "Turkish", "pl": "Polish", "ca": "Catalan",
            "nl": "Dutch", "ar": "Arabic", "sv": "Swedish",
            "it": "Italian", "id": "Indonesian", "hi": "Hindi",
            "fi": "Finnish", "vi": "Vietnamese", "he": "Hebrew",
            "uk": "Ukrainian", "el": "Greek", "ms": "Malay",
        }

        lang_name = language_names.get(detected_lang_code, detected_lang_code)
        confidence = round(probs[detected_lang_code] * 100, 2)
        language_info = f"TTS Output Language: {lang_name} ({detected_lang_code}) - Confidence: {confidence}%"

        return language_info, detected_lang_code, lang_name

    except Exception as e:
        return f"Error detecting TTS language: {str(e)}", "unknown", "Unknown"

    finally:
        # Delete temp file
        if os.path.exists(temp_filename):
            os.remove(temp_filename)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# System prompt for LLM
SYSTEM_PROMPT = """You are a helpful, respectful and honest conversational
assistant. Keep answers concise, usually one sentence at most."""

tts_lang_map = {
    "Automatic": "auto",
    "English": "eng",
    "Spanish": "spa",
    "French": "fra",
    "German": "deu",
    "Italian": "ita",
    "Portuguese": "por",
    "Chinese": "cmn",
    "Japanese": "jpn",
    "Korean": "kor",
    "Russian": "rus",
    "Arabic": "arb",
    "Hindi": "hin"
}

language_code_to_name = {
    "en": "English",
    "es": "Spanish",
    "fr": "French",
    "de": "German",
    "it": "Italian",
    "pt": "Portuguese",
    "zh": "Chinese",
    "ja": "Japanese",
    "ko": "Korean",
    "ru": "Russian",
    "ar": "Arabic",
    "hi": "Hindi",
}

loaded_tts_models = {}

conversation_history = []

# Load Models

# Load ASR model
print("Loading ASR model...")
asr_model_size = "tiny"  # Options: 'tiny', 'base', 'small', 'medium', 'large'
asr_model = whisper.load_model(asr_model_size, device=device)
print(f"ASR Model '{asr_model_size}' loaded")

# Load LLM model
print("Loading LLM model...")
def load_llm():
    model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.float16,
        load_in_4bit=True
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    return model, tokenizer

llm_model, llm_tokenizer = load_llm()
print("LLM Model loaded")

# ASR + LID

def transcribe_audio(audio):
    """Transcribe audio and detect language"""
    if audio is None:
        return "No audio recorded", ""

    sample_rate, audio_data = audio

    # Temporary audio file
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
        temp_filename = temp_audio.name

    try:
        # Save audio with sample rate
        import scipy.io.wavfile as wav
        wav.write(temp_filename, sample_rate, audio_data)

        # LID
        audio_data = whisper.load_audio(temp_filename)
        audio_data = whisper.pad_or_trim(audio_data)
        mel = whisper.log_mel_spectrogram(audio_data).to(device)
        _, probs = asr_model.detect_language(mel)
        detected_lang_code = max(probs, key=probs.get)

        language_names = {
            "en": "English", "zh": "Chinese", "de": "German",
            "es": "Spanish", "ru": "Russian", "ko": "Korean",
            "fr": "French", "ja": "Japanese", "pt": "Portuguese",
            "tr": "Turkish", "pl": "Polish", "ca": "Catalan",
            "nl": "Dutch", "ar": "Arabic", "sv": "Swedish",
            "it": "Italian", "id": "Indonesian", "hi": "Hindi",
            "fi": "Finnish", "vi": "Vietnamese", "he": "Hebrew",
            "uk": "Ukrainian", "el": "Greek", "ms": "Malay",
        }

        lang_name = language_names.get(detected_lang_code, detected_lang_code)
        confidence = round(probs[detected_lang_code] * 100, 2)
        language_info = f"Detected: {lang_name} ({detected_lang_code}) - Confidence: {confidence}%"

        # Transcribe
        result = asr_model.transcribe(temp_filename)
        transcription = result["text"].strip()

        return transcription, language_info, detected_lang_code, lang_name

    except Exception as e:
        return f"Error transcribing: {str(e)}", "", "en", "English"

    finally:
        # Delete temp file
        if os.path.exists(temp_filename):
            os.remove(temp_filename)

# LLM

def format_prompt(messages, detected_language=""):
    # Modify system prompt based on detected language
    adjusted_prompt = SYSTEM_PROMPT
    if detected_language and detected_language != "en":
        lang_name = language_code_to_name.get(detected_language, detected_language)
        adjusted_prompt += f"\nThe user is speaking in {lang_name}. Respond in {lang_name} to be understood."
        # Change this to whatever you think will get it to respond in the right language
        # And coherently

    formatted = [{"role": "system", "content": adjusted_prompt}]
    for msg in messages:
        formatted.append({"role": "user", "content": msg[0]})
        formatted.append({"role": "assistant", "content": msg[1]})
    return formatted

def generate_llm_response(input_text, history, detected_language="en"):
    messages = format_prompt(history, detected_language)

    inputs = llm_tokenizer.apply_chat_template(
        messages + [{"role": "user", "content": input_text}],
        return_tensors="pt",
        add_generation_prompt=True,
        padding=True,
        return_attention_mask=True
    ).to(llm_model.device)

    outputs = llm_model.generate(
        inputs,
        attention_mask=torch.ones_like(inputs),
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        top_p=0.95,
        pad_token_id=llm_tokenizer.eos_token_id
    )

    response = llm_tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
    return response

# tts

def get_tts_model(lang_code):
    if lang_code not in loaded_tts_models:
        try:
            model_dir = download(lang_code, "./models")
            loaded_tts_models[lang_code] = TTS(model_dir)
        except Exception as e:
            print(f"Error loading TTS model {lang_code}: {e}")
            return None
    return loaded_tts_models[lang_code]

def text_to_speech(text, language_selection="Automatic", detected_lang_code="en"):
    if not text.strip():
        return None, "No text to synthesize"

    # Store original text for reference (will be returned with the audio)
    original_text = text

    # Truncate text to 250 characters
    if len(text) > 250:
        # Try to truncate at a sentence boundary first
        import re
        sentences = re.split(r'(?<=[.!?])\s+', text[:250])
        if len(sentences) > 1:
            # Truncate at the last complete sentence within the 250 character limit
            truncated_text = " ".join(sentences[:-1])
        else:
            # If no sentence boundary found, truncate at a word boundary
            words = text[:250].split()
            truncated_text = " ".join(words[:-1])  # Remove the last word which might be cut off
            if not truncated_text:  # In case there's only one long word
                truncated_text = text[:250]

        print(f"Truncated TTS input from {len(text)} to {len(truncated_text)} characters")
        text = truncated_text

    if language_selection == "Automatic":
        # from LID
        lang_code_tts = {
            "en": "eng",
            "es": "spa",
            "fr": "fra",
            "de": "deu",
            "it": "ita",
            "pt": "por",
            "zh": "cmn",
            "ja": "jpn",
            "ko": "kor",
            "ru": "rus",
            "ar": "arb",
            "hi": "hin",
        }.get(detected_lang_code, "eng") # english default
    else:
        # from drop-down
        lang_code_tts = tts_lang_map.get(language_selection, "eng")

    # Lowercase the text only for English TTS to prevent it from skipping uppercase letters
    if lang_code_tts == 'eng':
        text = text.lower()

    # Google TTS for Chinese + Japanese for now
    if lang_code_tts == 'cmn' or lang_code_tts == 'jpn':
        lang_code_gtts = 'zh' if lang_code_tts == 'cmn' else 'ja'
        try:
            tts = gTTS(text, lang=lang_code_gtts)
            tts.save("gtts.wav")
            data, sr = sf.read("gtts.wav")
            audio = data.astype(np.float32)
            return (sr, audio)
        except Exception as e:
            print(f"GTTS error for {lang_code_tts}: {e}")
            # try in English instead
            return text_to_speech(original_text, "English", "en")  # Use original text when falling back

    # TTSMMS
    tts = get_tts_model(lang_code_tts)
    if not tts:
        # Default to English if code is wrong
        print(f"Falling back to English for {lang_code_tts}")
        tts = get_tts_model("eng")
        if not tts:
            return None

    try:
        result = tts.synthesis(text)
        audio = result["x"].astype(np.float32)
        sr = result["sampling_rate"]
        return (sr, audio)
    except Exception as e:
        print(f"TTS error: {e}")
        # Default to English
        if lang_code_tts != "eng":
            print("Falling back to English TTS")
            return text_to_speech(original_text, "English", "en")  # Use original text when falling back
        return None

# SDS
# Update the process_speech_input function
def process_speech_input(audio, output_language):
    global conversation_history

    # Transcribe audio, do LID
    if audio is None:
        return None, "", "No audio input received", [], "", ""

    transcription, language_info, detected_lang_code, detected_lang_name = transcribe_audio(audio)

    if not transcription or transcription.startswith("Error"):
        return None, "", f"Transcription failed: {transcription}", conversation_history, language_info, ""

    # Get LLM response
    llm_response = generate_llm_response(transcription, conversation_history, detected_lang_code)

    # Apply TTS
    speech_output = text_to_speech(llm_response, output_language, detected_lang_code)

    # Detect language in TTS output
    tts_language_info = ""
    if speech_output:
        tts_language_info, tts_lang_code, tts_lang_name = detect_audio_language(speech_output[1], speech_output[0])

    conversation_history.append([transcription, llm_response])

    # Prevent context from getting too big, don't want it to hallucinate
    if len(conversation_history) > 5:
        conversation_history = conversation_history[-5:]

    # Format chat display
    chat_display = []
    for user_msg, assistant_msg in conversation_history:
        chat_display.append([user_msg, assistant_msg])

    return speech_output, transcription, llm_response, chat_display, language_info, tts_language_info


def clear_conversation():
    global conversation_history
    conversation_history = []
    return None, "", "", [], "", ""

# Update Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Spoken Dialogue System")
    gr.Markdown("Robust Multilingual Cascaded Spoken Dialogue System Demo.")

    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Speak Here")
            output_language = gr.Dropdown(
                label="Output Speech Language",
                choices=list(tts_lang_map.keys()),
                value="Automatic"
            )
            submit_btn = gr.Button("Process Speech", variant="primary")
            clear_btn = gr.Button("Clear Conversation")

        with gr.Column(scale=2):
            audio_output = gr.Audio(label="Assistant Response (Audio)")
            transcription_output = gr.Textbox(label="Your Speech (Transcribed)")
            language_info_output = gr.Textbox(label="Input Language Detection")
            tts_language_info_output = gr.Textbox(label="Output Language Detection")  # New element
            response_output = gr.Textbox(label="Assistant Response (Text)")
            chat_history = gr.Chatbot(label="Conversation History", height=400)

    submit_btn.click(
        process_speech_input,
        inputs=[audio_input, output_language],
        outputs=[audio_output, transcription_output, response_output, chat_history,
                language_info_output, tts_language_info_output]  # Updated outputs
    )

    clear_btn.click(
        clear_conversation,
        inputs=[],
        outputs=[audio_output, transcription_output, response_output, chat_history,
                language_info_output, tts_language_info_output]  # Updated outputs
    )

print("Starting Spoken Dialogue System...")
demo.launch(debug=True, share=True)

Using device: cuda
Loading ASR model...


100%|█████████████████████████████████████| 72.1M/72.1M [00:01<00:00, 39.4MiB/s]


ASR Model 'tiny' loaded
Loading LLM model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

LLM Model loaded
Starting Spoken Dialogue System...


  chat_history = gr.Chatbot(label="Conversation History", height=400)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://4e33a1ca313e85274c.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




Truncated TTS input from 955 to 215 characters
downloading eng from https://dl.fbaipublicfiles.com/mms/tts/eng.tar.gz
extract all eng to ./models/eng
Done


  WeightNorm.apply(module, name, dim)


Truncated TTS input from 932 to 222 characters
downloading spa from https://dl.fbaipublicfiles.com/mms/tts/spa.tar.gz
extract all spa to ./models/spa
Done


  WeightNorm.apply(module, name, dim)


Truncated TTS input from 927 to 212 characters




Truncated TTS input from 391 to 128 characters




Truncated TTS input from 512 to 186 characters
downloading por from https://dl.fbaipublicfiles.com/mms/tts/por.tar.gz
extract all por to ./models/por
Done


  WeightNorm.apply(module, name, dim)


Truncated TTS input from 746 to 207 characters
downloading hin from https://dl.fbaipublicfiles.com/mms/tts/hin.tar.gz
extract all hin to ./models/hin
Done


  WeightNorm.apply(module, name, dim)
