# Silero VAD Plugin

> Plugin implementation for Voice Activity Detection using Silero VAD with SQLite result caching.

In [None]:
#| default_exp plugin

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import sqlite3
import json
import time
import os
import hashlib
import logging
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Union, Tuple
from pathlib import Path

import numpy as np
import librosa

# Tier 2 Imports
from cjm_media_plugin_system.analysis_interface import MediaAnalysisPlugin
from cjm_media_plugin_system.core import MediaAnalysisResult, TimeRange

# Plugin System Utils
from cjm_plugin_system.utils.validation import (
    dict_to_config, config_to_dict, dataclass_to_jsonschema,
    SCHEMA_TITLE, SCHEMA_DESC, SCHEMA_MIN, SCHEMA_MAX, SCHEMA_ENUM
)
from cjm_media_plugin_silero_vad.meta import get_plugin_metadata

# Silero Imports
try:
    from silero_vad import load_silero_vad, get_speech_timestamps
    SILERO_AVAILABLE = True
except ImportError:
    SILERO_AVAILABLE = False

## Configuration

The `SileroVADConfig` dataclass defines all configurable parameters for the VAD model.

In [None]:
#| export
@dataclass
class SileroVADConfig:
    """Configuration for Silero VAD parameters."""
    
    threshold: float = field(
        default=0.5,
        metadata={
            SCHEMA_TITLE: "Threshold",
            SCHEMA_DESC: "Speech probability threshold (0.0 - 1.0). Higher values reduce false positives.",
            SCHEMA_MIN: 0.0, SCHEMA_MAX: 1.0
        }
    )
    
    min_speech_duration_ms: int = field(
        default=250,
        metadata={
            SCHEMA_TITLE: "Min Speech Duration (ms)",
            SCHEMA_DESC: "Segments shorter than this will be ignored.",
            SCHEMA_MIN: 0
        }
    )
    
    min_silence_duration_ms: int = field(
        default=100,
        metadata={
            SCHEMA_TITLE: "Min Silence Duration (ms)",
            SCHEMA_DESC: "Silence shorter than this will not split segments.",
            SCHEMA_MIN: 0
        }
    )
    
    speech_pad_ms: int = field(
        default=30,
        metadata={
            SCHEMA_TITLE: "Speech Padding (ms)",
            SCHEMA_DESC: "Padding added to the start/end of each speech segment.",
            SCHEMA_MIN: 0
        }
    )
    
    sampling_rate: int = field(
        default=16000,
        metadata={
            SCHEMA_TITLE: "Sampling Rate",
            SCHEMA_DESC: "Target sampling rate for VAD processing (Silero expects 8k or 16k).",
            SCHEMA_ENUM: [8000, 16000]
        }
    )
    
    use_onnx: bool = field(
        default=True,
        metadata={
            SCHEMA_TITLE: "Use ONNX",
            SCHEMA_DESC: "Use ONNX runtime for inference (faster)."
        }
    )

## Plugin Implementation

The `SileroVADPlugin` implements `MediaAnalysisPlugin` to provide voice activity detection.

