# Audio Transcription and Speaker Diarization

This notebook processes audio files to:

1. Generate accurate transcriptions using faster-whisper

2. Identify and separate speakers using pyannote-audio

3. Combine results into a structured, speaker-attributed transcript

### 1. Setup and Configuration

In [None]:
import warnings
import justsdk
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import _root as _  # noqa: F401

from pathlib import Path
from typing import Optional, Dict, List
from faster_whisper import WhisperModel
from IPython.display import Audio, display
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook
from config._constants import HF_READ_ONLY_TOKEN, SAMPLE_DATA_DIR, MODEL_DIR
from src.utils import Utils


warnings.filterwarnings("ignore")
plt.style.use("seaborn-v0_8")
sns.set_palette("husl")

In [None]:
CONFIG = {
    "target_file": "project-proposal.mp3",
    "whisper": {
        "model_size": "base.en",
        "compute_type": "int8",
        "num_workers": 2,
        # "local_files_only": True,
    },
    "transcription": {
        "language": "en",
        "word_timestamps": True,
        "vad_filter": True,
        "vad_parameters": {"min_speech_duration_ms": 250},
    },
    "diarization": {"overlap_threshold": 0.5},
}

### 2. Audio File Discovery

In [None]:
def discover_audio_files() -> Dict[str, Dict]:
    extensions = [".mp3", ".mp4", ".wav"]
    directories = ["audio", "video"]

    files = {}
    for directory in directories:
        dir_path = SAMPLE_DATA_DIR / directory
        if dir_path.exists():
            for ext in extensions:
                for file_path in dir_path.glob(f"*{ext}"):
                    size_mb = file_path.stat().st_size / (1024 * 1024)
                    files[file_path.name] = {
                        "path": file_path,
                        "size_mb": round(size_mb, 2),
                    }

    return files


audio_files = discover_audio_files()
justsdk.print_info("Available audio files:")
for name, info in audio_files.items():
    print(f"  {name} ({info['size_mb']} MB)")

### 3. Transcription Engine

In [None]:
class TranscriptionEngine:
    def __init__(self, config: Dict):
        self.config = config
        self.model = self._initialize_model()

    def _initialize_model(self) -> Optional[WhisperModel]:
        try:
            model_params = {
                "model_size_or_path": self.config["whisper"]["model_size"],
                "compute_type": self.config["whisper"]["compute_type"],
                "num_workers": self.config["whisper"]["num_workers"],
                "download_root": str(MODEL_DIR / "whisper-base-en"),
                # "local_files_only": self.config["whisper"]["local_files_only"],
            }
            model = WhisperModel(**model_params)
            justsdk.print_success("Whisper model loaded successfully")
            return model
        except Exception as e:
            justsdk.print_error(f"Failed to load Whisper model: {e}")
            return None

    def transcribe(self, audio_path: Path) -> Optional[Dict]:
        if not self.model:
            return None

        try:
            segments_gen, info = self.model.transcribe(
                str(audio_path), **self.config["transcription"]
            )

            segments = []
            words = []
            full_text = []

            for segment in segments_gen:
                seg_data = self._process_segment(segment)
                segments.append(seg_data)
                full_text.append(seg_data["text"])

                if hasattr(segment, "words") and segment.words:
                    words.extend(self._process_words(segment.words))

            return {
                "file": audio_path.name,
                "info": self._process_info(info),
                "segments": segments,
                "words": words,
                "full_text": " ".join(full_text),
            }

        except Exception as e:
            justsdk.print_error(f"Transcription failed: {e}")
            return None

    def _process_segment(self, segment) -> Dict:
        return {
            "id": segment.id,
            "start": segment.start,
            "end": segment.end,
            "duration": segment.end - segment.start,
            "text": segment.text.strip(),
            "confidence": round(np.exp(segment.avg_logprob), 4),
            "no_speech_prob": segment.no_speech_prob,
            "start_time": Utils.to_hms(segment.start),
            "end_time": Utils.to_hms(segment.end),
        }

    def _process_words(self, words) -> List[Dict]:
        return [
            {
                "word": word.word,
                "start": word.start,
                "end": word.end,
                "probability": word.probability,
                "start_time": Utils.to_hms(word.start),
                "end_time": Utils.to_hms(word.end),
            }
            for word in words
        ]

    def _process_info(self, info) -> Dict:
        return {
            "language": info.language,
            "language_probability": info.language_probability,
            "duration": info.duration,
            "duration_formatted": Utils.to_hms(info.duration),
        }


