# Whisper Plugin

> Plugin implementation for OpenAI Whisper transcription

In [None]:
#| default_exp plugin

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import json
import logging
from pathlib import Path
from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from typing import Dict, Any, Optional, List, Union
import tempfile
import warnings

import numpy as np
import torch
import soundfile as sf

try:
    import whisper
    from whisper import load_model
    from whisper import transcribe
    from cjm_ffmpeg_utils.core import FFMPEG_AVAILABLE
    WHISPER_AVAILABLE = True and FFMPEG_AVAILABLE
except ImportError:
    WHISPER_AVAILABLE = False
    
from cjm_transcription_plugin_system.plugin_interface import TranscriptionPlugin
from cjm_transcription_plugin_system.core import AudioData, TranscriptionResult
from cjm_plugin_system.utils.validation import (
    dict_to_config, config_to_dict, validate_config, dataclass_to_jsonschema,
    SCHEMA_TITLE, SCHEMA_DESC, SCHEMA_MIN, SCHEMA_MAX, SCHEMA_ENUM
)

In [None]:
#| export
@dataclass
class WhisperPluginConfig:
    """Configuration for Whisper transcription plugin."""
    model:str = field(
        default="base",
        metadata={
            SCHEMA_TITLE: "Model",
            SCHEMA_DESC: "Whisper model size. Larger models are more accurate but slower.",
            SCHEMA_ENUM: ["tiny", "tiny.en", "base", "base.en", "small", "small.en", 
                        "medium", "medium.en", "large", "large-v1", "large-v2", "large-v3"]
        }
    )
    device:str = field(
        default="auto",
        metadata={
            SCHEMA_TITLE: "Device",
            SCHEMA_DESC: "Device for inference (auto will use CUDA if available)",
            SCHEMA_ENUM: ["auto", "cpu", "cuda"]
        }
    )
    language:Optional[str] = field(
        default=None,
        metadata={
            SCHEMA_TITLE: "Language",
            SCHEMA_DESC: "Language code (e.g., 'en', 'es', 'fr') or None for auto-detection"
        }
    )
    task:str = field(
        default="transcribe",
        metadata={
            SCHEMA_TITLE: "Task",
            SCHEMA_DESC: "Task to perform (transcribe or translate to English)",
            SCHEMA_ENUM: ["transcribe", "translate"]
        }
    )
    temperature:float = field(
        default=0.0,
        metadata={
            SCHEMA_TITLE: "Temperature",
            SCHEMA_DESC: "Sampling temperature (0 for deterministic)",
            SCHEMA_MIN: 0.0,
            SCHEMA_MAX: 1.0
        }
    )
    temperature_increment_on_fallback:Optional[float] = field(
        default=0.2,
        metadata={
            SCHEMA_TITLE: "Temperature Increment",
            SCHEMA_DESC: "Temperature increment when falling back",
            SCHEMA_MIN: 0.0,
            SCHEMA_MAX: 1.0
        }
    )
    beam_size:int = field(
        default=5,
        metadata={
            SCHEMA_TITLE: "Beam Size",
            SCHEMA_DESC: "Beam search width",
            SCHEMA_MIN: 1,
            SCHEMA_MAX: 10
        }
    )
    best_of:int = field(
        default=5,
        metadata={
            SCHEMA_TITLE: "Best Of",
            SCHEMA_DESC: "Number of candidates when sampling",
            SCHEMA_MIN: 1,
            SCHEMA_MAX: 10
        }
    )
    patience:float = field(
        default=1.0,
        metadata={
            SCHEMA_TITLE: "Patience",
            SCHEMA_DESC: "Beam search patience factor",
            SCHEMA_MIN: 0.0,
            SCHEMA_MAX: 2.0
        }
    )
    length_penalty:Optional[float] = field(
        default=None,
        metadata={
            SCHEMA_TITLE: "Length Penalty",
            SCHEMA_DESC: "Exponential length penalty"
        }
    )
    suppress_tokens:str = field(
        default="-1",
        metadata={
            SCHEMA_TITLE: "Suppress Tokens",
            SCHEMA_DESC: "Tokens to suppress ('-1' for default)"
        }
    )
    initial_prompt:Optional[str] = field(
        default=None,
        metadata={
            SCHEMA_TITLE: "Initial Prompt",
            SCHEMA_DESC: "Optional initial prompt"
        }
    )
    condition_on_previous_text:bool = field(
        default=False,
        metadata={
            SCHEMA_TITLE: "Condition on Previous",
            SCHEMA_DESC: "Condition on previous text"
        }
    )
    fp16:bool = field(
        default=True,
        metadata={
            SCHEMA_TITLE: "FP16",
            SCHEMA_DESC: "Use FP16 (half precision) for faster inference"
        }
    )
    compression_ratio_threshold:float = field(
        default=2.4,
        metadata={
            SCHEMA_TITLE: "Compression Ratio Threshold",
            SCHEMA_DESC: "Threshold for repetition detection",
            SCHEMA_MIN: 1.0,
            SCHEMA_MAX: 10.0
        }
    )
    logprob_threshold:float = field(
        default=-1.0,
        metadata={
            SCHEMA_TITLE: "Logprob Threshold",
            SCHEMA_DESC: "Average log probability threshold"
        }
    )
    no_speech_threshold:float = field(
        default=0.6,
        metadata={
            SCHEMA_TITLE: "No Speech Threshold",
            SCHEMA_DESC: "Threshold for silence detection",
            SCHEMA_MIN: 0.0,
            SCHEMA_MAX: 1.0
        }
    )
    word_timestamps:bool = field(
        default=False,
        metadata={
            SCHEMA_TITLE: "Word Timestamps",
            SCHEMA_DESC: "Extract word-level timestamps"
        }
    )
    prepend_punctuations:str = field(
        default="\"'“¿([{-",
        metadata={
            SCHEMA_TITLE: "Prepend Punctuations",
            SCHEMA_DESC: "Punctuations to merge with next word"
        }
    )
    append_punctuations:str = field(
        default="\"'.。,，!！?？:：”)]}、",
        metadata={
            SCHEMA_TITLE: "Append Punctuations",
            SCHEMA_DESC: "Punctuations to merge with previous word"
        }
    )
    threads:int = field(
        default=0,
        metadata={
            SCHEMA_TITLE: "Threads",
            SCHEMA_DESC: "Number of threads (0 for default)",
            SCHEMA_MIN: 0
        }
    )
    model_dir:Optional[str] = field(
        default=None,
        metadata={
            SCHEMA_TITLE: "Model Directory",
            SCHEMA_DESC: "Directory to save/load models"
        }
    )
    compile_model:bool = field(
        default=False,
        metadata={
            SCHEMA_TITLE: "Compile Model",
            SCHEMA_DESC: "Use torch.compile for potential speedup (requires PyTorch 2.0+)"
        }
    )


