# Voxtral VLLM Plugin

> Plugin implementation for Mistral Voxtral transcription through vLLM server

In [None]:
#| default_exp plugin

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import json
import logging
from pathlib import Path
from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from typing import Dict, Any, Optional, List, Union, Generator
import tempfile
import warnings
from threading import Thread
import subprocess
import time
import requests
import atexit
import signal
import threading
import queue
import re
from datetime import datetime

from fastcore.basics import patch

import numpy as np
import soundfile as sf

try:
    from openai import OpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

try:
    from mistral_common.protocol.transcription.request import TranscriptionRequest
    from mistral_common.protocol.instruct.chunk import RawAudio
    from mistral_common.audio import Audio
    MISTRAL_COMMON_AVAILABLE = True
except ImportError:
    MISTRAL_COMMON_AVAILABLE = False
    
from cjm_transcription_plugin_system.plugin_interface import TranscriptionPlugin
from cjm_transcription_plugin_system.core import AudioData, TranscriptionResult
from cjm_plugin_system.utils.validation import (
    dict_to_config, config_to_dict, validate_config,
    SCHEMA_TITLE, SCHEMA_DESC, SCHEMA_MIN, SCHEMA_MAX, SCHEMA_ENUM
)

In [None]:
#| export
import subprocess
import time
import requests
import atexit
import signal
import threading
import queue
import re
from contextlib import contextmanager
from typing import Optional, Callable, List
from datetime import datetime

