In [None]:
!pip install gradio transformers torch torchaudio librosa datasets

In [None]:
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
import librosa
import re
import warnings

# Suppress warnings
warnings.filterwarnings("ignore")

class Wav2VecInterface:
    def __init__(self):
        self.model = None
        self.processor = None
        self.feature_extractor = None
        self.current_model_name = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_model(self, model_url_or_name):
        try:
            # Extract model name from URL if it's a full URL
            if "huggingface.co" in model_url_or_name:
                model_name = re.search(r'huggingface\.co/([^/]+/[^/?]+)', model_url_or_name)
                if model_name:
                    model_name = model_name.group(1)
                else:
                    return "❌ Invalid Hugging Face URL format", None, None
            else:
                model_name = model_url_or_name

            # Check if model is already loaded
            if self.current_model_name == model_name:
                return f"✅ Model {model_name} is already loaded", None, None

            # Load the model and processor
            try:
                self.processor = Wav2Vec2Processor.from_pretrained(model_name)
                self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
            except:
                # Fallback to feature extractor only if processor fails
                self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
                self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
                self.processor = None

            # Move model to device
            self.model.to(self.device)
            self.model.eval()
            self.current_model_name = model_name

            # Get model info
            model_info = f"""
**Model Information:**
- **Name:** {model_name}
- **Device:** {self.device}
- **Model Type:** Wav2Vec2ForCTC
- **Parameters:** {sum(p.numel() for p in self.model.parameters()):,}
- **Processor Available:** {'Yes' if self.processor else 'No (using feature extractor only)'}
            """

            return f"✅ Successfully loaded model: {model_name}", model_info, gr.update(interactive=True)

        except Exception as e:
            return f"❌ Error loading model: {str(e)}", None, None

    def transcribe_audio(self, audio_file):
        if self.model is None:
            return "❌ Please load a model first!"

        try:
            # Load and preprocess audio
            if audio_file is None:
                return "❌ Please upload an audio file!"

            # Load audio file
            audio_input, sample_rate = librosa.load(audio_file, sr=16000)

            # Prepare input
            if self.processor:
                input_values = self.processor(
                    audio_input,
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_values
            else:
                input_values = self.feature_extractor(
                    audio_input,
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_values

            # Move to device
            input_values = input_values.to(self.device)

            # Inference
            with torch.no_grad():
                logits = self.model(input_values).logits

            # Get predicted IDs
            predicted_ids = torch.argmax(logits, dim=-1)

            # Decode transcription
            if self.processor:
                transcription = self.processor.decode(predicted_ids[0])
            else:
                # Fallback decoding (may not work for all models)
                transcription = "Transcription unavailable - processor required for decoding"

            # Audio info
            duration = len(audio_input) / 16000
            audio_info = f"""
**Audio Information:**
- **Duration:** {duration:.2f} seconds
- **Sample Rate:** 16000 Hz (resampled)
- **Shape:** {audio_input.shape}
            """

            return transcription, audio_info

        except Exception as e:
            return f"❌ Error during transcription: {str(e)}", None

    def extract_features(self, audio_file):
        if self.model is None:
            return "❌ Please load a model first!"

        try:
            # Load and preprocess audio
            if audio_file is None:
                return "❌ Please upload an audio file!"

            # Load audio file
            audio_input, sample_rate = librosa.load(audio_file, sr=16000)

            # Prepare input
            if self.processor:
                input_values = self.processor(
                    audio_input,
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_values
            else:
                input_values = self.feature_extractor(
                    audio_input,
                    sampling_rate=16000,
                    return_tensors="pt"
                ).input_values

            # Move to device
            input_values = input_values.to(self.device)

            # Extract features
            with torch.no_grad():
                # Get hidden states from the model
                outputs = self.model.wav2vec2(input_values)
                last_hidden_state = outputs.last_hidden_state

            # Feature info
            features_info = f"""
**Feature Information:**
- **Feature Shape:** {last_hidden_state.shape}
- **Hidden Size:** {last_hidden_state.shape[-1]}
- **Sequence Length:** {last_hidden_state.shape[1]}
- **Mean Activation:** {last_hidden_state.mean().item():.6f}
- **Std Activation:** {last_hidden_state.std().item():.6f}
            """

            return features_info

        except Exception as e:
            return f"❌ Error during feature extraction: {str(e)}"

# Initialize the interface
wav2vec_interface = Wav2VecInterface()

# Popular Wav2Vec2 models for quick selection
popular_models = [
    "facebook/wav2vec2-base-960h",
    "facebook/wav2vec2-large-960h",
    "facebook/wav2vec2-base-100h",
    "facebook/wav2vec2-large-960h-lv60-self",
    "jonatasgrosman/wav2vec2-large-xlsr-53-english",
    "jonatasgrosman/wav2vec2-large-xlsr-53-spanish",
    "jonatasgrosman/wav2vec2-large-xlsr-53-french",
    "jonatasgrosman/wav2vec2-large-xlsr-53-german",
    "wavlm/wavlm-base",
    "microsoft/unispeech-sat-base",
]

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Wav2Vec2 Model Interface", theme=gr.themes.Soft()) as demo:
        gr.Markdown(
            """
            # 🎵 Wav2Vec2 Model Interface

            Load and test Wav2Vec2 models from Hugging Face for speech recognition and feature extraction.

            **Instructions:**
            1. Enter a Hugging Face model URL or model name (e.g., `facebook/wav2vec2-base-960h`)
            2. Click "Load Model" to download and initialize the model
            3. Upload an audio file and choose your task (transcription or feature extraction)
            """
        )

        with gr.Row():
            with gr.Column(scale=2):
                model_input = gr.Textbox(
                    label="🤗 Hugging Face Model URL or Name",
                    placeholder="e.g., facebook/wav2vec2-base-960h or https://huggingface.co/facebook/wav2vec2-base-960h",
                    value="facebook/wav2vec2-base-960h"
                )

                with gr.Row():
                    load_btn = gr.Button("🔄 Load Model", variant="primary")
                    clear_btn = gr.Button("🗑️ Clear", variant="secondary")

                popular_dropdown = gr.Dropdown(
                    label="🌟 Popular Models (Quick Select)",
                    choices=popular_models,
                    value=None,
                    interactive=True
                )

            with gr.Column(scale=1):
                status_output = gr.Textbox(
                    label="📊 Status",
                    interactive=False,
                    lines=2
                )

        model_info_output = gr.Markdown(label="ℹ️ Model Information")

        with gr.Row():
            with gr.Column():
                audio_input = gr.Audio(
                    label="🎤 Upload Audio File",
                    type="filepath"
                )

                with gr.Row():
                    transcribe_btn = gr.Button(
                        "🔤 Transcribe",
                        variant="primary",
                        interactive=False
                    )
                    features_btn = gr.Button(
                        "🧠 Extract Features",
                        variant="secondary",
                        interactive=False
                    )

            with gr.Column():
                transcription_output = gr.Textbox(
                    label="📝 Transcription",
                    lines=3,
                    interactive=False
                )

                audio_info_output = gr.Markdown(label="🎵 Audio Information")
                features_output = gr.Markdown(label="🔍 Features Information")

        # Event handlers
        def load_model_handler(model_name):
            status, info, btn_update = wav2vec_interface.load_model(model_name)
            if btn_update is not None:
                return status, info, btn_update, btn_update
            return status, info, gr.update(), gr.update()

        def clear_handler():
            wav2vec_interface.model = None
            wav2vec_interface.processor = None
            wav2vec_interface.current_model_name = None
            return (
                "",  # model_input
                "🗑️ Cleared", # status
                "",  # model_info
                "",  # transcription
                "",  # audio_info
                "",  # features
                gr.update(interactive=False),  # transcribe_btn
                gr.update(interactive=False),  # features_btn
            )

        def transcribe_handler(audio):
            transcription, audio_info = wav2vec_interface.transcribe_audio(audio)
            return transcription, audio_info

        def features_handler(audio):
            features_info = wav2vec_interface.extract_features(audio)
            return features_info

        def popular_model_selected(choice):
            if choice:
                return choice
            return gr.update()

        # Connect events
        load_btn.click(
            load_model_handler,
            inputs=[model_input],
            outputs=[status_output, model_info_output, transcribe_btn, features_btn]
        )

        clear_btn.click(
            clear_handler,
            outputs=[
                model_input, status_output, model_info_output,
                transcription_output, audio_info_output, features_output,
                transcribe_btn, features_btn
            ]
        )

        transcribe_btn.click(
            transcribe_handler,
            inputs=[audio_input],
            outputs=[transcription_output, audio_info_output]
        )

        features_btn.click(
            features_handler,
            inputs=[audio_input],
            outputs=[features_output]
        )

        popular_dropdown.change(
            popular_model_selected,
            inputs=[popular_dropdown],
            outputs=[model_input]
        )

        # Add examples section
        gr.Markdown("""
        ### 📚 Example Models to Try:

        - **English ASR:** `facebook/wav2vec2-base-960h` (good for English speech recognition)
        - **Multilingual:** `facebook/wav2vec2-large-xlsr-53` (supports 53 languages)
        - **Large English:** `facebook/wav2vec2-large-960h-lv60-self` (high accuracy English model)
        - **Custom fine-tuned:** Search Hugging Face for domain-specific models

        ### 🎯 Supported Audio Formats:
        WAV, MP3, FLAC, M4A (will be resampled to 16kHz mono)
        """)

    return demo

# Install required packages (run this cell first in Colab)
def install_requirements():
    import subprocess
    import sys

    packages = [
        "gradio",
        "transformers",
        "torch",
        "torchaudio",
        "librosa",
        "datasets"
    ]

    for package in packages:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Uncomment the line below and run to install requirements
# install_requirements()

# Create and launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(
        share=True,  # Creates a public link
        debug=True,
        server_name="0.0.0.0",  # Important for Colab
        server_port=7860
    )