class WhisperLocalPlugin(TranscriptionPlugin):
    """OpenAI Whisper transcription plugin."""
    
    config_class = WhisperPluginConfig
    
    def __init__(self):
        """Initialize the Whisper plugin with default configuration."""
        self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}")
        self.config: WhisperPluginConfig = None
        self.model = None
        self.device = None
        self.model_dir = None
    
    @property
    def name(self) -> str: # Plugin name identifier
        """Get the plugin name identifier."""
        return "whisper_local"
    
    @property
    def version(self) -> str: # Plugin version string
        """Get the plugin version string."""
        return "1.0.0"
    
    @property
    def supported_formats(self) -> List[str]: # List of supported audio file formats
        """Get the list of supported audio file formats."""
        return ["wav", "mp3", "flac", "m4a", "ogg", "webm", "mp4", "avi", "mov"]

    def get_current_config(self) -> Dict[str, Any]: # Current configuration as dictionary
        """Return current configuration state."""
        if not self.config:
            return {}
        return config_to_dict(self.config)

    def get_config_schema(self) -> Dict[str, Any]: # JSON Schema for configuration
        """Return JSON Schema for UI generation."""
        return dataclass_to_jsonschema(WhisperPluginConfig)

    @staticmethod
    def get_config_dataclass() -> WhisperPluginConfig: # Configuration dataclass
        """Return dataclass describing the plugin's configuration options."""
        return WhisperPluginConfig
    
    def initialize(
        self,
        config: Optional[Any] = None # Configuration dataclass, dict, or None
    ) -> None:
        """Initialize or re-configure the plugin (idempotent)."""
        # Parse new config
        new_config = dict_to_config(WhisperPluginConfig, config or {})
        
        # Check for changes if already running
        if self.config:
            # If the model selection changed, unload old model
            if self.config.model != new_config.model:
                self.logger.info(f"Config change: Model {self.config.model} -> {new_config.model}")
                self._unload_model()
            
            # If device changed, unload
            if self.config.device != new_config.device:
                self.logger.info(f"Config change: Device {self.config.device} -> {new_config.device}")
                self._unload_model()
        
        # Apply new config
        self.config = new_config
        
        # Set device
        if self.config.device == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = self.config.device
        
        # Set model directory
        self.model_dir = self.config.model_dir
        
        self.logger.info(f"Initialized Whisper plugin with model '{self.config.model}' on device '{self.device}'")
    
    def _unload_model(self) -> None:
        """Unload the current model and free resources."""
        if self.model is not None:
            self.logger.info("Unloading Whisper model for reconfiguration")
            self.model = None
            
            # Clear GPU cache if using CUDA
            if self.device == "cuda" and torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    def _load_model(self) -> None:
        """Load the Whisper model (lazy loading)."""
        if self.model is None:
            try:
                self.logger.info(f"Loading Whisper model: {self.config.model}")
                self.model = load_model(
                    self.config.model, 
                    device=self.device,
                    download_root=self.model_dir
                )
                
                # Optionally compile the model (PyTorch 2.0+)
                if self.config.compile_model and hasattr(torch, 'compile'):
                    self.model = torch.compile(self.model)
                    self.logger.info("Model compiled with torch.compile")
                    
                self.logger.info("Local Whisper model loaded successfully")
            except Exception as e:
                raise RuntimeError(f"Failed to load Whisper model: {e}")
    
    def _prepare_audio(
        self,
        audio: Union[AudioData, str, Path] # Audio data, file path, or Path object to prepare
    ) -> str: # Path to the prepared audio file
        """Prepare audio for Whisper processing."""
        if isinstance(audio, (str, Path)):
            # Already a file path
            return str(audio)
        
        elif isinstance(audio, AudioData):
            # Save AudioData to temporary file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
                # Ensure audio is in the correct format
                audio_array = audio.samples
                
                # If stereo, convert to mono
                if audio_array.ndim > 1:
                    audio_array = audio_array.mean(axis=1)
                
                # Ensure float32 and normalized
                if audio_array.dtype != np.float32:
                    audio_array = audio_array.astype(np.float32)
                
                # Normalize if needed
                if audio_array.max() > 1.0:
                    audio_array = audio_array / np.abs(audio_array).max()
                
                # Save to file
                sf.write(tmp_file.name, audio_array, audio.sample_rate)
                return tmp_file.name
        else:
            raise ValueError(f"Unsupported audio input type: {type(audio)}")
    
    def _save_to_db(
        self,
        result: TranscriptionResult # Transcription result to save
    ) -> None:
        """Save transcription result to database (placeholder)."""
        # Placeholder for DB logic
        # Implementation will use self.db_path which can be injected via config or environment
        pass
    
    def execute(
        self,
        audio: Union[AudioData, str, Path], # Audio data or path to audio file to transcribe
        **kwargs # Additional arguments to override config
    ) -> TranscriptionResult: # Transcription result with text and metadata
        """Transcribe audio using Whisper."""
        # Load model if not already loaded
        self._load_model()
        
        # Prepare audio file (handles Zero-Copy handoff from Proxy)
        audio_path = self._prepare_audio(audio)
        temp_file_created = (audio_path != str(audio)) and not isinstance(audio, (str, Path))
        
        try:
            # Get config values, allowing kwargs overrides
            model_name = kwargs.get("model", self.config.model)
            task = kwargs.get("task", self.config.task)
            language = kwargs.get("language", self.config.language)
            beam_size = kwargs.get("beam_size", self.config.beam_size)
            best_of = kwargs.get("best_of", self.config.best_of)
            patience = kwargs.get("patience", self.config.patience)
            length_penalty = kwargs.get("length_penalty", self.config.length_penalty)
            suppress_tokens = kwargs.get("suppress_tokens", self.config.suppress_tokens)
            initial_prompt = kwargs.get("initial_prompt", self.config.initial_prompt)
            condition_on_previous_text = kwargs.get("condition_on_previous_text", self.config.condition_on_previous_text)
            fp16 = kwargs.get("fp16", self.config.fp16)
            compression_ratio_threshold = kwargs.get("compression_ratio_threshold", self.config.compression_ratio_threshold)
            logprob_threshold = kwargs.get("logprob_threshold", self.config.logprob_threshold)
            no_speech_threshold = kwargs.get("no_speech_threshold", self.config.no_speech_threshold)
            word_timestamps = kwargs.get("word_timestamps", self.config.word_timestamps)
            prepend_punctuations = kwargs.get("prepend_punctuations", self.config.prepend_punctuations)
            append_punctuations = kwargs.get("append_punctuations", self.config.append_punctuations)
            temperature = kwargs.get("temperature", self.config.temperature)
            temp_increment = kwargs.get("temperature_increment_on_fallback", self.config.temperature_increment_on_fallback)
            threads = kwargs.get("threads", self.config.threads)
            
            # Prepare Whisper arguments
            whisper_args = {
                "verbose": False,
                "task": task,
                "language": language,
                "beam_size": beam_size,
                "best_of": best_of,
                "patience": patience,
                "length_penalty": length_penalty,
                "suppress_tokens": suppress_tokens,
                "initial_prompt": initial_prompt,
                "condition_on_previous_text": condition_on_previous_text,
                "fp16": fp16 and self.device == "cuda",
                "compression_ratio_threshold": compression_ratio_threshold,
                "logprob_threshold": logprob_threshold,
                "no_speech_threshold": no_speech_threshold,
                "word_timestamps": word_timestamps,
                "prepend_punctuations": prepend_punctuations,
                "append_punctuations": append_punctuations,
            }
            
            # Handle temperature settings
            if temp_increment is not None and temp_increment > 0:
                temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temp_increment))
            else:
                temperature = [temperature]
            
            # Set number of threads if specified
            if threads > 0:
                torch.set_num_threads(threads)
            
            # Perform transcription
            self.logger.info(f"Transcribing audio with Whisper {model_name}")
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")  # Suppress Whisper warnings
                result = transcribe(
                    self.model,
                    audio_path,
                    temperature=temperature,
                    **whisper_args
                )
            
            # Process segments
            segments = []
            for segment in result.get("segments", []):
                segment_data = {
                    "start": segment["start"],
                    "end": segment["end"],
                    "text": segment["text"].strip()
                }
                
                # Add word timestamps if available
                if "words" in segment and word_timestamps:
                    segment_data["words"] = [
                        {
                            "word": word["word"],
                            "start": word["start"],
                            "end": word["end"],
                            "probability": word.get("probability")
                        }
                        for word in segment["words"]
                    ]
                
                segments.append(segment_data)
            
            # Create transcription result
            transcription_result = TranscriptionResult(
                text=result["text"].strip(),
                confidence=None,  # Whisper doesn't provide overall confidence
                segments=segments if segments else None,
                metadata={
                    "model": model_name,
                    "language": result.get("language", language),
                    "task": task,
                    "device": self.device,
                    "duration": result.get("duration"),
                }
            )
            
            # Save to database (placeholder)
            self._save_to_db(transcription_result)
            
            self.logger.info(f"Transcription completed: {len(result['text'].split())} words")
            return transcription_result
            
        finally:
            # Clean up temporary file if created
            if temp_file_created:
                try:
                    Path(audio_path).unlink(missing_ok=True)
                except Exception:
                    pass
    
    def is_available(self) -> bool: # True if Whisper and its dependencies are available
        """Check if Whisper is available."""
        return WHISPER_AVAILABLE
    
    def cleanup(self) -> None:
        """Clean up resources."""
        if self.model is not None:
            self.logger.info("Unloading Whisper model")
            self.model = None
            
            # Clear GPU cache if using CUDA
            if self.device == "cuda" and torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            self.logger.info("Cleanup completed")

