# Whisper Based Transcription and Speaker Diarization for YouTube Videos

This notebook will leverage the OpenAI Whisper model and pyannote-audio to create a diarization system. 

## What is speaker diarization?

Speaker diarization aims to answer the question of ***who spoke and when?***. 
In short: diarization algorithms break down an audio stream of multiple speakers into segments corresponding to the individual speakers. 
By combining the information that we get from diarization with Automated Speech Recognition (ASR) transcriptions, we can transform the generated 
transcript into a format which is more readable and interpretable for humans and that can be used for other downstream NLP tasks.

<img src="https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/_images/asr_sd_diagram.png" />

An example output from an ASR system without diarization could be:

```
I just got back from the gym. oh good. uhuh. How's it going? oh pretty well. 
It was really crowded today yeah. I kind of assumed everyone would be at the shore.
uhhuh. I was wrong. Well it's the middle of the week or whatever so. 
But it's the fourth of July. mm. So. yeah. People have to work tomorrow. Do you have to work tomorrow?
yeah. Did youhave off yesterday? Yes. oh that's good. And I was paid too. oh. Is it paid today? No. oh.
```

With diarization, however, the conversation would become much more readable:

```
A: I just got back from the gym.
B: oh good.
A: uhhuh.
B: How's it going?
A: oh pretty well.
A: It was really crowded today.
B: yeah.
A: I kind of assumed everyone would be at  the shore.
B: uhhuh.
A: I was wrong.
B: Well it's the middle of the week or whatever so.
A: But it's the fourth of July.
B: mm.
A: So.
B: yeah.
B: People have to work tomorrow.
B: Do you have to work tomorrow?
A: yeah.
B: Did you have off yesterday?
A: Yes.
B: oh that's good.
A: And I was paid too.
B: oh.
B: Is it paid today?
A: No.
B: oh.
```

Speaker-aware transcripts can be a powerful tool for analyzing speech data:

* We can use the transcripts to analyze individual speaker's sentiment by using sentiment analysis on both audio and text transcripts.
* Another use case is telemedicine where we might identify the `<doctor>`and `<patient>` tags on the transcription to create an accurate transcript and attach it to the patient file or EHR system.
* Speaker diarization can be used by hiring platforms to analyze phone and video recruitment calls. This allows them to split and categorize candidates depending on their responses to certain questions without having to listen again to the records.

## Instructions