class VLLMServer:
    """vLLM server manager for Voxtral models."""

    def __init__(
        self,
        model: str = "mistralai/Voxtral-Mini-3B-2507", # Model name to serve
        port: int = 8000, # Port for the server
        host: str = "0.0.0.0", # Host address to bind to
        gpu_memory_utilization: float = 0.85, # Fraction of GPU memory to use
        log_level: str = "INFO", # Logging level (DEBUG, INFO, WARNING, ERROR)
        capture_logs: bool = True, # Whether to capture and display server logs
        **kwargs # Additional vLLM server arguments
    ):
        self.model = model
        self.port = port
        self.host = host
        self.base_url = f"http://localhost:{self.port}"
        self.process: Optional[subprocess.Popen] = None
        self.capture_logs = capture_logs
        self.log_level = log_level
        
        # Log management
        self.log_queue = queue.Queue()
        self.log_thread = None
        self.stop_logging = threading.Event()
        self.log_callbacks: List[Callable] = []
        
        # Build command
        self.cmd = [
            "python", "-m", "vllm.entrypoints.openai.api_server",
            "--model", model,
            "--port", str(self.port),
            "--host", host,
            "--gpu-memory-utilization", str(gpu_memory_utilization),
            "--tokenizer-mode", "mistral",
            "--config-format", "mistral",
            "--load-format", "mistral",
        ]
        
        # Add any additional arguments
        for key, value in kwargs.items():
            self.cmd.extend([f"--{key.replace('_', '-')}", str(value)])
    
    def add_log_callback(
        self, 
        callback: Callable[[str], None] # Function that receives log line strings
    ) -> None: # Returns nothing
        """Add a callback function to receive each log line."""
        self.log_callbacks.append(callback)
    
    def _process_log_line(
        self, 
        line: str # Log line to process
    ) -> None: # Returns nothing
        """Process a single log line."""
        if not line.strip():
            return
        
        # Add timestamp if not present
        if not re.match(r'^\d{2}-\d{2} \d{2}:\d{2}:\d{2}', line):
            timestamp = datetime.now().strftime("%m-%d %H:%M:%S")
            line = f"{timestamp} {line}"
        
        # Store in queue for retrieval
        self.log_queue.put(line)
        
        # Call callbacks
        for callback in self.log_callbacks:
            try:
                callback(line)
            except Exception as e:
                print(f"Error in log callback: {e}")
        
        # Print if we're capturing logs
        if self.capture_logs:
            # Color code based on log level
            if "ERROR" in line:
                print(f"\033[91m{line}\033[0m")  # Red
            elif "WARNING" in line:
                print(f"\033[93m{line}\033[0m")  # Yellow
            elif "INFO" in line:
                print(f"\033[94m{line}\033[0m")  # Blue
            else:
                print(line)
    
    def _log_reader(
        self, 
        pipe, # Pipe to read from
        pipe_name: str # Name of the pipe (stdout/stderr)
    ) -> None: # Returns nothing
        """Read logs from a pipe in a separate thread."""
        for line in iter(pipe.readline, ''):
            if self.stop_logging.is_set():
                break
            if line:
                # Add pipe identifier for debugging
                if pipe_name == "stderr" and "INFO" in line:
                    # Most INFO logs come through stderr in vLLM
                    self._process_log_line(line.strip())
                elif pipe_name == "stdout":
                    self._process_log_line(line.strip())
    
    def start(
        self, 
        wait_for_ready: bool = True, # Wait for server to be ready before returning
        timeout: int = 120, # Maximum seconds to wait for server readiness
        show_progress: bool = True # Show progress indicators during startup
    ) -> None: # Returns nothing
        """Start the vLLM server."""
        if self.is_running():
            print("Server is already running")
            return
        
        print(f"Starting vLLM server with model {self.model}...")
        
        # Start process with pipes for output capture
        self.process = subprocess.Popen(
            self.cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,  # Line buffered
            universal_newlines=True,
            preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)
        )
        
        # Start log reading threads
        self.stop_logging.clear()
        
        if self.capture_logs:
            # Create threads for reading stdout and stderr
            stdout_thread = threading.Thread(
                target=self._log_reader,
                args=(self.process.stdout, "stdout"),
                daemon=True
            )
            stderr_thread = threading.Thread(
                target=self._log_reader,
                args=(self.process.stderr, "stderr"),
                daemon=True
            )
            stdout_thread.start()
            stderr_thread.start()
        
        # Register cleanup on exit
        atexit.register(self.stop)
        
        if wait_for_ready:
            self._wait_for_server(timeout, show_progress)
    
    def _wait_for_server(
        self, 
        timeout: int, # Maximum seconds to wait
        show_progress: bool # Whether to show progress indicators
    ) -> None: # Returns nothing
        """Wait for server to be ready to accept requests."""
        start_time = time.time()
        last_status = ""
        
        # Key phases to look for in logs
        phases = {
            "detected platform": "üîç Detecting platform...",
            "Loading model": "üì¶ Loading model weights...",
            "downloading weights": "‚¨áÔ∏è Downloading model weights...",
            "Loading weights took": "‚úÖ Model weights loaded",
            "Graph capturing": "üìä Capturing CUDA graphs...",
            "Graph capturing finished": "‚úÖ CUDA graphs ready",
            "Starting vLLM API server": "üöÄ Starting API server...",
            "Available routes": "‚úÖ Server ready!"
        }
        
        while time.time() - start_time < timeout:
            # Check for new status in logs
            if show_progress and not self.log_queue.empty():
                try:
                    log_line = self.log_queue.get_nowait()
                    for key, status_msg in phases.items():
                        if key in log_line:
                            if status_msg != last_status:
                                print(f"\n  {status_msg}")
                                last_status = status_msg
                            break
                except queue.Empty:
                    pass
            
            # Try to connect to the server
            try:
                response = requests.get(f"{self.base_url}/health", timeout=1)
                if response.status_code == 200:
                    print(f"\n‚úÖ vLLM server is ready at {self.base_url}")
                    return
            except requests.exceptions.RequestException:
                pass
            
            # Check if process has crashed
            if self.process.poll() is not None:
                raise RuntimeError(f"Server process exited with code {self.process.poll()}")
            
            time.sleep(0.5)
        
        raise TimeoutError(f"Server did not start within {timeout} seconds")
    
    def stop(self) -> None: # Returns nothing
        """Stop the vLLM server."""
        if self.process and self.process.poll() is None:
            print("Stopping vLLM server...")
            
            # Signal threads to stop
            self.stop_logging.set()
            
            # Terminate process
            self.process.terminate()
            try:
                self.process.wait(timeout=10)
            except subprocess.TimeoutExpired:
                print("Force killing server...")
                self.process.kill()
                self.process.wait()
            
            self.process = None
            print("Server stopped")
    
    def restart(self) -> None: # Returns nothing
        """Restart the server."""
        self.stop()
        time.sleep(2)
        self.start()
    
    def is_running(self) -> bool: # True if server is running and responsive
        """Check if server is running and responsive."""
        # First check if process exists and hasn't exited
        if self.process is None or self.process.poll() is not None:
            return False
        
        # Try to connect to the server
        try:
            response = requests.get(f"{self.base_url}/health", timeout=1)
            if response.status_code == 200:
                # print(f"\n‚úÖ vLLM server is ready at {self.base_url}")
                return True
            else:
                print(f"\n‚ùå vLLM server is not ready at {self.base_url}")
                return False
        except requests.exceptions.RequestException as e:
            print(f"\n‚ùå The following exception occurred: {e}")
            return False
    
    def get_recent_logs(
        self, 
        n: int = 100 # Number of recent log lines to retrieve
    ) -> List[str]: # List of recent log lines
        """Get the most recent n log lines."""
        logs = []
        while not self.log_queue.empty() and len(logs) < n:
            try:
                logs.append(self.log_queue.get_nowait())
            except queue.Empty:
                break
        return logs
    
    def get_metrics_from_logs(self) -> dict: # Dictionary with performance metrics
        """Parse recent logs to extract performance metrics."""
        metrics = {
            "prompt_throughput": 0.0,
            "generation_throughput": 0.0,
            "running_requests": 0,
            "waiting_requests": 0,
            "gpu_kv_cache_usage": 0.0,
            "prefix_cache_hit_rate": 0.0,
        }
        
        # Look for metric lines in recent logs
        recent_logs = self.get_recent_logs(50)
        for log in recent_logs:
            if "Avg prompt throughput:" in log:
                # Parse metrics line
                match = re.search(r'Avg prompt throughput: ([\d.]+) tokens/s', log)
                if match:
                    metrics["prompt_throughput"] = float(match.group(1))
                
                match = re.search(r'Avg generation throughput: ([\d.]+) tokens/s', log)
                if match:
                    metrics["generation_throughput"] = float(match.group(1))
                
                match = re.search(r'Running: (\d+) reqs', log)
                if match:
                    metrics["running_requests"] = int(match.group(1))
                
                match = re.search(r'Waiting: (\d+) reqs', log)
                if match:
                    metrics["waiting_requests"] = int(match.group(1))
                
                match = re.search(r'GPU KV cache usage: ([\d.]+)%', log)
                if match:
                    metrics["gpu_kv_cache_usage"] = float(match.group(1))
                
                match = re.search(r'Prefix cache hit rate: ([\d.]+)%', log)
                if match:
                    metrics["prefix_cache_hit_rate"] = float(match.group(1))
                
                break  # Use the most recent metrics
        
        return metrics
    
    def tail_logs(
        self, 
        follow: bool = True, # Continue displaying new logs as they arrive
        n: int = 10 # Number of initial lines to display
    ) -> None: # Returns nothing
        """Tail the server logs (similar to tail -f)."""
        # Display recent logs
        recent = self.get_recent_logs(n)
        for line in recent:
            print(line)
        
        if follow:
            print("\n--- Following logs (Ctrl+C to stop) ---\n")
            try:
                while self.is_running():
                    if not self.log_queue.empty():
                        print(self.log_queue.get())
                    else:
                        time.sleep(0.1)
            except KeyboardInterrupt:
                print("\n--- Stopped following logs ---")
    
    def __enter__(self):
        """Context manager entry."""
        self.start()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.stop()