## Testing the Plugin

In [None]:
# Test basic functionality
plugin = WhisperLocalPlugin()

# Check availability
print(f"Whisper available: {plugin.is_available()}")
print(f"Plugin name: {plugin.name}")
print(f"Plugin version: {plugin.version}")
print(f"Supported formats: {plugin.supported_formats}")
print(f"Config class: {plugin.config_class.__name__}")

Whisper available: True
Plugin name: whisper_local
Plugin version: 1.0.0
Supported formats: ['wav', 'mp3', 'flac', 'm4a', 'ogg', 'webm', 'mp4', 'avi', 'mov']
Config class: WhisperPluginConfig


In [None]:
# Test configuration dataclass
from dataclasses import fields

print("Available models:")
model_field = next(f for f in fields(WhisperPluginConfig) if f.name == "model")
for model in model_field.metadata.get(SCHEMA_ENUM, []):
    print(f"  - {model}")

Available models:
  - tiny
  - tiny.en
  - base
  - base.en
  - small
  - small.en
  - medium
  - medium.en
  - large
  - large-v1
  - large-v2
  - large-v3


In [None]:
# Test configuration validation
test_configs = [
    ({"model": "tiny"}, "Valid config"),
    ({"model": "invalid"}, "Invalid model"),
    ({"model": "tiny", "temperature": 1.5}, "Temperature out of range"),
]

