In [None]:
import os
import gc
import re
import pickle

import whisperx
import torch
import pandas as pd

HF_TOKEN = os.environ['HF_TOKEN']

device = "cuda"
audio_file = "audios/video5.mp3"
audio_name = re.split('/|\.', audio_file)[1]  # audios/video3.mp3 -> video3
batch_size = 8  # reduce if low on GPU mem
dialogues_dir = 'dialogues'
diarized_outputs_dir = 'diarized_outputs'

for directory in [dialogues_dir, diarized_outputs_dir]:
    if not os.path.exists(os.path.join(directory)):
        os.mkdir(os.path.join(directory))

## Transcription

In [None]:
model = whisperx.load_model("large-v2", device, compute_type="float32")  # change to "int8" if low on GPU mem (may reduce accuracy)
audio = whisperx.load_audio(audio_file)
transcription = model.transcribe(audio, batch_size=batch_size)

import gc; gc.collect(); torch.cuda.empty_cache(); del model

transcription["segments"] # before alignment

## Output Alignment

In [None]:
model_a, metadata = whisperx.load_align_model(language_code=transcription["language"], device=device)
alignment = whisperx.align(
    transcription["segments"],
    model_a,
    metadata,
    audio,
    device,
    return_char_alignments=False
)

import gc; gc.collect(); torch.cuda.empty_cache(); del model_a

alignment["segments"]  # after alignment

## Assigning Speaker Labels (Diarization)

In [None]:
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)

# add min/max number of speakers if known a priori
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
audio = whisperx.load_audio(audio_file)
diarize_segments = diarize_model(audio)

result = whisperx.assign_word_speakers(diarize_segments, alignment)
import gc; gc.collect(); torch.cuda.empty_cache(); del diarize_model


with open(os.path.join('.', 'diarized_outputs', f'{audio_name}_segments.pkl'), 'wb') as file:
    pickle.dump((diarize_segments, result), file)
    
print(diarize_segments)
print(result["segments"])  # segments are now assigned speaker IDs

## Processing Diarization Output

In [None]:
result['segments']

In [None]:
def generate_dialogue_from_segments_fifo(segments: list[dict], speakers: dict = None) -> str:
    """Assign labels based on token-wise labeling performed by WhisperX."""
    dialogue = []
    buffer = '' 
    current_speaker = None
    
    for segment in segments:
        for word in segment['words']:
            if not current_speaker:
                current_speaker = word['speaker']
            if 'speaker' not in word.keys() or current_speaker == word['speaker']:
                buffer += f" {word['word']}"
            else:
                dialogue.append(f"{current_speaker if not speakers else speakers[current_speaker]}: {buffer}")
                buffer = word['word']
                current_speaker = word['speaker']
    return '\n'.join(dialogue)


def generate_dialogue_from_segments_most_frequent(segments: list[dict], speakers: dict = None) -> str:
    """Assign the label to the most frequent speaker within the segment."""
    dialogue = []
    for segment in segments:
        df = pd.DataFrame().from_records(segment['words'])
        most_frequent_speaker = df['speaker'].value_counts().to_frame().reset_index().loc[0, 'speaker']
        dialogue.append(f"{most_frequent_speaker if not speakers else speakers[most_frequent_speaker]}: {segment['text']}")
    return '\n'.join(dialogue)


def generate_dialogue_from_segments(segments: list[dict], speakers: dict = None) -> str:
    """Use WhisperX-assigned whole segment labels."""
    dialogue = []
    for segment in segments:
        dialogue.append(f"{segment['speaker'] if not speakers else speakers[segment['speaker']]}: {segment['text']}")
    return '\n'.join(dialogue)

In [None]:
dialogue = generate_dialogue_from_segments(
    segments=result['segments'],
)

with open(os.path.join('.', 'dialogues', f'{audio_name}_dialogue.txt'), 'w') as file:
    file.write(dialogue)

print(dialogue)