## Audio transcription + Speaker diarization

1. Transcribe audio using `faster-whisper`.

2. Speaker diarization using `pyannote-audio`.

3. Combine the results from both steps to create a structured output.

### 1. Importing libraries

In [None]:
import warnings
import justsdk
import pandas as pd
import numpy as np
import torch

from pathlib import Path
from datetime import timedelta
from faster_whisper import WhisperModel
from IPython.display import Audio, display
from pyannote.audio import Pipeline
from typing import Optional
from _constants import HF_TOKEN, SAMPLE_DIR, MODEL_DIR
from pyannote.audio.pipelines.utils.hook import ProgressHook

### 2. Configurations

In [None]:
warnings.filterwarnings("ignore")

TARGET_SAMPLE = "project-proposal.mp3"

WHISPER_MODEL = "base"

WHISPER_CONFIG = {
    "model_size": WHISPER_MODEL,
    "device": "cpu",
    "compute_type": "int8",
    "num_workers": 2,
    "download_root": str(MODEL_DIR / f"whisper-{WHISPER_MODEL}"),
}

AUDIO_CONFIG = {
    "language": "en",
    "task": "transcribe",
    "beam_size": 5,  # Paths searches during decoding
    "best_of": 5,
    "patience": 1,
    "length_penalty": 1,
    "temperature": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],  # Temperature fallback
    "compression_ratio_threshold": 2.4,  # Reject if text is too repetitive
    "log_prob_threshold": -1.0,  # Threshold for confidence levels
    "no_speech_threshold": 0.6,  # Threshold for non-speech detection
    "word_timestamps": True,  # Generate word-level timestamps
    "vad_filter": True,  # Skip silent parts
    "vad_parameters": {
        "threshold": 0.5,
        "min_speech_duration_ms": 250,
        "max_speech_duration_s": float("inf"),
        "min_silence_duration_ms": 2000,
        "speech_pad_ms": 400,
    },
}

DIARIZATION_CONFIG = {
    "num_speakers": None,
    "min_speakers": None,
    "max_speakers": None,
    "segmentation_onset": 0.5,  # XXX: Learn more about this threshold
    "clustering": {  # XXX: Learn more about clustering options
        "method": "centroid",
        "min_cluster_size": 15,
        "threshold": 0.7,
    },
}

### 3. Helper: General