transcriber = TranscriptionEngine(CONFIG)

### 4. Speaker Diarization Engine

In [None]:
class DiarizationEngine:
    def __init__(self, hf_read_only_token: str):
        self.pipeline = self._initialize_pipeline(hf_read_only_token)

    def _initialize_pipeline(self, hf_read_only_token: str) -> Optional[Pipeline]:
        try:
            pipeline = Pipeline.from_pretrained(
                "pyannote/speaker-diarization-3.1", use_auth_token=hf_read_only_token
            )

            if torch.cuda.is_available():
                pipeline = pipeline.to(torch.device("cuda"))
                justsdk.print_info("Using GPU for diarization")
            else:
                justsdk.print_info("Using CPU for diarization")

            return pipeline
        except Exception as e:
            justsdk.print_error(f"Failed to initialize diarization pipeline: {e}")
            return None

    def process(self, audio_path: Path) -> Optional[Dict]:
        if not self.pipeline:
            return None

        try:
            with ProgressHook() as hook:
                diarization = self.pipeline(str(audio_path), hook=hook)

            segments = []
            speakers = set()

            for turn, _, speaker in diarization.itertracks(yield_label=True):
                segment_data = {
                    "speaker": speaker,
                    "start": turn.start,
                    "end": turn.end,
                    "duration": turn.end - turn.start,
                    "start_time": Utils.to_hms(turn.start),
                    "end_time": Utils.to_hms(turn.end),
                }
                segments.append(segment_data)
                speakers.add(speaker)

            speaker_stats = {speaker: 0.0 for speaker in speakers}
            for segment in segments:
                speaker_stats[segment["speaker"]] += segment["duration"]

            return {
                "segments": segments,
                "speakers": sorted(speakers),
                "speaker_stats": speaker_stats,
                "total_duration": sum(speaker_stats.values()),
            }

        except Exception as e:
            justsdk.print_error(f"Diarization failed: {e}")
            return None


diarizer = DiarizationEngine(HF_READ_ONLY_TOKEN)

### 5. Transcript Alignment and Processing

In [None]:
def align_transcription_with_speakers(
    transcription: Dict, diarization: Dict, threshold: float = 0.5
) -> List[Dict]:
    aligned_segments = []

    for trans_seg in transcription["segments"]:
        best_speaker = "UNKNOWN"
        best_overlap = 0.0

        for speaker_seg in diarization["segments"]:
            overlap_start = max(trans_seg["start"], speaker_seg["start"])
            overlap_end = min(trans_seg["end"], speaker_seg["end"])
            overlap_duration = max(0, overlap_end - overlap_start)

            if overlap_duration > 0:
                overlap_ratio = overlap_duration / trans_seg["duration"]
                if overlap_ratio >= threshold and overlap_ratio > best_overlap:
                    best_overlap = overlap_ratio
                    best_speaker = speaker_seg["speaker"]

        aligned_segments.append(
            {**trans_seg, "speaker": best_speaker, "overlap_confidence": best_overlap}
        )

    return aligned_segments


def format_transcript(aligned_segments: List[Dict]) -> str:
    lines = []
    for segment in aligned_segments:
        speaker_label = segment["speaker"].replace("SPEAKER_", "Speaker ")
        time_range = f"[{segment['start_time']} - {segment['end_time']}]"

        line = f"{time_range} {speaker_label}:\n{segment['text']}\n"
        lines.append(line)

    return "\n".join(lines)

### 6. Audio Processing Pipeline

In [None]:
if CONFIG["target_file"] not in audio_files:
    justsdk.print_error(f"Target file '{CONFIG['target_file']}' not found")
    if audio_files:
        CONFIG["target_file"] = list(audio_files.keys())[0]
        justsdk.print_info(f"Using first available file: {CONFIG['target_file']}")
    else:
        raise FileNotFoundError("No audio files found")

target_file = audio_files[CONFIG["target_file"]]["path"]
justsdk.print_info(f"Processing: {target_file.name}")

