### Imports

In [None]:
import whisperx
import torch
import json
import os
import time
import datetime
import gc
import py3nvml.py3nvml as nvml
import logging
from threading import Thread

### Variables (adjustable)

In [None]:
device = "cuda"  # "cpu" to use the CPU, "cuda" to use the GPU
device_index = 0  # for multiple GPU setup. Indicates which GPU to use
batch_size = 48
compute_type = "float16"  # precision to use (fp16, fp32, int8, etc.)
model_version = "large-v2"  # options: https://github.com/beeldengeluid/dane-whisper-asr-worker?tab=readme-ov-file#model-options
audios_path = "/opt/app-root/src/nbest/bn_nl_segments/"  # absolute path to folder where audio to be transcribed can be found
audio_file = "/opt/app-root/src/nbest/bn_nl_segments/nbest-eval-2008-bn-nl-001_1.wav"  # for experimenting/testing purposes
out_path = "/opt/app-root/src/results/jax/large-v2/"  # absolute path to folder where transcriptions + log should be saved
# GPU measurement parameter
interval = 0.5  # how often to measure GPU usage (in s)
# SECRET (for running diarization)
HF_TOKEN = "REPLACE_WITH_YOUR_HF_TOKEN"

### Threading function (to measure GPU usage)

In [None]:
class MyThread(Thread):
    def __init__(self, func, params):
        super(MyThread, self).__init__()
        self.func = func
        self.params = params
        self.result = None

    def run(self):
        self.result = self.func(*self.params)

    def get_result(self):
        return self.result

### ffmpeg setup

In [None]:
# # Uncomment the 2 lines below if you want to download a static FFmpeg build
# !curl https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz -o ffmpeg.tar.xz \
#  && tar -xf ffmpeg.tar.xz && rm ffmpeg.tar.xz

# Add the build to PATH
ffmdir = !find . -iname ffmpeg-*-static
path = %env PATH
path = path + ':' + ffmdir[0]
%env PATH $path
print('')
!which ffmpeg
print('Done!')

### Running it all (modify where needed, mostly when changing implementation)

In [None]:
### SETTING UP THE LOGGER
logging.basicConfig(filename="log.txt",
                    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
                    level=logging.INFO,
                    force=True)
logger = logging.getLogger(__name__)
consoleHandler = logging.StreamHandler()
logger.addHandler(consoleHandler)

### LOADING THE MODEL
logger.info("================================START OF EVALUATION================================")
start = time.time()
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model(model_version,
                            device=device,
                            device_index=device_index,
                            compute_type=compute_type)
logger.info(f"Time to load the model: {time.time() - start} s")
logger.info("================================")
logger.info(
            "Measuring maximum GPU memory usage on GPU device."
            " Make sure to not have additional processes running on the same GPU."
        )
# Initialization for measuring GPU usage
nvml.nvmlInit()
handle = nvml.nvmlDeviceGetHandleByIndex(device_index)
gpu_name = nvml.nvmlDeviceGetName(handle)
gpu_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle).total >> 20
gpu_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000.0

gpu_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}

def _get_gpu_info():
    while True:
        gpu_usage["gpu_memory_usage"].append(
            nvml.nvmlDeviceGetMemoryInfo(handle).used >> 20
        )
        gpu_usage["gpu_power_usage"].append(
            nvml.nvmlDeviceGetPowerUsage(handle) / 1000
        )
        time.sleep(interval)

        if stop:
            break

    return gpu_usage

stop = False
thread = MyThread(_get_gpu_info, params=())
thread.start()

# Measuring time spent transcribing this file
file_start = time.time()

# Transcribing the file
audio = whisperx.load_audio(audio_file)
result = model.transcribe(audio, batch_size=batch_size)

# Stop measuring GPU usage for this file
stop = True
thread.join()

logger.info(f"Time to transcribe: {time.time() - file_start} s")
max_memory_usage = max(gpu_usage["gpu_memory_usage"])
max_power_usage = max(gpu_usage["gpu_power_usage"])
logger.info(
    "Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
    % (
        max_memory_usage,
        gpu_memory_limit,
        (max_memory_usage / gpu_memory_limit) * 100,
    )
)
logger.info(
    "Maximum GPU power usage: %dW / %dW (%.2f%%)"
    % (
        max_power_usage,
        gpu_power_limit,
        (max_power_usage / gpu_power_limit) * 100,
    )
)
logger.info("--------------------------------")
# print(result["segments"]) # before alignment

