<a href="https://colab.research.google.com/github/kurtvalcorza/notebooks/blob/main/VibeVoice_ASR_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VibeVoice-ASR: Unified Speech-to-Text with Speaker Diarization

This notebook demonstrates [Microsoft's VibeVoice-ASR](https://huggingface.co/microsoft/VibeVoice-ASR), a 9B parameter model that provides:

- **Who**: Speaker identification/diarization
- **When**: Precise timestamps
- **What**: Transcribed content
- **60-minute single-pass processing** with global context
- **Customizable hotwords** for domain-specific accuracy

## Requirements

- **GPU Memory**: ~18-20GB VRAM recommended
- **Colab Runtime**: Use **A100 GPU** (Colab Pro/Pro+) for best results. T4 (16GB) may work with 4-bit quantization.

To change runtime: `Runtime` → `Change runtime type` → Select `A100 GPU` or `T4 GPU`

## 1. Check GPU and Environment

In [None]:
import torch

if not torch.cuda.is_available():
    raise RuntimeError(
        "No GPU detected! Please enable GPU runtime:\n"
        "Runtime → Change runtime type → Hardware accelerator → GPU"
    )

gpu_name = torch.cuda.get_device_name(0)
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

print(f"GPU: {gpu_name}")
print(f"VRAM: {gpu_memory_gb:.1f} GB")

if gpu_memory_gb < 16:
    print("\n⚠️  Warning: GPU has less than 16GB VRAM. Model may not fit.")
    print("   Consider using Colab Pro with A100 GPU.")
elif gpu_memory_gb < 20:
    print("\n⚠️  Note: T4 GPU detected. Using 4-bit quantization for memory efficiency.")
    USE_QUANTIZATION = True
else:
    print("\n✓ Sufficient VRAM for full-precision inference.")
    USE_QUANTIZATION = False

## 2. Install Dependencies

In [None]:
%%capture
# Install core dependencies
!pip install -U transformers accelerate bitsandbytes
!pip install soundfile librosa
!apt-get update -qq && apt-get install -qq ffmpeg

# Install flash-attention for better performance on Ampere+ GPUs (A100, etc.)
# This may take a few minutes to compile
import torch
if torch.cuda.get_device_capability()[0] >= 8:
    print("Installing flash-attention for Ampere+ GPU...")
    !pip install flash-attn --no-build-isolation -q

In [None]:
# Clone VibeVoice repository for the processor and utilities
!git clone https://github.com/microsoft/VibeVoice.git /content/VibeVoice 2>/dev/null || echo "Repository already cloned"
!pip install -e /content/VibeVoice[asr] -q

## 3. Mount Google Drive (Optional)

Mount your Google Drive to access audio files stored there.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 4. Load VibeVoice-ASR Model

In [None]:
import torch
import sys

# Add VibeVoice to path
sys.path.insert(0, '/content/VibeVoice')

from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor

MODEL_ID = "microsoft/VibeVoice-ASR"

# Check if we should use quantization (set in cell 1)
try:
    use_quant = USE_QUANTIZATION
except NameError:
    use_quant = torch.cuda.get_device_properties(0).total_memory / 1e9 < 20

# Determine attention implementation based on GPU capability
gpu_capability = torch.cuda.get_device_capability()[0]
if gpu_capability >= 8:
    attn_impl = "flash_attention_2"
else:
    attn_impl = "sdpa"

print(f"Loading model: {MODEL_ID}")
print(f"Using 4-bit quantization: {use_quant}")
print(f"Attention implementation: {attn_impl}")

# Load processor (uses Qwen2.5-7B as base language model)
processor = VibeVoiceASRProcessor.from_pretrained(
    MODEL_ID,
    language_model_pretrained_name="Qwen/Qwen2.5-7B"
)

# Load model
if use_quant:
    from transformers import BitsAndBytesConfig
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    model = VibeVoiceASRForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=quantization_config,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation=attn_impl
    )
else:
    model = VibeVoiceASRForConditionalGeneration.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation=attn_impl
    )

print("\n✓ Model loaded successfully!")

## 5. Transcription Functions

