## \[Research\] Audio transcription with `faster-whisper`

This notebook demonstrates the audio processing pipeline, focusing on transcribing audio with `faster-whisper`.

### 1. Importing

In [None]:
import justsdk
import warnings
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from faster_whisper import WhisperModel
from datetime import timedelta
from IPython.display import Audio, display

warnings.filterwarnings("ignore")

ROOT = Path.cwd().parent
DATA_DIR = ROOT / "data"
RAW_DIR = DATA_DIR / "raw"
SAMPLE_DIR = DATA_DIR / "sample" / "audio"
MODEL_DIR = ROOT / "models"

### 2. Setting configs

In [None]:
WHISPER_MODEL = "base"

WHISPER_CONFIG = {
    "model_size": WHISPER_MODEL,
    "device": "cpu",
    "compute_type": "int8",
    "num_workers": 2,
    "download_root": str(MODEL_DIR / f"whisper-{WHISPER_MODEL}"),
}

AUDIO_CONFIG = {
    "language": "en",
    "task": "transcribe",
    "beam_size": 5,  # Paths searches during decoding
    "best_of": 5,
    "patience": 1,
    "length_penalty": 1,
    "temperature": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],  # Temperature fallback
    "compression_ratio_threshold": 2.4,  # Reject if text is too repetitive
    "log_prob_threshold": -1.0,  # Threshold for confidence levels
    "no_speech_threshold": 0.6,  # Threshold for non-speech detection
    "word_timestamps": True,  # Generate word-level timestamps
    "vad_filter": True,  # Skip silent parts
    "vad_parameters": {
        "threshold": 0.5,
        "min_speech_duration_ms": 250,
        "max_speech_duration_s": float("inf"),
        "min_silence_duration_ms": 2000,
        "speech_pad_ms": 400,
    },
}

justsdk.print_info(f"WHISPER_MODEL: {WHISPER_MODEL}")
justsdk.print_info(f"DEVICE: {WHISPER_CONFIG['device']}")

### 3. Linking sample audio files

In [None]:
sample_audio = []
audio_ext = [".wav", ".mp3", ".mp4"]

for ext in audio_ext:
    sample_audio.extend(SAMPLE_DIR.glob(f"*{ext}"))

# Check the sample audio size in MB
justsdk.print_info("sample_audio:")
for audio in sample_audio:
    size_mb = audio.stat().st_size / (1024 * 1024)
    print(f"  {audio.name}: {size_mb:.2f} MB")

### 4. Initializing `faster-whisper`

In [None]:
try:
    model = WhisperModel(
        WHISPER_CONFIG["model_size"],
        device=WHISPER_CONFIG["device"],
        compute_type=WHISPER_CONFIG["compute_type"],
        download_root=WHISPER_CONFIG["download_root"],
        num_workers=WHISPER_CONFIG["num_workers"],
    )
    justsdk.print_success("Whisper model loaded successfully.")
except Exception as e:
    justsdk.print_error(f"Failed to load Whisper model: {e}")

### 5. Functions to transcribe audio