In [None]:
#| export
@dataclass
class VoxtralVLLMPluginConfig:
    """Configuration for Voxtral VLLM transcription plugin."""
    model_id:str = field(
        default="mistralai/Voxtral-Mini-3B-2507",
        metadata={
            SCHEMA_TITLE: "Model ID",
            SCHEMA_DESC: "Voxtral model to use. Mini is faster, Small is more accurate.",
            SCHEMA_ENUM: ["mistralai/Voxtral-Mini-3B-2507", "mistralai/Voxtral-Small-24B-2507"]
        }
    )
    server_mode:str = field(
        default="managed",
        metadata={
            SCHEMA_TITLE: "Server Mode",
            SCHEMA_DESC: "'managed': plugin manages server lifecycle, 'external': connect to existing server",
            SCHEMA_ENUM: ["managed", "external"]
        }
    )
    server_url:str = field(
        default="http://localhost:8000",
        metadata={
            SCHEMA_TITLE: "Server URL",
            SCHEMA_DESC: "vLLM server URL (for external mode)"
        }
    )
    server_port:int = field(
        default=8000,
        metadata={
            SCHEMA_TITLE: "Server Port",
            SCHEMA_DESC: "Port for managed vLLM server",
            SCHEMA_MIN: 1024,
            SCHEMA_MAX: 65535
        }
    )
    gpu_memory_utilization:float = field(
        default=0.85,
        metadata={
            SCHEMA_TITLE: "GPU Memory Utilization",
            SCHEMA_DESC: "Fraction of GPU memory to use (managed mode)",
            SCHEMA_MIN: 0.1,
            SCHEMA_MAX: 1.0
        }
    )
    max_model_len:int = field(
        default=32768,
        metadata={
            SCHEMA_TITLE: "Max Model Length",
            SCHEMA_DESC: "Maximum sequence length for the model",
            SCHEMA_MIN: 1024,
            SCHEMA_MAX: 131072
        }
    )
    language:Optional[str] = field(
        default="en",
        metadata={
            SCHEMA_TITLE: "Language",
            SCHEMA_DESC: "Language code for transcription (e.g., 'en', 'es', 'fr')"
        }
    )
    temperature:float = field(
        default=0.0,
        metadata={
            SCHEMA_TITLE: "Temperature",
            SCHEMA_DESC: "Temperature for sampling (0.0 for deterministic)",
            SCHEMA_MIN: 0.0,
            SCHEMA_MAX: 2.0
        }
    )
    streaming:bool = field(
        default=False,
        metadata={
            SCHEMA_TITLE: "Streaming",
            SCHEMA_DESC: "Enable streaming output by default"
        }
    )
    server_startup_timeout:int = field(
        default=120,
        metadata={
            SCHEMA_TITLE: "Server Startup Timeout",
            SCHEMA_DESC: "Timeout in seconds for server startup (managed mode)",
            SCHEMA_MIN: 30,
            SCHEMA_MAX: 600
        }
    )
    auto_start_server:bool = field(
        default=True,
        metadata={
            SCHEMA_TITLE: "Auto Start Server",
            SCHEMA_DESC: "Automatically start server on first use (managed mode)"
        }
    )
    capture_server_logs:bool = field(
        default=True,
        metadata={
            SCHEMA_TITLE: "Capture Server Logs",
            SCHEMA_DESC: "Capture vLLM server logs (managed mode)"
        }
    )
    dtype:str = field(
        default="auto",
        metadata={
            SCHEMA_TITLE: "Data Type",
            SCHEMA_DESC: "Data type for model weights",
            SCHEMA_ENUM: ["auto", "half", "float16", "bfloat16", "float32"]
        }
    )
    tensor_parallel_size:int = field(
        default=1,
        metadata={
            SCHEMA_TITLE: "Tensor Parallel Size",
            SCHEMA_DESC: "Number of GPUs for tensor parallelism",
            SCHEMA_MIN: 1,
            SCHEMA_MAX: 8
        }
    )