In [None]:
import librosa
import numpy as np
import re

def load_audio(audio_path, target_sr=16000):
    """Load audio file and resample to target sample rate."""
    audio, sr = librosa.load(audio_path, sr=target_sr)
    return audio, sr

def transcribe(
    audio_path,
    hotwords=None,
    max_new_tokens=None,  # Auto-calculated if None
    temperature=0.0,
):
    """
    Transcribe audio file with speaker diarization and timestamps.
    
    Args:
        audio_path: Path to audio file
        hotwords: Optional list of domain-specific terms/names to improve accuracy
        max_new_tokens: Maximum tokens to generate (auto-calculated based on duration if None)
        temperature: Sampling temperature (0.0 for deterministic)
    
    Returns:
        Dictionary with raw text and parsed segments
    """
    # Load audio
    audio_data, sr = load_audio(audio_path)
    duration = len(audio_data) / sr
    print(f"Audio duration: {duration:.1f} seconds")
    
    if duration > 3600:
        print("⚠️  Warning: Audio exceeds 60 minutes. Results may be truncated.")
    
    # Auto-calculate max_new_tokens based on duration if not specified
    # Estimate: ~3 words/sec * 1.5 tokens/word * 2x buffer for timestamps/speakers
    if max_new_tokens is None:
        max_new_tokens = max(256, min(int(duration * 10) + 256, 8192))
        print(f"Auto max_new_tokens: {max_new_tokens}")
    
    # Build generation prompt with optional hotwords
    hotword_str = ", ".join(hotwords) if hotwords else None
    
    # Prepare inputs using the VibeVoice processor
    inputs = processor(
        audio=audio_data,
        sampling_rate=sr,
        return_tensors="pt",
        padding=True,
        add_generation_prompt=True,
    )
    
    # Move inputs to model device
    device = next(model.parameters()).device
    inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
    
    # Generate transcription
    print("Generating transcription...")
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature if temperature > 0 else None,
            do_sample=temperature > 0,
            pad_token_id=processor.tokenizer.eos_token_id,
        )
    
    # Decode output - skip the input tokens
    input_len = inputs['input_ids'].shape[1]
    generated_ids = outputs[0][input_len:]
    
    # Find EOS token and truncate
    eos_token_id = processor.tokenizer.eos_token_id
    if eos_token_id in generated_ids:
        eos_idx = (generated_ids == eos_token_id).nonzero(as_tuple=True)[0]
        if len(eos_idx) > 0:
            generated_ids = generated_ids[:eos_idx[0]]
    
    transcription = processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Post-process using VibeVoice's built-in function if available
    try:
        from vibevoice.processor.vibevoice_asr_processor import post_process_transcription
        processed = post_process_transcription(transcription)
        segments = processed.get("segments", [])
        raw_text = processed.get("text", transcription)
    except (ImportError, AttributeError):
        segments = parse_transcription(transcription)
        raw_text = transcription
    
    return {
        "raw_text": raw_text,
        "segments": segments,
        "duration": duration
    }

def parse_transcription(text):
    """
    Parse VibeVoice-ASR output into structured segments.
    Expected format: [start-end] Speaker: text or JSON format
    """
    segments = []
    
    # Try JSON format first
    try:
        import json
        # Handle potential incomplete JSON
        if text.strip().startswith('['):
            # Try to parse as JSON array
            data = json.loads(text)
            for item in data:
                segments.append({
                    "start": item.get("Start", item.get("start", 0)),
                    "end": item.get("End", item.get("end", 0)),
                    "speaker": f"Speaker {item.get('Speaker', item.get('speaker', 0))}",
                    "text": item.get("Content", item.get("content", item.get("text", "")))
                })
            return segments
    except (json.JSONDecodeError, TypeError):
        pass
    
    # Fallback to timestamp pattern matching
    pattern = r'\[([\d.]+)[-–]([\d.]+)\]\s*(Speaker\s*\d+|[^:]+):\s*(.+?)(?=\[[\d.]|$)'
    
    for match in re.finditer(pattern, text, re.DOTALL):
        segments.append({
            "start": float(match.group(1)),
            "end": float(match.group(2)),
            "speaker": match.group(3).strip(),
            "text": match.group(4).strip()
        })
    
    return segments