In [None]:
def format_timestamp(seconds: float) -> str:
    """Convert seconds to HH:MM:SS.mmm format"""
    td = timedelta(seconds=seconds)
    hours = int(td.total_seconds() // 3600)
    minutes = int((td.total_seconds() % 3600) // 60)
    seconds = td.total_seconds() % 60
    return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"


def transcribe_audio(audio_path: Path, model: WhisperModel, config: dict) -> dict:
    try:
        segments_generator, info = model.transcribe(
            str(audio_path),
            language=config["language"],
            task=config["task"],
            beam_size=config["beam_size"],
            best_of=config["best_of"],
            patience=config["patience"],
            length_penalty=config["length_penalty"],
            temperature=config["temperature"],
            compression_ratio_threshold=config["compression_ratio_threshold"],
            log_prob_threshold=config["log_prob_threshold"],
            no_speech_threshold=config["no_speech_threshold"],
            word_timestamps=config["word_timestamps"],
            vad_filter=config["vad_filter"],
            vad_parameters=config["vad_parameters"],
        )
        segments = list(segments_generator)

        processed_segments = []
        word_segments = []
        full_text = ""

        for segment in segments:
            seg_dict = {
                "id": segment.id,
                "start": segment.start,
                "end": segment.end,
                "text": segment.text.strip(),
                "tokens": segment.tokens,
                "temperature": segment.temperature,
                "avg_logprob": segment.avg_logprob,
                "compression_ratio": segment.compression_ratio,
                "no_speech_prob": segment.no_speech_prob,
                "start_formatted": format_timestamp(segment.start),
                "end_formatted": format_timestamp(segment.end),
                "duration": segment.end - segment.start,
            }
            processed_segments.append(seg_dict)
            full_text += seg_dict["text"] + " "

            if hasattr(segment, "words") and segment.words:
                for word in segment.words:
                    word_segments.append(
                        {
                            "word": word.word,
                            "start": word.start,
                            "end": word.end,
                            "probability": word.probability,
                            "start_formatted": format_timestamp(word.start),
                            "end_formatted": format_timestamp(word.end),
                        }
                    )

        audio_info = {
            "language": info.language,
            "language_probability": info.language_probability,
            "duration": info.duration,
            "duration_formatted": format_timestamp(info.duration),
        }

        result = {
            "file": audio_path.name,
            "num_segments": len(processed_segments),
            "segments": processed_segments,
            "num_words": len(word_segments),
            "words": word_segments,
            "full_text": full_text.strip(),
            "audio_info": audio_info,
        }
        return result
    except Exception as e:
        justsdk.print_error(f"Failed to transcribe {audio_path.name}: {e}")
        return None

### 6. Transcribing sample audio

In [None]:
transcription = None
if sample_audio:
    selected_audio = sample_audio[0]
    justsdk.print_info(f"Transcribing audio: {selected_audio.name}")
    display(Audio(str(selected_audio)))
    transcription = transcribe_audio(selected_audio, model, AUDIO_CONFIG)
else:
    justsdk.print_error("No sample audio files found to transcribe.")

if transcription:
    justsdk.print_success(f"Transcription for {transcription['file']} completed.")
    print(f"  Total segments: {transcription['num_segments']}")
    print(f"  Total words: {transcription['num_words']}")
    print(f"{'=' * 10} Full Text {'=' * 10}")
    print(transcription["full_text"])
    print(f"{'=' * 31}")
    print(f"\n  Audio duration: {transcription['audio_info']['duration_formatted']}")
else:
    justsdk.print_error("Transcription failed or no audio processed.")

if transcription and transcription["segments"]:
    segments_df = pd.DataFrame(transcription["segments"])
    for index, seg in segments_df.iterrows():
        justsdk.print_info(
            f"  [{seg['start_formatted']} -> {seg['end_formatted']}] ({seg['duration']:.2f}s)",
            newline_before=True
        )
        print(f"{'=' * 10} Text {'=' * 10}")
        print(seg["text"])
        print(f"{'=' * 26}")
        print(f"  Confidence: {seg['avg_logprob']:.4f}")
        print(f"  Non speech probability: {seg['no_speech_prob']:.4f}")

### 7. Visualizing the metrics

In [None]:
if transcription and transcription["segments"]:
    segments_df = pd.DataFrame(transcription["segments"])

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Segment durations
    axes[0, 0].set_title("Segment Durations")
    axes[0, 0].bar(range(len(segments_df)), segments_df["duration"])
    axes[0, 0].set_xlabel("Segment ID")
    axes[0, 0].set_ylabel("Duration (s)")

    # Confidence scores
    axes[0, 1].set_title(
        "Transcription Confidence by Segment (Negative Log Probability)"
    )
    axes[0, 1].plot(segments_df.index, -segments_df["avg_logprob"], "o-")
    axes[0, 1].set_xlabel("Segment ID")
    axes[0, 1].set_ylabel("Confidence Score")

    # Words per segment
    words_per_segment = segments_df["text"].apply(lambda x: len(x.split()))
    axes[1, 0].set_title("Words per Segment")
    axes[1, 0].bar(range(len(segments_df)), words_per_segment)
    axes[1, 0].set_xlabel("Segment ID")
    axes[1, 0].set_ylabel("Number of Words")

    # Non-speech probability
    axes[1, 1].set_title("Non-Speech Probability by Segment")
    axes[1, 1].scatter(
        segments_df.index,
        segments_df["no_speech_prob"],
        c=segments_df["no_speech_prob"],
        cmap="RdYlGn_r",
    )
    axes[1, 1].set_xlabel("Segment ID")
    axes[1, 1].set_ylabel("Non-Speech Probability")
    axes[1, 1].axhline(y=0.5, color="r", linestyle="--", alpha=0.5, label="Threshold")

    plt.tight_layout()
    plt.show()

    justsdk.print_info("Summary statistics:")
    print(f"  Average segment duration(s): {segments_df['duration'].mean():.2f}")
    print(f"  Average confidence score: {-segments_df['avg_logprob'].mean():.4f}")
    print(f"  Average words per segment: {words_per_segment.mean():.2f}")
    print(
        f"  Average non-speech probability: {segments_df['no_speech_prob'].mean():.4f}"
    )
    print(
        f"  Total audio duration(s): {transcription['audio_info']['duration_formatted']}"
    )