gpu_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
stop = False
thread = MyThread(_get_gpu_info, params=())
thread.start()

# Measuring time spent aligning the output (word-level timestamps)
file_start = time.time()

# 2. Align whisper output
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

# Stop measuring GPU usage for this file
stop = True
thread.join()

logger.info(f"Time to align (generate word-level timestamps using wav2vec2): {time.time() - file_start} s")
max_memory_usage = max(gpu_usage["gpu_memory_usage"])
max_power_usage = max(gpu_usage["gpu_power_usage"])
logger.info(
    "Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
    % (
        max_memory_usage,
        gpu_memory_limit,
        (max_memory_usage / gpu_memory_limit) * 100,
    )
)
logger.info(
    "Maximum GPU power usage: %dW / %dW (%.2f%%)"
    % (
        max_power_usage,
        gpu_power_limit,
        (max_power_usage / gpu_power_limit) * 100,
    )
)
logger.info("--------------------------------")

# print(result["segments"]) # after alignment

gpu_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
stop = False
thread = MyThread(_get_gpu_info, params=())
thread.start()

file_start = time.time()
# 3. Assign speaker labels
diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
diarize_segments = diarize_model(audio)

result = whisperx.assign_word_speakers(diarize_segments, result)

# Stop measuring GPU usage for this file
stop = True
thread.join()


logger.info(f"Time to diarize: {time.time() - file_start} s")
max_memory_usage = max(gpu_usage["gpu_memory_usage"])
max_power_usage = max(gpu_usage["gpu_power_usage"])
logger.info(
    "Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
    % (
        max_memory_usage,
        gpu_memory_limit,
        (max_memory_usage / gpu_memory_limit) * 100,
    )
)
logger.info(
    "Maximum GPU power usage: %dW / %dW (%.2f%%)"
    % (
        max_power_usage,
        gpu_power_limit,
        (max_power_usage / gpu_power_limit) * 100,
    )
)
logger.info("================================")
logger.info('Total time spent evaluating (loading-diarization): ' + str(datetime.timedelta(seconds=time.time() - start)))
logger.info("================================END OF EVALUATION================================")
# print(result)

# Formatting the transcription
segments_to_add = []
for segment in result["segments"]:
    words_to_add = []
    for word in segment["words"]:
        words_to_add.append({
            # There's an issue where the text output contains a whitespace at the front of the text
            "text": word["word"].strip(),
            "start": word["start"],
            "end": word["end"],
            "confidence": word["score"],
            "speaker": word["speaker"]
        })
    segments_to_add.append({
        "start": segment["start"],
        "end": segment["end"],
        "text": segment["text"].strip(),
        "speaker": segment["speaker"],
        "words": words_to_add
    })
result = {"segments": segments_to_add}
# Saving results to JSON file
with open('output.json', 'w', encoding='utf-8') as f:
    json.dump(result, f, indent = 2, ensure_ascii = False)
# print(diarize_segments)
# print(result["segments"]) # segments are now assigned speaker IDs

logging.shutdown()
nvml.nvmlShutdown()
torch.cuda.empty_cache()
gc.collect()

### Dataset benchmarking (WIP)

In [None]:
### SETTING UP THE LOGGER
logging.basicConfig(filename=out_path + "log.txt",
                    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
                    level=logging.INFO,
                    force=True)
logger = logging.getLogger(__name__)
consoleHandler = logging.StreamHandler()
logger.addHandler(consoleHandler)

### LOADING THE MODEL
logger.info("================================START OF EVALUATION================================")
start = time.time()
# 1. Transcribe with original whisper (batched)
model = whisperx.load_model(model_version,
                            device=device,
                            device_index=device_index,
                            compute_type=compute_type)
logger.info(f"Time to load the model: {time.time() - start} s")
logger.info("================================")
logger.info(
            "Measuring maximum GPU memory usage on GPU device."
            " Make sure to not have additional processes running on the same GPU."
        )