class VoxtralVLLMPlugin(TranscriptionPlugin):
    """Mistral Voxtral transcription plugin via vLLM server."""
    
    config_class = VoxtralVLLMPluginConfig
    
    def __init__(self):
        """Initialize the Voxtral VLLM plugin with default configuration."""
        self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}")
        self.config: VoxtralVLLMPluginConfig = None
        self.server: Optional[VLLMServer] = None
        self.client: Optional[OpenAI] = None
        self.model_id: Optional[str] = None
    
    @property
    def name(self) -> str: # The plugin name identifier
        """Get the plugin name identifier."""
        return "voxtral_vllm"
    
    @property
    def version(self) -> str: # The plugin version string
        """Get the plugin version string."""
        return "1.0.0"
    
    @property
    def supported_formats(self) -> List[str]: # List of supported audio formats
        """Get the list of supported audio file formats."""
        return ["wav", "mp3", "flac", "m4a", "ogg", "webm", "mp4", "avi", "mov"]
    
    def get_current_config(self) -> VoxtralVLLMPluginConfig: # Current configuration dataclass
        """Return current configuration."""
        return self.config
    
    def initialize(
        self,
        config: Optional[Any] = None # Configuration dataclass, dict, or None
    ) -> None:
        """Initialize the plugin with configuration."""
        # Handle config input
        if config is None:
            self.config = VoxtralVLLMPluginConfig()
        elif isinstance(config, VoxtralVLLMPluginConfig):
            self.config = config
        elif isinstance(config, dict):
            self.config = dict_to_config(VoxtralVLLMPluginConfig, config, validate=True)
        else:
            raise TypeError(f"Expected VoxtralVLLMPluginConfig, dict, or None, got {type(config).__name__}")
        
        self.model_id = self.config.model_id
        
        # Initialize based on server mode
        if self.config.server_mode == "managed":
            # Create managed server instance (but don't start yet)
            self.server = VLLMServer(
                model=self.model_id,
                port=self.config.server_port,
                gpu_memory_utilization=self.config.gpu_memory_utilization,
                max_model_len=self.config.max_model_len,
                capture_logs=self.config.capture_server_logs,
                dtype=self.config.dtype,
                tensor_parallel_size=self.config.tensor_parallel_size
            )
            server_url = f"http://localhost:{self.config.server_port}"
        else:
            # External server mode
            server_url = self.config.server_url
        
        # Create OpenAI client
        self.client = OpenAI(
            api_key="EMPTY",  # vLLM doesn't require API key
            base_url=f"{server_url}/v1"
        )
        
        self.logger.info(
            f"Initialized Voxtral VLLM plugin with model '{self.model_id}' "
            f"in {self.config.server_mode} mode"
        )
    
    def _ensure_server_running(self) -> None:
        """Ensure the vLLM server is running (for managed mode)."""
        if self.config.server_mode == "managed" and self.server:
            if not self.server.is_running():
                print("\n\nSERVER IS NOT RUNNING\n\n")
                if self.config.auto_start_server:
                    self.logger.info("Starting vLLM server...")
                    self.server.start(
                        wait_for_ready=True,
                        timeout=self.config.server_startup_timeout,
                        show_progress=self.config.capture_server_logs
                    )
                else:
                    raise RuntimeError("vLLM server is not running and auto_start_server is disabled")
        elif self.config.server_mode == "external":
            # Check if external server is accessible
            try:
                response = requests.get(f"{self.config.server_url}/health", timeout=5)
                if response.status_code != 200:
                    raise RuntimeError(f"External vLLM server returned status {response.status_code}")
            except requests.exceptions.RequestException as e:
                raise RuntimeError(f"Cannot connect to external vLLM server: {e}")
    
    def _prepare_audio(
        self,
        audio: Union[AudioData, str, Path] # Audio data, file path, or Path object to prepare
    ) -> str: # Path to the prepared audio file
        """Prepare audio for Voxtral processing."""
        if isinstance(audio, (str, Path)):
            # Already a file path
            return str(audio)
        
        elif isinstance(audio, AudioData):
            # Save AudioData to temporary file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                # Ensure audio is in the correct format
                audio_array = audio.samples
                
                # If stereo, convert to mono
                if audio_array.ndim > 1:
                    audio_array = audio_array.mean(axis=1)
                
                # Ensure float32 and normalized
                if audio_array.dtype != np.float32:
                    audio_array = audio_array.astype(np.float32)
                
                # Normalize if needed
                if audio_array.max() > 1.0:
                    audio_array = audio_array / np.abs(audio_array).max()
                
                # Save to file
                sf.write(tmp_file.name, audio_array, audio.sample_rate)
                return tmp_file.name
        else:
            raise ValueError(f"Unsupported audio input type: {type(audio)}")
    
    def execute(
        self,
        audio: Union[AudioData, str, Path], # Audio data or path to audio file to transcribe
        **kwargs # Additional arguments to override config
    ) -> TranscriptionResult: # Transcription result with text and metadata
        """Transcribe audio using Voxtral via vLLM."""
        # Ensure server is running
        self._ensure_server_running()
        
        # Prepare audio file
        audio_path = self._prepare_audio(audio)
        temp_file_created = not isinstance(audio, (str, Path))
        
        try:
            # Get config values, allowing kwargs overrides
            language = kwargs.get("language", self.config.language)
            temperature = kwargs.get("temperature", self.config.temperature)
            
            # Prepare inputs using mistral_common
            self.logger.info(f"Processing audio with Voxtral {self.model_id} via vLLM")
            
            input_audio = Audio.from_file(audio_path, strict=False)
            input_audio = RawAudio.from_audio(input_audio)
            
            req = TranscriptionRequest(
                model=self.model_id,
                audio=input_audio,
                language=language or "en",
                temperature=temperature
            ).to_openai(exclude=("top_p", "seed"))
            
            # Get transcription from vLLM server
            response = self.client.audio.transcriptions.create(**req)
            
            # Create transcription result
            transcription_result = TranscriptionResult(
                text=response.text.strip(),
                confidence=None,  # vLLM doesn't provide confidence scores
                segments=None,  # vLLM doesn't provide segments by default
                metadata={
                    "model": self.model_id,
                    "language": language or "en",
                    "server_mode": self.config.server_mode,
                    "temperature": temperature,
                }
            )
            
            self.logger.info(f"Transcription completed: {len(response.text.split())} words")
            return transcription_result
            
        finally:
            # Clean up temporary file if created
            if temp_file_created:
                try:
                    Path(audio_path).unlink()
                except Exception:
                    pass
    
    def is_available(self) -> bool: # True if vLLM and dependencies are available
        """Check if vLLM and required dependencies are available."""
        if not OPENAI_AVAILABLE:
            return False
        if not MISTRAL_COMMON_AVAILABLE:
            return False
        
        # Check if vLLM is installed
        try:
            import vllm
            return True
        except ImportError:
            return False
    
    def cleanup(self) -> None:
        """Clean up resources."""
        self.logger.info("Cleaning up Voxtral VLLM plugin")
        
        # Stop managed server if running
        if self.config and self.config.server_mode == "managed" and self.server:
            if self.server.is_running():
                self.logger.info("Stopping managed vLLM server")
                self.server.stop()
            self.server = None
        
        # Clear client
        self.client = None
        
        self.logger.info("Cleanup completed successfully")

