In [None]:
# Install required dependencies
!pip install -q gradio
!pip install -q librosa soundfile requests
!pip install -q sacrebleu
!pip install -q omegaconf hydra-core
!pip install -q pytorch-lightning
!python -m pip install -q "nemo_toolkit[asr,tts] @ git+https://github.com/NVIDIA/NeMo.git"

import gradio as gr
import librosa
import soundfile as sf
import tempfile
import os
import torch
from pathlib import Path
import numpy as np
import requests
import math
from typing import List, Tuple
import re

# Global variable to store the model
model = None
device_info = ""

def check_gpu_availability():
    """Check GPU availability and return device info"""
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        gpu_name = torch.cuda.get_device_name(0)
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        return f"✅ GPU Available: {gpu_name} ({gpu_memory:.1f}GB) - {gpu_count} device(s)"
    else:
        return "❌ No GPU available - using CPU (will be very slow)"

def get_gpu_memory_usage():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(0) / 1024**3
        cached = torch.cuda.memory_reserved(0) / 1024**3
        return f"GPU Memory: {allocated:.1f}GB allocated, {cached:.1f}GB cached"
    return "CPU mode"

def download_example_files():
    """Download example audio files for testing"""
    examples_dir = "example_audio"
    os.makedirs(examples_dir, exist_ok=True)

    example_urls = [
        ("https://cdn-media.huggingface.co/speech_samples/sample1.flac", "librispeech_sample1.flac"),
        ("https://cdn-media.huggingface.co/speech_samples/sample2.flac", "librispeech_sample2.flac")
    ]

    downloaded_files = []

    for url, filename in example_urls:
        filepath = os.path.join(examples_dir, filename)

        if not os.path.exists(filepath):
            try:
                print(f"Downloading {filename}...")
                response = requests.get(url)
                response.raise_for_status()

                with open(filepath, 'wb') as f:
                    f.write(response.content)
                print(f"✅ Downloaded {filename}")

            except Exception as e:
                print(f"❌ Failed to download {filename}: {e}")
                continue

        if os.path.exists(filepath):
            downloaded_files.append(filepath)

    return downloaded_files

def load_model():
    """Load the Canary-Qwen-2.5B model with GPU support"""
    global model, device_info

    try:
        # Check GPU availability first
        device_info = check_gpu_availability()
        print(device_info)

        # Clear GPU cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        from nemo.collections.speechlm2.models import SALM
        print("Loading Canary-Qwen-2.5B model...")

        # Load model - NeMo should automatically use GPU if available
        model = SALM.from_pretrained('nvidia/canary-qwen-2.5b')

        # Ensure model is on GPU if available
        if torch.cuda.is_available():
            model = model.cuda()
            print(f"Model moved to GPU: {next(model.parameters()).device}")
        else:
            print("Model loaded on CPU")

        # Get memory usage after loading
        memory_info = get_gpu_memory_usage()
        print(f"Model loaded successfully! {memory_info}")

        return f"✅ Model loaded successfully!\n{device_info}\n{memory_info}"

    except Exception as e:
        error_msg = f"Error loading model: {str(e)}"
        print(error_msg)
        return f"❌ {error_msg}\n{device_info}"

def preprocess_audio(audio_path):
    """Preprocess audio to meet model requirements (16kHz, mono)"""
    try:
        # Load audio file
        audio, sr = librosa.load(audio_path, sr=16000, mono=True)

        # Create temporary file with processed audio
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
        sf.write(temp_file.name, audio, 16000)

        return temp_file.name, len(audio) / 16000  # Return path and duration
    except Exception as e:
        raise Exception(f"Error processing audio: {str(e)}")