To run this notebook you need a link to a YouTube video (`_YOUTUBE_VIDEO_URL`) and a HuggingFace authorization token (`_HF_AUTH_TOKEN`). You can create an account to HuggingFace services and acquire the authorization token by following [these instructions on the HuggingFace website](https://huggingface.co/docs/hub/security-tokens).

Make sure you have a GPU runtime enabled by going to `Runtime -> Change runtime type` and make sure the `Hardware accelerator` is set to `GPU`. 

Insert the YouTube video URL to the following form. Execute the notebook by selecting `Runtime -> Run all` from the top dropdown menu.

In [None]:

_YOUTUBE_VIDEO_URL = "https://www.youtube.com/watch?v=H7kZ98bAHaA" #@param {type:"string"}
_HF_AUTH_TOKEN = "" #@param {type:"string"}

## Check GPU

Check that we are running in an environment with a GPU available. Otherwise, the Whisper model is going to run very slowly.

In [None]:
#@title
import locale


locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
#@title
!nvidia-smi

## Install dependencies

In [None]:
!apt update && apt install ffmpeg
!apt install libvoikko1 libvoikko-dev voikko-fi python3-libvoikko --yes
!pip install --upgrade \
    googletrans==4.0.0rc1 \
    yt-dlp \
    setuptools-rust \
    jiwer \
    pydub \
    pyannote.audio \
    git+https://github.com/openai/whisper.git \
    git+https://github.com/m-bain/whisperx.git \
    torch==1.12.1+cu116 \
    torchvision==0.13.1+cu116 \
    torchaudio==0.12.1 \
    --extra-index-url https://download.pytorch.org/whl/cu116
!pip install torchtext==0.13.0 --force-reinstall --no-dependencies

## Download YouTube video audio

Download the audio for the selected YouTube video.

In [None]:
#@title
import re
import os
import yt_dlp


def download_youtube_audio(yt_video_url: str) -> str:
    yt_video_id = re.search(r"v=([a-zA-Z0-9]*)", yt_video_url).group(1)
    output_file = f"tmp/data/yt_audio/{yt_video_id}.wav"
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    if os.path.exists(output_file):
        return output_file

    ydl_opts = {
        "format": "bestaudio/best",
        "postprocessors": [{
            "key": "FFmpegExtractAudio",
            "preferredcodec": "wav",
            "preferredquality": "192",
      }],
        "outtmpl": "tmp/data/yt_audio/%(id)s.%(ext)s",
    }

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        print(yt_video_url)
        ydl.download([yt_video_url])
    return output_file


print(f"Downloading YouTube audio for video: {_YOUTUBE_VIDEO_URL} ..")
audio_file_path = download_youtube_audio(_YOUTUBE_VIDEO_URL)
print(f"YouTube video audio downloaded successfully to: {audio_file_path}")

## Optional: Crop the audio

If the audio is very long we can crop the audio to a selected time interval `[T1, T2]`.

In [None]:
#@title
from pydub import AudioSegment


_AUDIO_CROP_T1_SEC = None
_AUDIO_CROP_T2_SEC = None


def crop_audio_file(
    audio_file_path: str,
    t1_sec: float,
    t2_sec: float,
) -> str:
    if (t1_sec is None and t2_sec is None) or (t1_sec == 0.0 and t2_sec is None):
        return audio_file_path
    
    assert(t1_sec < t2_sec)
    output_file = f"tmp/data/audio_crops/{os.path.basename(audio_file_path)}"
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    if os.path.exists(output_file):
        return output_file

    t1_ms = int(t1_sec * 1000)
    t2_ms = int(t2_sec * 1000)
    audio = AudioSegment.from_wav(audio_file_path)
    cropped_audio = audio[t1_ms:min(t2_ms, len(audio))]
    cropped_audio.export(output_file, format="wav")
    return output_file


print(f"Cropping the audio file to [{_AUDIO_CROP_T1_SEC}, {_AUDIO_CROP_T2_SEC}] sec..")
audio_crop_file_path = crop_audio_file(
    audio_file_path=audio_file_path,
    t1_sec=_AUDIO_CROP_T1_SEC,
    t2_sec=_AUDIO_CROP_T2_SEC,
)
print(f"Audio crop stored at: {audio_crop_file_path}")

## Speaker diarization with pyannote-audio

We will use [pyannote-audio](https://github.com/pyannote/pyannote-audio) pretrained models to diarize the cropped audio file.

In [None]:
#@title
from pyannote.audio import Pipeline


try:
    if pipeline:
        print("Using existing diarization pipeline")
except NameError:
    print("Creating new diarization pipeline")
    pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization@2.1",
        use_auth_token=_HF_AUTH_TOKEN,
    )

print(f"Creating diarization for audio file: {audio_crop_file_path} ..")
diarization = pipeline(audio_crop_file_path)
print(f"Diarization complete")

Create diarization segments: in other words recognize who is speaking and at what time in the audio.

In [None]:
#@title
from dataclasses import dataclass


@dataclass
class DiarizationSegment:
    start: float
    end: float
    speaker: str


diarization_segments = [] 
print(f"Diarization summary: {len(diarization)} segments with {len(diarization.labels())} unique speakers")

for idx, track in enumerate(diarization.itertracks(yield_label=True)):
    turn, _, speaker = track
    ds = DiarizationSegment(
        start=turn.start,
        end=turn.end,
        speaker=speaker,
    )
    diarization_segments.append(ds)
    print(f"idx: {idx:03d}, start: {ds.start:06.1f}s, end: {ds.end:06.1f}s, speaker: {ds.speaker}")

Postprocess the diarization segments. Combine two consecutive segments from the same speaker and fix problems with overlapping segments.

In [None]:
#@title
import copy


postprocessed_diarization_segments = []


for idx, current_ds in enumerate(diarization_segments):
    if idx == 0:
        postprocessed_diarization_segments.append(copy.deepcopy(current_ds))
        continue
    
    # If the current diarization segment overlaps with the previous one we need to split
    # the segments into non overlapping parts
    previous_ds = postprocessed_diarization_segments[-1]
    
    if current_ds.start < previous_ds.end:
        previous_ds_original_end = previous_ds.end
        
        # Patch the previous to end at the start of current
        previous_ds.end = current_ds.start
        
        # If the current segment is completely contained within the previous:
        # Append two new segments
        if current_ds.end < previous_ds_original_end:
            postprocessed_diarization_segments.append(copy.deepcopy(current_ds))
            postprocessed_diarization_segments.append(
                DiarizationSegment(
                    start=current_ds.end,
                    end=previous_ds_original_end,
                    speaker=previous_ds.speaker
                )
            )
        # If the current ends after the previous segment:
        # Only append the current segment
        else:
            postprocessed_diarization_segments.append(copy.deepcopy(current_ds))
    else:
        postprocessed_diarization_segments.append(copy.deepcopy(current_ds))

        
for idx, ds in enumerate(postprocessed_diarization_segments):
    print(f"idx: {idx:03d}, start: {ds.start:06.1f}s, end: {ds.end:06.1f}s, speaker: {ds.speaker}")

In [None]:
#@title
_MAX_DIARIZATION_SEGMENT_LENGTH = None

if _MAX_DIARIZATION_SEGMENT_LENGTH is not None and _MAX_DIARIZATION_SEGMENT_LENGTH > 0:
    postprocessed_diarization_segments_2 = []

    for idx, current_ds in enumerate(postprocessed_diarization_segments):
        current_ds_length = current_ds.end - current_ds.start
        if current_ds_length > _MAX_DIARIZATION_SEGMENT_LENGTH:
            num_splits = int(current_ds_length / _MAX_DIARIZATION_SEGMENT_LENGTH)

            for i in range(0, num_splits):
                split_segment_start = current_ds.start + i * _MAX_DIARIZATION_SEGMENT_LENGTH
                split_segment_end = split_segment_start + _MAX_DIARIZATION_SEGMENT_LENGTH
                postprocessed_diarization_segments_2.append(
                    DiarizationSegment(
                        start=split_segment_start,
                        end=split_segment_end,
                        speaker=current_ds.speaker,
                    )
                )

            if current_ds_length % _MAX_DIARIZATION_SEGMENT_LENGTH != 0:
                postprocessed_diarization_segments_2.append(
                    DiarizationSegment(
                        start=current_ds.start + num_splits * _MAX_DIARIZATION_SEGMENT_LENGTH,
                        end=current_ds.end,
                        speaker=current_ds.speaker,
                    )
                )

    postprocessed_diarization_segments = postprocessed_diarization_segments_2
    for idx, ds in enumerate(postprocessed_diarization_segments):
        print(f"idx: {idx:03d}, start: {ds.start:06.1f}s, end: {ds.end:06.1f}s, speaker: {ds.speaker}")

## Speech-to-text (STT) with Whisper

Next we will do speech-to-text (STT) and optionally translation using the OpenAI Whisper model. 
Because the original Whisper model does not give accurate word level timestamps and the overall usage of the 
produced timestamps in the original Whisper model is quite convoluted we will use the model through the WhisperX module 
instead.

In [None]:
#@title
import whisper
import torch


_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # "cuda" or "cpu"
_WHISPER_MODEL = "large" # "tiny", "base", "small", "medium", "large" + ".en" for english only


# Transcribe with the original Whisper model
print(f"Using device: {_DEVICE}")

try:
    if whisper_model:
        print(f"Using existing Whisper model: {_WHISPER_MODEL}")
except NameError:
    print(f"Loading Whisper model: {_WHISPER_MODEL} ..")
    whisper_model = whisper.load_model(_WHISPER_MODEL, device=_DEVICE)

In [None]:
#@title
import whisperx
import torch
import time
import ffmpeg


audio_length_sec = float(ffmpeg.probe(audio_crop_file_path)["format"]["duration"])
print(f"Transcribing {audio_length_sec} seconds of audio: {audio_crop_file_path} ..")
s_time = time.time()
transcription_result = whisper_model.transcribe(audio_crop_file_path)
print(f"Transcription completed in: {time.time()-s_time:.2f} seconds")

# Load alignment model and metadata
print(f"Loading alignment model for detected transcription language: {transcription_result['language']} ..")
align_model, align_model_metadata = whisperx.load_align_model(
    language_code=transcription_result["language"],
    device=_DEVICE,
)
print("Aligning Whisper output segments ..")
s_time = time.time()
transcription_result_aligned = whisperx.align(
    transcription_result["segments"],
    align_model,
    align_model_metadata,
    audio_crop_file_path,
    _DEVICE,
)
print(f"Alignment completed in: {time.time()-s_time:.2f} seconds")
print(transcription_result_aligned["word_segments"])

## Combine transcription and diarization results

Combine the the transcription and diarization results to get the final results.

In [None]:
#@title
import re
from dataclasses import dataclass
from typing import List


@dataclass
class SpeakerTranscribedSegment:
    start: float
    end: float
    speaker: str
    text: str

  
def combine_diarization_and_transcription(
    diarization_segments,
    transcription_result,
    min_segment_length: float = 0.5,
    use_postprocessing_heuristics: bool = False,
    combine_consecutive_speaker_segments: bool = False,
) -> List[SpeakerTranscribedSegment]:
    
    speaker_transcribed_segments = []
    
    transcription_segments = transcription_result["word_segments"]
    
    current_ds_idx = 0
    current_ds_ts_s_idx = 0
    current_ds_ts_e_idx = 0
    
    while current_ds_idx < len(diarization_segments):
        current_ds = diarization_segments[current_ds_idx]
        
        while current_ds_ts_e_idx < len(transcription_segments) and transcription_segments[current_ds_ts_e_idx]["start"] < current_ds.end:
            current_ds_ts_e_idx += 1
        
        transcribed_segment_text = " ".join(
            [ts["text"] for ts in transcription_segments[current_ds_ts_s_idx:current_ds_ts_e_idx]]
        )
        transcribed_segment_text = re.sub(' +', ' ', transcribed_segment_text).strip()
        
        # If the next segment speaker is the same as the previous modify the previous segment
        # information and combine them together, do the same if the length of the segment is less
        # than the minimum segment length
        previous_speaker = None
        
        if len(speaker_transcribed_segments) > 0:
            previous_speaker = speaker_transcribed_segments[-1].speaker
        
        if combine_consecutive_speaker_segments and ((previous_speaker == current_ds.speaker) or (current_ds.end - current_ds.start < min_segment_length)):
                speaker_transcribed_segments[-1].end = current_ds.end
                speaker_transcribed_segments[-1].text = speaker_transcribed_segments[-1].text + " " + transcribed_segment_text
        else:
            speaker_transcribed_segments.append(
                SpeakerTranscribedSegment(
                    start=current_ds.start,
                    end=current_ds.end,
                    speaker=current_ds.speaker,
                    text=transcribed_segment_text,
                )
            )
        
        current_ds_ts_s_idx = current_ds_ts_e_idx
        current_ds_ts_e_idx = current_ds_ts_e_idx + 1
        current_ds_idx += 1
    
    # Try to fix small mistakes in transcribed diarization using heuristics
    if use_postprocessing_heuristics:
        for current_sts_idx, current_sts in enumerate(speaker_transcribed_segments):

            # If the speaker transcribed segment does not begin with an upper case character there is
            # likely a small mistake, since Whisper is very good at styling. Likely a part of the text
            # belongs to the previous or next segment. Try to fix these small mistakes here.
            if not current_sts.text[0].isupper():
                max_words_to_move = 3
                previous_sts = speaker_transcribed_segments[current_sts_idx-1]
                first_sentence_in_current = re.split(r"[\?\.]", current_sts.text)[0]
                last_sentence_in_previous = re.split(r"[\?\.]", previous_sts.text)[-1]

                if len(last_sentence_in_previous.strip().split(" ")) <= max_words_to_move:
                    current_sts.text = last_sentence_in_previous.strip() + " " + current_sts.text
                    previous_sts.text = previous_sts.text[0:-len(last_sentence_in_previous)].strip()
                elif len(first_sentence_in_current.strip().split(" ")) <= max_words_to_move:
                    previous_sts.text = previous_sts.text + " " + first_sentence_in_current + "."
                    current_sts.text = current_sts.text[len(first_sentence_in_current) + 2:-1].strip()

        speaker_transcribed_segments = list([sts for sts in speaker_transcribed_segments if len(sts.text.strip()) > 0])
                    
    return speaker_transcribed_segments


speaker_transcribed_segments = combine_diarization_and_transcription(
    diarization_segments=postprocessed_diarization_segments,
    transcription_result=transcription_result_aligned,
    use_postprocessing_heuristics=True,
    combine_consecutive_speaker_segments=False,
)

In [None]:
#@title
for sts in speaker_transcribed_segments:
    print(f"[{sts.start:.2f}, {sts.end:.2f}] {sts.speaker}: {sts.text}")