display(Audio(target_file))

In [None]:
justsdk.print_info("Starting transcription...")
transcription_result = transcriber.transcribe(target_file)

if transcription_result:
    info = transcription_result["info"]
    justsdk.print_success("Transcription completed")
    print(f"Duration: {info['duration_formatted']}")
    print(
        f"Language: {info['language']} (confidence: {info['language_probability']:.3f})"
    )
    print(f"Segments: {len(transcription_result['segments'])}")
    print(f"Words: {len(transcription_result['words'])}")
else:
    raise RuntimeError("Transcription failed")

In [None]:
justsdk.print_info("Starting speaker diarization...")
diarization_result = diarizer.process(target_file)

if diarization_result:
    justsdk.print_success("Diarization completed")
    print(f"Speakers found: {len(diarization_result['speakers'])}")
    print(f"Speaker segments: {len(diarization_result['segments'])}")

    for speaker, duration in diarization_result["speaker_stats"].items():
        percentage = (duration / diarization_result["total_duration"]) * 100
        print(f"  {speaker}: {Utils.to_hms(duration)} ({percentage:.1f}%)")
else:
    raise RuntimeError("Diarization failed")

In [None]:
justsdk.print_info("Aligning transcription with speakers...")
aligned_segments = align_transcription_with_speakers(
    transcription_result, diarization_result, CONFIG["diarization"]["overlap_threshold"]
)

final_transcript = format_transcript(aligned_segments)
justsdk.print_success("Alignment completed")

### 7. Results Analysis and Visualization

In [None]:
justsdk.print_info("Final Speaker-Attributed Transcript:")
print("=" * 80)
print(final_transcript)
print("=" * 80)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))

speaker_colors = {
    speaker: f"C{i}" for i, speaker in enumerate(diarization_result["speakers"])
}
y_pos = 0

for segment in diarization_result["segments"]:
    ax1.barh(
        y_pos,
        segment["duration"],
        left=segment["start"],
        color=speaker_colors[segment["speaker"]],
        alpha=0.7,
        label=segment["speaker"]
        if segment["speaker"] not in ax1.get_legend_handles_labels()[1]
        else "",
    )
    y_pos += 1

ax1.set_xlabel("Time (seconds)")
ax1.set_ylabel("Speaker Segments")
ax1.set_title("Speaker Diarization Timeline")
ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
ax1.grid(True, alpha=0.3)

speakers = list(diarization_result["speaker_stats"].keys())
durations = list(diarization_result["speaker_stats"].values())
colors = [speaker_colors[speaker] for speaker in speakers]

ax2.pie(durations, labels=speakers, colors=colors, autopct="%1.1f%%", startangle=90)
ax2.set_title("Speaker Time Distribution")

plt.tight_layout()
plt.show()

In [None]:
confidence_data = [seg["confidence"] for seg in transcription_result["segments"]]
segment_durations = [seg["duration"] for seg in transcription_result["segments"]]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.hist(confidence_data, bins=20, alpha=0.7, color="skyblue", edgecolor="black")
ax1.axvline(
    np.mean(confidence_data),
    color="red",
    linestyle="--",
    label=f"Mean: {np.mean(confidence_data):.3f}",
)
ax1.set_xlabel("Confidence Score")
ax1.set_ylabel("Number of Segments")
ax1.set_title("Transcription Confidence Distribution")
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.scatter(segment_durations, confidence_data, alpha=0.6, color="coral")
ax2.set_xlabel("Segment Duration (seconds)")
ax2.set_ylabel("Confidence Score")
ax2.set_title("Confidence vs Segment Duration")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Confidence Statistics:")
print(f"  Mean: {np.mean(confidence_data):.3f}")
print(f"  Median: {np.median(confidence_data):.3f}")
print(f"  Min: {np.min(confidence_data):.3f}")
print(f"  Max: {np.max(confidence_data):.3f}")
print(f"  Std Dev: {np.std(confidence_data):.3f}")

In [None]:
summary_df = pd.DataFrame(aligned_segments)
print("Aligned Segments Summary:")
print(
    summary_df[
        [
            "speaker",
            "start_time",
            "end_time",
            "confidence",
            "overlap_confidence",
            "text",
        ]
    ].head(10)
)