def split_audio_into_chunks(audio_path: str, chunk_duration: float = 30.0, overlap_duration: float = 2.0) -> List[Tuple[str, float, float]]:
    """
    Split audio into overlapping chunks for processing

    Args:
        audio_path: Path to audio file
        chunk_duration: Duration of each chunk in seconds
        overlap_duration: Overlap between chunks in seconds

    Returns:
        List of tuples containing (chunk_path, start_time, end_time)
    """
    try:
        # Load audio
        audio, sr = librosa.load(audio_path, sr=16000, mono=True)
        total_duration = len(audio) / sr

        if total_duration <= chunk_duration:
            # Audio is short enough, return as single chunk
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
            sf.write(temp_file.name, audio, sr)
            return [(temp_file.name, 0.0, total_duration)]

        chunks = []
        start_sample = 0
        chunk_samples = int(chunk_duration * sr)
        overlap_samples = int(overlap_duration * sr)

        while start_sample < len(audio):
            end_sample = min(start_sample + chunk_samples, len(audio))

            # Extract chunk
            chunk_audio = audio[start_sample:end_sample]

            # Save chunk to temporary file
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
            sf.write(temp_file.name, chunk_audio, sr)

            start_time = start_sample / sr
            end_time = end_sample / sr

            chunks.append((temp_file.name, start_time, end_time))

            # Move to next chunk with overlap
            start_sample += chunk_samples - overlap_samples

            # Break if we've reached the end
            if end_sample >= len(audio):
                break

        return chunks

    except Exception as e:
        raise Exception(f"Error splitting audio: {str(e)}")

def transcribe_chunk(chunk_path: str, max_tokens: int = 128) -> str:
    """Transcribe a single audio chunk"""
    global model

    if model is None:
        raise Exception("Model not loaded")

    try:
        # Clear GPU cache before inference
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Transcribe using the model
        with torch.cuda.amp.autocast() if torch.cuda.is_available() else torch.no_grad():
            answer_ids = model.generate(
                prompts=[
                    [{"role": "user",
                      "content": f"Transcribe the following: {model.audio_locator_tag}",
                      "audio": [chunk_path]}]
                ],
                max_new_tokens=max_tokens,
            )

        # Get transcription
        transcription = model.tokenizer.ids_to_text(answer_ids[0].cpu())

        return transcription.strip()

    except Exception as e:
        raise Exception(f"Error transcribing chunk: {str(e)}")

def merge_overlapping_transcriptions(transcriptions: List[Tuple[str, float, float]], overlap_duration: float = 2.0) -> str:
    """
    Merge overlapping transcriptions by removing duplicate content in overlap regions

    Args:
        transcriptions: List of (transcription, start_time, end_time) tuples
        overlap_duration: Duration of overlap between chunks

    Returns:
        Combined transcription
    """
    if not transcriptions:
        return ""

    if len(transcriptions) == 1:
        return transcriptions[0][0]

    merged_text = transcriptions[0][0]  # Start with first transcription

    for i in range(1, len(transcriptions)):
        current_text = transcriptions[i][0]

        # Simple approach: try to find common words at the end of previous and start of current
        prev_words = merged_text.split()
        curr_words = current_text.split()

        # Look for overlap in the last few words of previous and first few words of current
        max_overlap_words = min(len(prev_words), len(curr_words), 20)  # Check up to 20 words

        best_overlap = 0
        for j in range(1, max_overlap_words + 1):
            if prev_words[-j:] == curr_words[:j]:
                best_overlap = j

        if best_overlap > 0:
            # Remove overlapping words from current transcription
            merged_text += " " + " ".join(curr_words[best_overlap:])
        else:
            # No clear overlap found, just concatenate with space
            merged_text += " " + current_text

    return merged_text.strip()

def transcribe_audio(audio_file, max_tokens=128):
    """Transcribe audio using Canary-Qwen-2.5B"""
    global model

    if model is None:
        return "❌ Model not loaded. Please load the model first."

    if audio_file is None:
        return "❌ Please upload an audio file."

    try:
        # Preprocess audio
        processed_audio_path, duration = preprocess_audio(audio_file)

        # Check duration (model supports up to 40s)
        if duration > 40:
            warning_msg = f"⚠️ Audio duration ({duration:.1f}s) exceeds recommended limit of 40s. Results may be degraded."
            print(warning_msg)

        # Clear GPU cache before inference
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        print(f"Starting transcription on device: {next(model.parameters()).device}")

        # Transcribe using the model
        with torch.cuda.amp.autocast() if torch.cuda.is_available() else torch.no_grad():
            answer_ids = model.generate(
                prompts=[
                    [{"role": "user",
                      "content": f"Transcribe the following: {model.audio_locator_tag}",
                      "audio": [processed_audio_path]}]
                ],
                max_new_tokens=max_tokens,
            )

        # Get transcription
        transcription = model.tokenizer.ids_to_text(answer_ids[0].cpu())

        # Clean up temporary file
        os.unlink(processed_audio_path)

        # Get current memory usage
        memory_info = get_gpu_memory_usage()

        result = f"🎯 **Transcription:** {transcription}\n\n📊 **Duration:** {duration:.1f}s\n💾 **{memory_info}**"

        if duration > 40:
            result = f"⚠️ Audio duration ({duration:.1f}s) exceeds recommended limit of 40s.\n\n" + result

        return result

    except Exception as e:
        return f"❌ Error during transcription: {str(e)}"