In [None]:
#| export
@patch
def supports_streaming(
    self: VoxtralVLLMPlugin # The plugin instance
) -> bool: # True if streaming is supported
    """Check if this plugin supports streaming transcription."""
    return True

@patch
def execute_stream(
    self: VoxtralVLLMPlugin, # The plugin instance
    audio: Union[AudioData, str, Path], # Audio data or path to audio file
    **kwargs # Additional plugin-specific parameters
) -> Generator[str, None, TranscriptionResult]: # Yields text chunks, returns final result
    """Stream transcription results chunk by chunk."""
    # Ensure server is running
    self._ensure_server_running()
    
    # Prepare audio file
    audio_path = self._prepare_audio(audio)
    temp_file_created = not isinstance(audio, (str, Path))
    
    try:
        # Get config values, allowing kwargs overrides
        language = kwargs.get("language", self.config.language)
        temperature = kwargs.get("temperature", self.config.temperature)
        
        # Prepare inputs using mistral_common
        self.logger.info(f"Streaming transcription with Voxtral {self.model_id} via vLLM")
        
        input_audio = Audio.from_file(audio_path, strict=False)
        input_audio = RawAudio.from_audio(input_audio)
        
        req = TranscriptionRequest(
            model=self.model_id,
            audio=input_audio,
            language=language or "en",
            temperature=temperature
        ).to_openai(exclude=("top_p", "seed"))
        
        # Get streaming transcription from vLLM server
        response = self.client.audio.transcriptions.create(**req, stream=True)
        
        # Collect generated text
        generated_text = ""
        for chunk in response:
            if chunk.choices[0]['delta']['content']:
                text_chunk = chunk.choices[0]['delta']['content']
                generated_text += text_chunk
                yield text_chunk
        
        # Return final result
        return TranscriptionResult(
            text=generated_text.strip(),
            confidence=None,
            segments=None,
            metadata={
                "model": self.model_id,
                "language": language or "en",
                "server_mode": self.config.server_mode,
                "temperature": temperature,
                "streaming": True,
            }
        )
        
    finally:
        # Clean up temporary file if created
        if temp_file_created:
            try:
                Path(audio_path).unlink()
            except Exception:
                pass

