In [None]:
import whisperx
import torch
import gc

from pathlib import Path

In [2]:
device = "cuda" 
audio_file = str(list(Path('./').rglob('*.wav'))[0])
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)

In [None]:
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model("large-v2", device, compute_type=compute_type,language='en')

# save model to local path (optional)
# model_dir = "/path/"
# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)

audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)
print(result["segments"]) # before alignment

# delete model if low on GPU resources
gc.collect()
torch.cuda.empty_cache()
del model

In [None]:
# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result_a = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

print(result_a["segments"]) # after alignment
# delete model if low on GPU resources
gc.collect()
torch.cuda.empty_cache()
del model_a

In [None]:
# get HF token
with open('../.streamlit/secrets.toml', 'r') as sec:
    lines = sec.readlines()
    for line in lines:
        tmp = line.split('=')
        if tmp[0] == 'hf_token':
            hf_token = tmp[-1]

# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)

# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)

result_d = whisperx.assign_word_speakers(diarize_segments, result_a)
print(diarize_segments)
print(result_d["segments"]) # segments are now assigned speaker IDs

gc.collect()
torch.cuda.empty_cache()
del diarize_model

In [None]:
print(result_d['segments'])

In [8]:
# write results for easier loading in future
import json
with open('diarized_data.json', 'w') as f:
    json.dump(result_d['segments'], f, ensure_ascii=False, indent=4)