# Initialization for measuring GPU usage
nvml.nvmlInit()
handle = nvml.nvmlDeviceGetHandleByIndex(device_index)
gpu_name = nvml.nvmlDeviceGetName(handle)
gpu_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle).total >> 20
gpu_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000.0

# Go through files to transcribe
for file in os.listdir(audios_path):
    # Start measuring GPU usage for this file
    gpu_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
    def _get_gpu_info():
        while True:
            gpu_usage["gpu_memory_usage"].append(
                nvml.nvmlDeviceGetMemoryInfo(handle).used >> 20
            )
            gpu_usage["gpu_power_usage"].append(
                nvml.nvmlDeviceGetPowerUsage(handle) / 1000
            )
            time.sleep(interval)

            if stop:
                break

        return gpu_usage

    stop = False
    thread = MyThread(_get_gpu_info, params=())
    thread.start()

    # Measuring time spent transcribing this file
    file_start = time.time()

    # Transcribing the file
    audio = whisperx.load_audio(file)
    result = model.transcribe(audio, batch_size=batch_size)

    # Stop measuring GPU usage for this file
    stop = True
    thread.join()

    logger.info(f"Time to transcribe: {time.time() - file_start} s")
    max_memory_usage = max(gpu_usage["gpu_memory_usage"])
    max_power_usage = max(gpu_usage["gpu_power_usage"])
    logger.info(
        "Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
        % (
            max_memory_usage,
            gpu_memory_limit,
            (max_memory_usage / gpu_memory_limit) * 100,
        )
    )
    logger.info(
        "Maximum GPU power usage: %dW / %dW (%.2f%%)"
        % (
            max_power_usage,
            gpu_power_limit,
            (max_power_usage / gpu_power_limit) * 100,
        )
    )
    logger.info("--------------------------------")
    # print(result["segments"]) # before alignment

    gpu_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
    stop = False
    thread = MyThread(_get_gpu_info, params=())
    thread.start()

    # Measuring time spent aligning the output (word-level timestamps)
    file_start = time.time()

    # 2. Align whisper output
    model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)

    # Stop measuring GPU usage for this file
    stop = True
    thread.join()

    logger.info(f"Time to align (generate word-level timestamps using wav2vec2): {time.time() - file_start} s")
    max_memory_usage = max(gpu_usage["gpu_memory_usage"])
    max_power_usage = max(gpu_usage["gpu_power_usage"])
    logger.info(
        "Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
        % (
            max_memory_usage,
            gpu_memory_limit,
            (max_memory_usage / gpu_memory_limit) * 100,
        )
    )
    logger.info(
        "Maximum GPU power usage: %dW / %dW (%.2f%%)"
        % (
            max_power_usage,
            gpu_power_limit,
            (max_power_usage / gpu_power_limit) * 100,
        )
    )
    logger.info("--------------------------------")

    # print(result["segments"]) # after alignment

    gpu_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
    stop = False
    thread = MyThread(_get_gpu_info, params=())
    thread.start()

    file_start = time.time()
    # 3. Assign speaker labels
    diarize_model = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
    diarize_segments = diarize_model(audio)

    result = whisperx.assign_word_speakers(diarize_segments, result)

    # Stop measuring GPU usage for this file
    stop = True
    thread.join()


    logger.info(f"Time to diarize: {time.time() - file_start} s")
    max_memory_usage = max(gpu_usage["gpu_memory_usage"])
    max_power_usage = max(gpu_usage["gpu_power_usage"])
    logger.info(
        "Maximum GPU memory usage: %dMiB / %dMiB (%.2f%%)"
        % (
            max_memory_usage,
            gpu_memory_limit,
            (max_memory_usage / gpu_memory_limit) * 100,
        )
    )
    logger.info(
        "Maximum GPU power usage: %dW / %dW (%.2f%%)"
        % (
            max_power_usage,
            gpu_power_limit,
            (max_power_usage / gpu_power_limit) * 100,
        )
    )
logger.info("================================")
logger.info('Total time spent evaluating (loading-diarization): ' + str(datetime.timedelta(seconds=time.time() - start)))
logger.info("================================END OF EVALUATION================================")
# print(diarize_segments)
# print(result["segments"]) # segments are now assigned speaker IDs

logging.shutdown()
nvml.nvmlShutdown()
torch.cuda.empty_cache()
gc.collect()