## Testing the Plugin

In [None]:
# Test basic functionality
plugin = VoxtralVLLMPlugin()

# Check availability
print(f"Voxtral VLLM available: {plugin.is_available()}")
print(f"Plugin name: {plugin.name}")
print(f"Plugin version: {plugin.version}")
print(f"Supported formats: {plugin.supported_formats}")
print(f"Supports streaming: {plugin.supports_streaming()}")

Voxtral VLLM available: True
Plugin name: voxtral_vllm
Plugin version: 1.0.0
Supported formats: ['wav', 'mp3', 'flac', 'm4a', 'ogg', 'webm', 'mp4', 'avi', 'mov']
Supports streaming: True


In [None]:
# Test configuration dataclass
from dataclasses import fields

print("Available models:")
model_field = next(f for f in fields(VoxtralVLLMPluginConfig) if f.name == "model_id")
for model in model_field.metadata.get(SCHEMA_ENUM, []):
    print(f"  - {model}")

server_field = next(f for f in fields(VoxtralVLLMPluginConfig) if f.name == "server_mode")
print(f"\nServer modes: {server_field.metadata.get(SCHEMA_ENUM)}")

Available models:
  - mistralai/Voxtral-Mini-3B-2507
  - mistralai/Voxtral-Small-24B-2507