def transcribe_long_audio(audio_file, chunk_duration=30, overlap_duration=2, max_tokens=128, progress=gr.Progress()):
    """Transcribe long audio by splitting into chunks and processing with streaming"""
    global model

    if model is None:
        return "❌ Model not loaded. Please load the model first."

    if audio_file is None:
        return "❌ Please upload an audio file."

    try:
        # Preprocess audio to get duration
        processed_audio_path, total_duration = preprocess_audio(audio_file)

        progress(0, desc=f"Preparing audio ({total_duration:.1f}s)...")

        # Split audio into chunks
        chunks = split_audio_into_chunks(
            processed_audio_path,
            chunk_duration=chunk_duration,
            overlap_duration=overlap_duration
        )

        progress(0.1, desc=f"Split into {len(chunks)} chunks...")

        # Transcribe each chunk
        transcriptions = []

        for i, (chunk_path, start_time, end_time) in enumerate(chunks):
            progress_val = 0.1 + (i / len(chunks)) * 0.8
            progress(progress_val, desc=f"Transcribing chunk {i+1}/{len(chunks)} ({start_time:.1f}s-{end_time:.1f}s)...")

            try:
                transcription = transcribe_chunk(chunk_path, max_tokens)
                transcriptions.append((transcription, start_time, end_time))

                # Clean up chunk file
                os.unlink(chunk_path)

            except Exception as e:
                print(f"Error transcribing chunk {i+1}: {e}")
                transcriptions.append((f"[Error in chunk {i+1}]", start_time, end_time))

        progress(0.9, desc="Merging transcriptions...")

        # Merge overlapping transcriptions
        final_transcription = merge_overlapping_transcriptions(transcriptions, overlap_duration)

        # Clean up main processed audio file
        os.unlink(processed_audio_path)

        # Get memory usage
        memory_info = get_gpu_memory_usage()

        progress(1.0, desc="Complete!")

        # Format result with detailed info
        result = f"""🎯 **Final Transcription:**
{final_transcription}

📊 **Processing Details:**
- Total Duration: {total_duration:.1f}s
- Number of Chunks: {len(chunks)}
- Chunk Duration: {chunk_duration}s
- Overlap Duration: {overlap_duration}s
- Successful Chunks: {len([t for t in transcriptions if not t[0].startswith('[Error')])}

💾 **{memory_info}**"""

        return result

    except Exception as e:
        return f"❌ Error during long audio transcription: {str(e)}"

def post_process_transcript(transcript, user_prompt, max_tokens=512):
    """Use LLM mode for post-processing the transcript"""
    global model

    if model is None:
        return "❌ Model not loaded. Please load the model first."

    if not transcript.strip():
        return "❌ Please provide a transcript to process."

    if not user_prompt.strip():
        return "❌ Please provide a prompt for post-processing."

    try:
        # Clear GPU cache before inference
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        print(f"Starting LLM processing on device: {next(model.parameters()).device}")

        # Use LLM mode (disable adapter to use base LLM capabilities)
        with model.llm.disable_adapter():
            with torch.cuda.amp.autocast() if torch.cuda.is_available() else torch.no_grad():
                answer_ids = model.generate(
                    prompts=[[{"role": "user",
                              "content": f"{user_prompt}\n\n{transcript}"}]],
                    max_new_tokens=max_tokens,
                )

        response = model.tokenizer.ids_to_text(answer_ids[0].cpu())

        # Get current memory usage
        memory_info = get_gpu_memory_usage()

        return f"🤖 **Response:** {response}\n\n💾 **{memory_info}**"

    except Exception as e:
        return f"❌ Error during post-processing: {str(e)}"

# Check GPU availability at startup
print("=== GPU Information ===")
startup_gpu_info = check_gpu_availability()
print(startup_gpu_info)