In [None]:
def format_timestamp(seconds: float) -> str:
    """Convert seconds to `HH:MM:SS.mm` format."""
    if seconds < 0:
        return "00:00:00.00"
    td = timedelta(seconds=seconds)
    hours = int(td.total_seconds() // 3600)
    minutes = int((td.total_seconds() % 3600) // 60)
    seconds = td.total_seconds() % 60
    return f"{hours:02d}:{minutes:02d}:{seconds:05.2f}"

### 4. Get sample files

In [None]:
def get_sample_files() -> dict:
    sample_ext = [".mp3", ".mp4"]
    sample_dirs = ["audio", "video"]

    samples = [
        sample
        for sample_dir in sample_dirs
        for ext in sample_ext
        for sample in (SAMPLE_DIR / sample_dir).glob(f"*{ext}")
    ]

    samples_dict = {}
    for sample in samples:
        size_mb = sample.stat().st_size / (1024 * 1024)
        samples_dict[sample.name] = {
            "path": sample,
            "size_mb": size_mb,
        }

    return samples_dict


sample_input = get_sample_files()
justsdk.print_info("Sample files found:")
for name, info in sample_input.items():
    print(f"  {name} ({info['size_mb']:.2f} MB): {info['path']}")

### 5. Init `faster-whisper`

In [None]:
try:
    whisper_model = WhisperModel(
        model_size_or_path=WHISPER_CONFIG["model_size"],
        device=WHISPER_CONFIG["device"],
        compute_type=WHISPER_CONFIG["compute_type"],
        num_workers=WHISPER_CONFIG["num_workers"],
        download_root=WHISPER_CONFIG["download_root"],
    )
    justsdk.print_success("Loaded Whisper model.")
except Exception as e:
    justsdk.print_error(f"Failed to load Whisper model: {e}")

### 6. Helper: Transcription

In [None]:
def transcribe_audio(audio_path: Path, model: WhisperModel, config: dict) -> dict:
    try:
        segments_generator, info = model.transcribe(str(audio_path), **config)
        segments = list(segments_generator)

        processed_segments = []
        word_segments = []
        full_text_parts = []

        for segment in segments:
            seg_dict = _process_segment(segment)
            processed_segments.append(seg_dict)
            full_text_parts.append(seg_dict["text"])

            # Process words if available
            if hasattr(segment, "words") and segment.words:
                word_segments.extend(_process_words(segment.words))

        return {
            "file": audio_path.name,
            "audio_info": _process_audio_info(info),
            "num_segments": len(processed_segments),
            "segments": processed_segments,
            "num_words": len(word_segments),
            "words": word_segments,
            "full_text": " ".join(full_text_parts),
        }
    except Exception as e:
        justsdk.print_error(f"Failed to transcribe {audio_path.name}: {e}")
        return None


def _process_segment(segment) -> dict:
    return {
        "id": segment.id,
        "start": segment.start,
        "end": segment.end,
        "text": segment.text.strip(),
        "tokens": segment.tokens,
        "temperature": segment.temperature,
        "avg_logprob": segment.avg_logprob,
        "no_speech_prob": segment.no_speech_prob,
        "compression_ratio": segment.compression_ratio,
        "start_formatted": format_timestamp(segment.start),
        "end_formatted": format_timestamp(segment.end),
        "duration": segment.end - segment.start,
    }


def _process_words(words) -> list:
    return [
        {
            "word": word.word,
            "start": word.start,
            "end": word.end,
            "probability": word.probability,
            "start_formatted": format_timestamp(word.start),
            "end_formatted": format_timestamp(word.end),
        }
        for word in words
    ]


def _process_audio_info(info) -> dict:
    return {
        "language": info.language,
        "language_probability": info.language_probability,
        "duration": info.duration,
        "duration_formatted": format_timestamp(info.duration),
    }

### 7. Transcribing audio

In [None]:
if sample_input:
    sample_selected = sample_input[TARGET_SAMPLE]["path"]
    justsdk.print_info(f"Selected sample: {sample_selected.name} — {sample_selected}")
    display(Audio(sample_selected))
    transcription = transcribe_audio(
        audio_path=sample_selected,
        model=whisper_model,
        config=AUDIO_CONFIG,
    )
else:
    justsdk.print_error("No sample files found.")

### 8. Check full transcription

In [None]:
if transcription:
    justsdk.print_success(f"Completed transcription for {transcription['file']}")
    print(f"""
    Duration: \t{transcription["audio_info"]["duration_formatted"]}
    No. of segments: \t{transcription["num_segments"]}
    No. of words: \t{transcription["num_words"]}
    """)
    full_text = transcription["full_text"].replace(". ", ".\n  ")
    justsdk.print_info("Full Transcription:")
    print(f"  {full_text}")
else:
    justsdk.print_error("Transcription failed or returned no results.")

### 9. Check segments from transcription

In [None]:
if transcription and transcription["segments"]:
    segments_df = pd.DataFrame(transcription["segments"])
    for index, seg in segments_df.iterrows():
        confidence = round(np.exp(seg["avg_logprob"]), 4)
        print(f"""
        [{seg["start_formatted"]} -> {seg["end_formatted"]}] ({seg["duration"]:.2f}s)
          {seg["text"]}

          Confidence: {confidence}
          Non-speech prob.: {seg["no_speech_prob"]:.4f}
        """)
else:
    justsdk.print_error("No segments available for display.")

### 10. Metrics: Transcription segments

In [None]:
# TODO: Visualize segments

### 11. Helper: Speaker diarization

In [None]:
def initialize_diarization_pipeline(auth_token: Optional[str] = None) -> Pipeline:
    try:
        if not auth_token:
            raise ValueError("Hugging Face auth token is required for diarization.")

        pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1", use_auth_token=auth_token
        )

        # Use GPU if available
        if torch.cuda.is_available():
            pipeline = pipeline.to(torch.device("cuda"))
            justsdk.print_info("Using GPU for speaker diarization")
        else:
            justsdk.print_info("Using CPU for speaker diarization")

        return pipeline
    except Exception as e:
        justsdk.print_error(f"Failed to initialize diarization pipeline: {e}")
        return None


def perform_speaker_diarization(
    audio_path: Path, pipeline: Pipeline, config: dict = DIARIZATION_CONFIG
) -> dict:
    print(audio_path)
    try:
        with ProgressHook() as hook:
            diarization = pipeline(
                str(audio_path),
                hook=hook,
                num_speakers=config.get("num_speakers"),
                min_speakers=config.get("min_speakers"),
                max_speakers=config.get("max_speakers"),
            )
        speaker_segments: list = []
        speakers: set = set()

        for turn, _, speaker in diarization.itertracks(yield_label=True):
            speaker_segments.append(
                {
                    "speaker": speaker,
                    "start": turn.start,
                    "end": turn.end,
                    "start_formatted": format_timestamp(turn.start),
                    "end_formatted": format_timestamp(turn.end),
                    "duration": turn.end - turn.start,
                }
            )
            speakers.add(speaker)

        speaker_stats: dict = {speaker: 0.0 for speaker in speakers}
        for segment in speaker_segments:
            speaker_stats[segment["speaker"]] += segment["duration"]

        return {
            "segments": speaker_segments,
            "num_segments": len(speaker_segments),
            "speakers": sorted(list(speakers)),
            "num_speakers": len(speakers),
            "speaker_stats": speaker_stats,
            "total_duration": sum(speaker_stats.values()),
        }
    except Exception as e:
        justsdk.print_error(f"Failed to perform speaker diarization: {e}")
        return None


def align_transcription_with_speakers(
    transcription: dict, diarization_res: dict, overlap_threshold: float = 0.5
) -> list:
    aligned_segments: list = []

    for trans_seg in transcription["segments"]:
        trans_start = trans_seg["start"]
        trans_end = trans_seg["end"]
        trans_duration = trans_end - trans_start

        speaker_overlaps: list = []

        for speaker_seg in diarization_res["segments"]:
            speaker_start = speaker_seg["start"]
            speaker_end = speaker_seg["end"]

            overlap_start = max(trans_start, speaker_start)
            overlap_end = min(trans_end, speaker_end)
            overlap_duration = max(0, overlap_end - overlap_start)

            if overlap_duration > 0:
                overlap_ratio = overlap_duration / trans_duration
                if overlap_ratio >= overlap_threshold:
                    speaker_overlaps.append(
                        {
                            "speaker": speaker_seg["speaker"],
                            "overlap_ratio": overlap_ratio,
                            "overlap_duration": overlap_duration,
                        }
                    )

        if speaker_overlaps:
            best_speaker = max(speaker_overlaps, key=lambda x: x["overlap_ratio"])[
                "speaker"
            ]
            formatted_best_speaker = best_speaker.replace(
                "SPEAKER_", f"{justsdk.Fore.MAGENTA}SPEAKER_{justsdk.Fore.RESET}"
            )
            speaker = formatted_best_speaker
        else:
            speaker = f"{justsdk.Fore.RED}UNKNOWN{justsdk.Fore.RESET}"

        aligned_segments.append(
            {
                **trans_seg,
                "speaker": speaker,
            }
        )
    return aligned_segments


def format_diarized_transcript(aligned_segments: list) -> str:
    formatted_lines: list = []
    for seg in aligned_segments:
        line = (
            f"  [{seg['start_formatted']} - {seg['end_formatted']}] {seg['speaker']}:\n"
            f"  {seg['text']}\n"
        )
        formatted_lines.append(line)
    return "\n".join(formatted_lines)

### 12. Speaker diarization

In [None]:
justsdk.print_info("Initializing diarization pipeline...")
diarization_pipeline = initialize_diarization_pipeline(HF_TOKEN)

if diarization_pipeline and transcription:
    justsdk.print_info(f"Performing speaker diarization on {sample_selected.name}")

    diarization_result = perform_speaker_diarization(
        audio_path=sample_selected, pipeline=diarization_pipeline
    )

    if diarization_result:
        justsdk.print_success(
            f"Diarization completed. Found {diarization_result['num_speakers']} speakers."
        )
    else:
        justsdk.print_error("Diarization failed or returned no results.")

else:
    justsdk.print_error(
        "Diarization pipeline initialization failed or transcription is missing."
    )

### 13. Analyzing diarization results

In [None]:
if diarization_result:
    aligned_segments = align_transcription_with_speakers(
        transcription, diarization_result, overlap_threshold=0.5
    )
    formatted_transcript = format_diarized_transcript(aligned_segments)
    justsdk.print_info("Formatted Diarized Transcript:")
    print(formatted_transcript)
else:
    justsdk.print_error("No diarization results available for alignment.")

### 14. Metrics: Speaker diarization

In [None]:
# TODO: Visualize diarization results