In [None]:
#| export
class SileroVADPlugin(MediaAnalysisPlugin):
    """Voice Activity Detection plugin using Silero VAD."""
    
    config_class = SileroVADConfig
    
    def __init__(self):
        """Initialize the Silero VAD plugin."""
        self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}")
        self.config: SileroVADConfig = None
        self._model = None
        self._db_path = None

    @property
    def name(self) -> str:  # Plugin name identifier
        """Get the plugin name identifier."""
        return "silero-vad"
    
    @property
    def version(self) -> str:  # Plugin version string
        """Get the plugin version string."""
        return "1.0.0"
    
    @property
    def supported_media_types(self) -> List[str]:  # Supported media types
        """Get the list of supported media types."""
        return ["audio", "video"]

    def get_current_config(self) -> Dict[str, Any]:  # Current configuration as dictionary
        """Return current configuration state."""
        return config_to_dict(self.config) if self.config else {}

    def get_config_schema(self) -> Dict[str, Any]:  # JSON Schema for configuration
        """Return JSON Schema for UI generation."""
        return dataclass_to_jsonschema(SileroVADConfig)

    def initialize(
        self,
        config: Optional[Any] = None  # Configuration dataclass, dict, or None
    ) -> None:
        """Initialize or re-configure the plugin (idempotent)."""
        self.config = dict_to_config(SileroVADConfig, config or {})
        
        # Set DB Path
        self._db_path = get_plugin_metadata()["db_path"]
        self._init_db()
        
        self.logger.info(f"Initialized Silero VAD plugin with threshold={self.config.threshold}")

    def _init_db(self) -> None:
        """Ensure local cache database exists."""
        with sqlite3.connect(self._db_path) as con:
            con.execute("""
                CREATE TABLE IF NOT EXISTS vad_jobs (
                    file_path TEXT PRIMARY KEY,
                    file_hash TEXT,
                    config_hash TEXT,
                    ranges JSON,
                    metadata JSON,
                    created_at REAL
                )
            """)

    def _load_model(self) -> None:
        """Lazy load the Silero model."""
        if self._model is None:
            if not SILERO_AVAILABLE:
                raise ImportError("silero-vad not installed.")
            self.logger.info("Loading Silero VAD model")
            self._model = load_silero_vad(onnx=self.config.use_onnx)
            self.logger.info("Silero VAD model loaded successfully")

    def _load_audio(
        self,
        path: str,      # Path to audio file
        target_sr: int  # Target sampling rate
    ) -> Tuple[np.ndarray, float]:  # (audio array, duration in seconds)
        """Load and normalize audio using librosa."""
        # Use librosa to load (supports many formats)
        y, sr = librosa.load(path, sr=target_sr, mono=True)
        
        # Normalize
        peak = np.max(np.abs(y)) if len(y) else 0.0
        if peak > 0:
            y = y / peak
            
        return y.astype(np.float32), float(len(y) / sr)

    def execute(
        self,
        media_path: Union[str, Path],  # Path to media file to analyze
        force: bool = False,           # If True, ignore cache and re-run
        **kwargs                       # Override config parameters for this run
    ) -> MediaAnalysisResult:  # Analysis result with detected speech segments
        """Run VAD on the audio file."""
        media_path = str(media_path)
        
        # Apply runtime overrides if any
        if kwargs:
            # Create a temporary merged config for this run
            run_config = dict_to_config(SileroVADConfig, {**config_to_dict(self.config), **kwargs})
        else:
            run_config = self.config

        config_hash = json.dumps(config_to_dict(run_config), sort_keys=True)
        
        # 1. Check Cache
        if not force:
            with sqlite3.connect(self._db_path) as con:
                cur = con.execute(
                    "SELECT ranges, metadata, config_hash FROM vad_jobs WHERE file_path = ?",
                    (media_path,)
                )
                row = cur.fetchone()
                
                # Use cached if config matches (params haven't changed)
                if row and row[2] == config_hash:
                    self.logger.info(f"Using cached VAD result for {media_path}")
                    ranges_data = json.loads(row[0])
                    meta_data = json.loads(row[1])
                    return MediaAnalysisResult(
                        ranges=[TimeRange(**r) for r in ranges_data],
                        metadata=meta_data
                    )

        # 2. Process
        self._load_model()
        
        # Load Audio
        self.logger.info(f"Processing audio: {media_path}")
        wav, duration = self._load_audio(media_path, run_config.sampling_rate)
        
        # Run Inference
        speech_timestamps = get_speech_timestamps(
            audio=wav,
            model=self._model,
            threshold=run_config.threshold,
            sampling_rate=run_config.sampling_rate,
            min_speech_duration_ms=run_config.min_speech_duration_ms,
            min_silence_duration_ms=run_config.min_silence_duration_ms,
            speech_pad_ms=run_config.speech_pad_ms,
            return_seconds=True,  # Critical: We want seconds, not samples
            visualize_probs=False
        )
        
        # Convert to TimeRange DTOs
        ranges = []
        for ts in speech_timestamps:
            ranges.append(TimeRange(
                start=ts['start'],
                end=ts['end'],
                label="speech",
                confidence=1.0  # Silero binary output implies high confidence if passed threshold
            ))
        
        # Calculate total speech duration
        total_speech = sum(r.end - r.start for r in ranges)
        
        metadata = {
            "duration": duration,
            "sample_rate": run_config.sampling_rate,
            "total_speech": total_speech,
            "total_silence": duration - total_speech,
            "segment_count": len(ranges),
            "processed_at": time.time()
        }
        
        self.logger.info(f"Detected {len(ranges)} speech segments ({total_speech:.2f}s speech / {duration:.2f}s total)")

        # 3. Save to Cache
        with sqlite3.connect(self._db_path) as con:
            con.execute(
                """
                INSERT OR REPLACE INTO vad_jobs
                (file_path, config_hash, ranges, metadata, created_at)
                VALUES (?, ?, ?, ?, ?)
                """,
                (
                    media_path,
                    config_hash,
                    json.dumps([r.to_dict() for r in ranges]),
                    json.dumps(metadata),
                    time.time()
                )
            )
            
        return MediaAnalysisResult(ranges=ranges, metadata=metadata)

    def is_available(self) -> bool:  # True if Silero VAD is available
        """Check if Silero VAD is available."""
        return SILERO_AVAILABLE

    def cleanup(self) -> None:
        """Clean up resources."""
        if self._model is not None:
            self.logger.info("Unloading Silero VAD model")
            self._model = None