def print_transcript(result, show_timestamps=True):
    """Pretty print transcription result."""
    print("\n" + "="*60)
    print("TRANSCRIPTION")
    print("="*60 + "\n")
    
    if result["segments"]:
        for seg in result["segments"]:
            if show_timestamps:
                start = seg.get('start', seg.get('start_time', 0))
                end = seg.get('end', seg.get('end_time', 0))
                speaker = seg.get('speaker', 'Unknown')
                text = seg.get('text', '')
                print(f"[{start:.2f}-{end:.2f}] {speaker}: {text}")
            else:
                print(f"{seg.get('speaker', 'Unknown')}: {seg.get('text', '')}")
            print()
    else:
        # Fallback to raw text if parsing failed
        print(result["raw_text"])
    
    print("="*60)

def save_transcript(result, output_path, show_timestamps=True):
    """Save transcription to text file."""
    with open(output_path, "w", encoding="utf-8") as f:
        if result["segments"]:
            for seg in result["segments"]:
                start = seg.get('start', seg.get('start_time', 0))
                end = seg.get('end', seg.get('end_time', 0))
                speaker = seg.get('speaker', 'Unknown')
                text = seg.get('text', '')
                if show_timestamps:
                    f.write(f"[{start:.2f}-{end:.2f}] {speaker}: {text}\n")
                else:
                    f.write(f"{speaker}: {text}\n")
        else:
            f.write(result["raw_text"])
    print(f"Transcript saved to: {output_path}")

## 6. Transcribe Audio

Set the path to your audio file and optional hotwords.

In [None]:
# Configuration
audio_file = "/content/drive/MyDrive/audio.wav"  #@param {type: "string"}
hotwords = ""  #@param {type: "string"}
save_output = True  #@param {type: "boolean"}

# Parse hotwords
hotword_list = [h.strip() for h in hotwords.split(",") if h.strip()] if hotwords else None

# Run transcription
result = transcribe(
    audio_path=audio_file,
    hotwords=hotword_list
)

# Display results
print_transcript(result)

# Save to file
if save_output:
    import os
    output_path = os.path.splitext(audio_file)[0] + "_vibevoice_transcript.txt"
    save_transcript(result, output_path)

## 7. Batch Processing (Multiple Files)

Process multiple audio files from a directory.

In [None]:
import os
from pathlib import Path

input_dir = "/content/drive/MyDrive/audio_files"  #@param {type: "string"}
output_dir = "/content/drive/MyDrive/transcripts"  #@param {type: "string"}
audio_extensions = [".wav", ".mp3", ".m4a", ".flac", ".ogg"]  # Supported formats

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Find audio files
audio_files = []
for ext in audio_extensions:
    audio_files.extend(Path(input_dir).glob(f"*{ext}"))

print(f"Found {len(audio_files)} audio files to process\n")

# Process each file
for i, audio_path in enumerate(audio_files, 1):
    print(f"\n[{i}/{len(audio_files)}] Processing: {audio_path.name}")
    print("-" * 40)
    
    try:
        result = transcribe(str(audio_path))
        
        # Save transcript
        output_path = os.path.join(output_dir, f"{audio_path.stem}_transcript.txt")
        save_transcript(result, output_path)
        
    except Exception as e:
        print(f"Error processing {audio_path.name}: {e}")
        continue

print(f"\n\nCompleted! Transcripts saved to: {output_dir}")

## Additional Notes

### Model Capabilities
- Processes up to 60 minutes of audio in a single pass
- Maintains consistent speaker tracking across long recordings
- Supports English and Chinese

### Hotwords
Use hotwords to improve accuracy for:
- Names of people, products, or companies
- Technical terminology
- Domain-specific vocabulary

### Resources
- [VibeVoice GitHub](https://github.com/microsoft/VibeVoice)
- [Model on Hugging Face](https://huggingface.co/microsoft/VibeVoice-ASR)
- [Live Demo](https://aka.ms/vibevoice-asr)