# Download example files
print("\n=== Downloading Example Files ===")
example_files = download_example_files()

# Create example prompts for post-processing
example_prompts = [
    "Summarize this transcript in 2-3 sentences:",
    "Extract the main topics discussed in this transcript:",
    "What are the key action items mentioned in this transcript?",
    "Identify any questions asked in this transcript:",
    "Correct any grammar or formatting issues in this transcript:"
]

# Create the Gradio interface
with gr.Blocks(title="Nvidia Canary-Qwen-2.5B Speech Recognition", theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"""
    # 🎤 Nvidia Canary-Qwen-2.5B Speech Recognition

    This interface uses Nvidia's state-of-the-art Canary-Qwen-2.5B model for English speech recognition and transcript post-processing.

    **System Info:** {startup_gpu_info}

    **Features:**
    - 🎯 High-accuracy English speech transcription
    - 📝 Automatic punctuation and capitalization
    - 🚀 **NEW:** Streaming support for long audio files
    - 🤖 LLM-powered transcript post-processing
    - ⚡ Fast inference (418 RTFx on GPU)
    - 🚀 GPU acceleration enabled

    **Supported formats:** WAV, FLAC, MP3, M4A (automatically converted to 16kHz mono)
    """)

    with gr.Tab("🎤 Speech Transcription"):
        with gr.Row():
            with gr.Column():
                # Model loading section
                gr.Markdown("### 1. Load Model")
                load_btn = gr.Button("🔄 Load Canary-Qwen-2.5B Model", variant="primary")
                load_status = gr.Textbox(label="Status", interactive=False, lines=3)

                gr.Markdown("### 2. Upload Audio")
                audio_input = gr.Audio(
                    label="Upload Audio File (≤40s recommended)",
                    type="filepath"
                )

                gr.Markdown("### 3. Transcription Settings")
                max_tokens = gr.Slider(
                    minimum=32,
                    maximum=512,
                    value=128,
                    step=32,
                    label="Max Tokens"
                )

                transcribe_btn = gr.Button("🎯 Transcribe Audio", variant="secondary")

            with gr.Column():
                gr.Markdown("### 📄 Transcription Result")
                transcription_output = gr.Textbox(
                    label="Transcription",
                    lines=12,
                    interactive=False
                )

        # Example audio files section (only show if files were downloaded successfully)
        if example_files:
            gr.Markdown("""
            ### 🎵 Try with Example Audio
            Click on one of the example files below to test the transcription.
            """)

            examples = gr.Examples(
                examples=[[file] for file in example_files],
                inputs=[audio_input],
                label="LibriSpeech Example Files"
            )

    with gr.Tab("🎬 Long Audio Streaming"):
        gr.Markdown("""
        ### 🚀 Stream Long Audio Files
        Process audio files of any length by automatically splitting them into chunks with smart overlap handling.
        Perfect for meetings, lectures, podcasts, and long-form content.
        """)

        with gr.Row():
            with gr.Column():
                gr.Markdown("### Upload Long Audio")
                long_audio_input = gr.Audio(
                    label="Upload Audio File (any length)",
                    type="filepath"
                )

                gr.Markdown("### Streaming Settings")
                with gr.Row():
                    chunk_duration = gr.Slider(
                        minimum=15,
                        maximum=35,
                        value=30,
                        step=5,
                        label="Chunk Duration (seconds)"
                    )
                    overlap_duration = gr.Slider(
                        minimum=1,
                        maximum=5,
                        value=2,
                        step=1,
                        label="Overlap Duration (seconds)"
                    )

                streaming_max_tokens = gr.Slider(
                    minimum=32,
                    maximum=512,
                    value=128,
                    step=32,
                    label="Max Tokens per Chunk"
                )

                stream_btn = gr.Button("🎬 Start Streaming Transcription", variant="primary")

            with gr.Column():
                gr.Markdown("### 📺 Streaming Results")
                streaming_output = gr.Textbox(
                    label="Live Transcription",
                    lines=15,
                    interactive=False
                )

        gr.Markdown("""
        **How it works:**
        1. Audio is split into overlapping chunks (default: 30s chunks with 2s overlap)
        2. Each chunk is transcribed independently
        3. Overlapping regions are intelligently merged to avoid duplicate words
        4. Progress is shown in real-time

        **Optimal settings:**
        - **Chunk Duration:** 30s (balances accuracy and processing speed)
        - **Overlap:** 2s (prevents word cutoffs at boundaries)
        - Use shorter chunks for very noisy audio or multiple speakers
        """)

    with gr.Tab("🤖 Transcript Post-Processing"):
        gr.Markdown("""
        ### LLM Mode - Post-Process Your Transcripts
        Use the underlying LLM capabilities to analyze, summarize, or process your transcripts.
        """)

        with gr.Row():
            with gr.Column():
                transcript_input = gr.Textbox(
                    label="Transcript",
                    lines=6,
                    placeholder="Paste your transcript here or use the output from the transcription tabs..."
                )

                user_prompt = gr.Dropdown(
                    choices=example_prompts,
                    label="Choose a prompt or type your own",
                    allow_custom_value=True,
                    value=example_prompts[0]
                )

                llm_max_tokens = gr.Slider(
                    minimum=64,
                    maximum=1024,
                    value=512,
                    step=64,
                    label="Max Response Tokens"
                )

                process_btn = gr.Button("🤖 Process Transcript", variant="secondary")

            with gr.Column():
                llm_output = gr.Textbox(
                    label="LLM Response",
                    lines=12,
                    interactive=False
                )

    with gr.Tab("ℹ️ Model Information"):
        gr.Markdown(f"""
        ### About Canary-Qwen-2.5B

        **System Information:**
        {startup_gpu_info}

        **Model Details:**
        - 🔢 **Parameters:** 2.5 billion
        - 🌍 **Language:** English only
        - ⚡ **Speed:** 418 RTFx (Real-Time Factor on GPU)
        - 🎯 **Architecture:** Speech-Augmented Language Model (SALM)
        - 📊 **Performance:** State-of-the-art on multiple English benchmarks

        **Capabilities:**
        - **ASR Mode:** High-accuracy speech transcription with punctuation and capitalization
        - **Streaming Mode:** Process long audio files with intelligent chunking
        - **LLM Mode:** Text-only processing for summarization, Q&A, and analysis

        **Limitations:**
        - Maximum audio duration per chunk: 40 seconds (handled automatically in streaming mode)
        - English language only
        - Requires 16kHz mono audio (automatically handled by this interface)

        **Training Data:**
        - 234.5k hours of English speech data
        - Public datasets including LibriSpeech, Common Voice, YouTube Commons, and more

        **License:** CC-BY-4.0

        ### Performance Benchmarks

        **WER (Word Error Rate) on OpenASR Leaderboard:**
        - LibriSpeech Clean: 1.60%
        - LibriSpeech Other: 3.10%
        - AMI Meetings: 10.18%
        - GigaSpeech: 9.41%
        - Earnings-22: 10.42%
        - SPGISpeech: 1.90%
        - Tedlium: 2.72%
        - VoxPopuli: 5.66%

        ### Streaming Features

        **New in this version:**
        - 🎬 **Long Audio Support:** Process audio files of any length
        - 🔄 **Intelligent Chunking:** Automatic splitting with overlap handling
        - 📊 **Progress Tracking:** Real-time progress updates
        - 🧠 **Smart Merging:** Removes duplicate words from overlapping segments
        - ⚙️ **Configurable Settings:** Adjust chunk size and overlap for optimal results

        ### GPU Requirements
        - **Recommended:** T4, V100, A100, or better
        - **Memory:** ~6-8GB GPU memory for inference
        - **Speed:** ~418x real-time on modern GPUs
        - **Streaming:** Can process hours of audio with consistent performance
        """)

    # Event handlers
    load_btn.click(
        fn=load_model,
        outputs=[load_status]
    )

    transcribe_btn.click(
        fn=transcribe_audio,
        inputs=[audio_input, max_tokens],
        outputs=[transcription_output]
    )

    stream_btn.click(
        fn=transcribe_long_audio,
        inputs=[long_audio_input, chunk_duration, overlap_duration, streaming_max_tokens],
        outputs=[streaming_output]
    )

    process_btn.click(
        fn=post_process_transcript,
        inputs=[transcript_input, user_prompt, llm_max_tokens],
        outputs=[llm_output]
    )

# Launch the interface
if __name__ == "__main__":
    demo.launch(
        share=True,
        debug=True,
        server_name="0.0.0.0",
        server_port=7860
    )