for config, description in test_configs:
    try:
        test_cfg = dict_to_config(WhisperPluginConfig, config, validate=True)
        print(f"{description}: Valid=True")
    except ValueError as e:
        print(f"{description}: Valid=False")
        print(f"  Error: {str(e)[:100]}")

Valid config: Valid=True
Invalid model: Valid=False
  Error: model: 'invalid' is not one of ['tiny', 'tiny.en', 'base', 'base.en', 'small', 'small.en', 'medium',
Temperature out of range: Valid=False
  Error: temperature: 1.5 is greater than maximum 1.0


In [None]:
# Test initialization and get_current_config (returns dict now)
plugin.initialize({"model": "tiny", "device": "cpu"})
current_config = plugin.get_current_config()
print(f"Current config (dict): {current_config}")
print(f"Current model: {current_config['model']}")

Current config (dict): {'model': 'tiny', 'device': 'cpu', 'language': None, 'task': 'transcribe', 'temperature': 0.0, 'temperature_increment_on_fallback': 0.2, 'beam_size': 5, 'best_of': 5, 'patience': 1.0, 'length_penalty': None, 'suppress_tokens': '-1', 'initial_prompt': None, 'condition_on_previous_text': False, 'fp16': True, 'compression_ratio_threshold': 2.4, 'logprob_threshold': -1.0, 'no_speech_threshold': 0.6, 'word_timestamps': False, 'prepend_punctuations': '"\'“¿([{-', 'append_punctuations': '"\'.。,，!！?？:：”)]}、', 'threads': 0, 'model_dir': None, 'compile_model': False}
Current model: tiny


In [None]:
#| eval: false
# Test get_config_schema for UI generation
import json

schema = plugin.get_config_schema()
print("JSON Schema for WhisperPluginConfig:")
print(f"  Name: {schema['name']}")
print(f"  Properties count: {len(schema['properties'])}")
print(f"  Model field enum: {schema['properties']['model'].get('enum', [])[:3]}...")
print(f"\nFull schema (truncated):")
print(json.dumps({k: v for k, v in list(schema['properties'].items())[:3]}, indent=2))

JSON Schema for WhisperPluginConfig:
  Name: WhisperPluginConfig
  Properties count: 23
  Model field enum: ['tiny', 'tiny.en', 'base']...

Full schema (truncated):
{
  "model": {
    "type": "string",
    "title": "Model",
    "description": "Whisper model size. Larger models are more accurate but slower.",
    "enum": [
      "tiny",
      "tiny.en",
      "base",
      "base.en",
      "small",
      "small.en",
      "medium",
      "medium.en",
      "large",
      "large-v1",
      "large-v2",
      "large-v3"
    ],
    "default": "base"
  },
  "device": {
    "type": "string",
    "title": "Device",
    "description": "Device for inference (auto will use CUDA if available)",
    "enum": [
      "auto",
      "cpu",
      "cuda"
    ],
    "default": "auto"
  },
  "language": {
    "type": [
      "string",
      "null"
    ],
    "title": "Language",
    "description": "Language code (e.g., 'en', 'es', 'fr') or None for auto-detection",
    "default": null
  }
}


In [None]:
#| eval: false
# Test idempotent initialize - model unload on config change
import logging

# Enable logging to see model unload messages
logging.basicConfig(level=logging.INFO)

# Initialize with one model
plugin.initialize({"model": "tiny", "device": "cpu"})
print(f"Initial config: model={plugin.config.model}")

# Re-initialize with same model (no unload should happen)
print("\nRe-initializing with same model...")
plugin.initialize({"model": "tiny", "device": "cpu"})

# Re-initialize with different model (unload should trigger)
print("\nRe-initializing with different model...")
plugin.initialize({"model": "base", "device": "cpu"})
print(f"New config: model={plugin.config.model}")

INFO:__main__.WhisperLocalPlugin:Initialized Whisper plugin with model 'tiny' on device 'cpu'
INFO:__main__.WhisperLocalPlugin:Initialized Whisper plugin with model 'tiny' on device 'cpu'
INFO:__main__.WhisperLocalPlugin:Config change: Model tiny -> base
INFO:__main__.WhisperLocalPlugin:Initialized Whisper plugin with model 'base' on device 'cpu'


Initial config: model=tiny

Re-initializing with same model...

Re-initializing with different model...
New config: model=base


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()