## Testing the Plugin

In [None]:
# Test basic functionality
plugin = SileroVADPlugin()

# Check availability
print(f"Silero VAD available: {plugin.is_available()}")
print(f"Plugin name: {plugin.name}")
print(f"Plugin version: {plugin.version}")
print(f"Supported media types: {plugin.supported_media_types}")
print(f"Config class: {plugin.config_class.__name__}")

Silero VAD available: True
Plugin name: silero-vad
Plugin version: 1.0.0
Supported media types: ['audio', 'video']
Config class: SileroVADConfig


In [None]:
# Test configuration dataclass
from dataclasses import fields

print("Configuration fields:")
for f in fields(SileroVADConfig):
    title = f.metadata.get(SCHEMA_TITLE, f.name)
    default = f.default
    print(f"  {title}: {default}")

Configuration fields:
  Threshold: 0.5
  Min Speech Duration (ms): 250
  Min Silence Duration (ms): 100
  Speech Padding (ms): 30
  Sampling Rate: 16000
  Use ONNX: True


In [None]:
# Test configuration validation
test_configs = [
    ({"threshold": 0.5}, "Valid config"),
    ({"threshold": 1.5}, "Threshold out of range"),
    ({"sampling_rate": 44100}, "Invalid sampling rate"),
]

for config, description in test_configs:
    try:
        test_cfg = dict_to_config(SileroVADConfig, config, validate=True)
        print(f"{description}: Valid=True")
    except ValueError as e:
        print(f"{description}: Valid=False")
        print(f"  Error: {str(e)[:80]}")

Valid config: Valid=True
Threshold out of range: Valid=False
  Error: threshold: 1.5 is greater than maximum 1.0
Invalid sampling rate: Valid=False
  Error: sampling_rate: 44100 is not one of [8000, 16000]


In [None]:
#| eval: false
# Test get_config_schema for UI generation
import json

schema = plugin.get_config_schema()
print("JSON Schema for SileroVADConfig:")
print(f"  Name: {schema['name']}")
print(f"  Properties count: {len(schema['properties'])}")
print(f"\nSample properties:")
print(json.dumps({k: v for k, v in list(schema['properties'].items())[:3]}, indent=2))

JSON Schema for SileroVADConfig:
  Name: SileroVADConfig
  Properties count: 6

Sample properties:
{
  "threshold": {
    "type": "number",
    "title": "Threshold",
    "description": "Speech probability threshold (0.0 - 1.0). Higher values reduce false positives.",
    "minimum": 0.0,
    "maximum": 1.0,
    "default": 0.5
  },
  "min_speech_duration_ms": {
    "type": "integer",
    "title": "Min Speech Duration (ms)",
    "description": "Segments shorter than this will be ignored.",
    "minimum": 0,
    "default": 250
  },
  "min_silence_duration_ms": {
    "type": "integer",
    "title": "Min Silence Duration (ms)",
    "description": "Silence shorter than this will not split segments.",
    "minimum": 0,
    "default": 100
  }
}


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()