Server modes: ['managed', 'external']


In [None]:
# Test configuration validation
from dataclasses import asdict

plugin = VoxtralVLLMPlugin()

test_configs = [
    ({"model_id": "mistralai/Voxtral-Mini-3B-2507"}, "Valid config"),
    ({"model_id": "invalid_model"}, "Invalid model"),
    ({"server_port": 9000}, "Valid port change"),
    ({"temperature": 2.5}, "Temperature out of range"),
]

# Get defaults for merging
defaults = plugin.get_config_defaults()

for config_update, description in test_configs:
    try:
        merged = {**defaults, **config_update}
        test_config = dict_to_config(VoxtralVLLMPluginConfig, merged, validate=True)
        print(f"{description}: Valid=True")
    except ValueError as e:
        print(f"{description}: Valid=False")
        print(f"  Error: {str(e)[:100]}")

Valid config: Valid=True
Invalid model: Valid=False
  Error: model_id: 'invalid_model' is not one of ['mistralai/Voxtral-Mini-3B-2507', 'mistralai/Voxtral-Small-
Valid port change: Valid=True
Temperature out of range: Valid=False
  Error: temperature: 2.5 is greater than maximum 2.0


In [None]:
# Test initialization with external server mode
plugin.initialize({
    "model_id": "mistralai/Voxtral-Mini-3B-2507",
    "server_mode": "external",
    "server_url": "http://localhost:8000"
})
print(f"Current config mode: {plugin.get_current_config().server_mode}")
print(f"Current model: {plugin.get_current_config().model_id}")

Current config mode: external
Current model: mistralai/Voxtral-Mini-3B-2507


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()