### Imports

In [None]:
from whisper_jax import FlaxWhisperPipline
import json
import os
import time
import datetime
import gc
import py3nvml.py3nvml as nvml
import logging
from threading import Thread

### Variables (adjustable)

In [None]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
batch_size = 6
model_version = "large-v2"
audios_path = "/opt/app-root/src/nbest/bn_nl_segments/"
out_path = "/opt/app-root/src/results/jax/large-v2/"

### Threading function (to measure GPU usage)

In [None]:
class MyThread(Thread):
    def __init__(self, func, params):
        super(MyThread, self).__init__()
        self.func = func
        self.params = params
        self.result = None

    def run(self):
        self.result = self.func(*self.params)

    def get_result(self):
        return self.result

### Running it all (modify where needed, mostly when changing implementation)

In [None]:
# # Uncomment the 2 lines below if you want to download a static FFmpeg build
# !curl https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz -o ffmpeg.tar.xz \
#  && tar -xf ffmpeg.tar.xz && rm ffmpeg.tar.xz

# Add the build to PATH
ffmdir = !find . -iname ffmpeg-*-static
path = %env PATH
path = path + ':' + ffmdir[0]
%env PATH $path
print('')
!which ffmpeg
print('Done!')

In [None]:
### SETTING UP THE LOGGER
logging.basicConfig(filename="jax_" + model_version + "_batchsize_" + str(batch_size) + ".txt",
                    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
                    level=logging.INFO,
                    force=True)
logger = logging.getLogger(__name__)
consoleHandler = logging.StreamHandler()
logger.addHandler(consoleHandler)

### LOADING THE MODEL
logger.info("================================")
start = time.time()
pipeline = FlaxWhisperPipline("openai/whisper-" + model_version, batch_size=batch_size)
logger.info(f"Time to load the model: {time.time() - start} s")

logger.info(
            "Measuring maximum GPU memory/power usage."
            " Make sure to not have additional processes running on the GPUs."
        )
# Initialization for measuring GPU usage
nvml.nvmlInit()
handle1 = nvml.nvmlDeviceGetHandleByIndex(0)
gpu1_name = nvml.nvmlDeviceGetName(handle1)
gpu1_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle1).total >> 20
gpu1_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle1) / 1000.0

handle2 = nvml.nvmlDeviceGetHandleByIndex(1)
gpu2_name = nvml.nvmlDeviceGetName(handle2)
gpu2_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle2).total >> 20
gpu2_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle2) / 1000.0

# Start measuring GPU usage for this file
gpu1_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
gpu2_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
    
def _get_gpu_info():
    while True:
        gpu1_usage["gpu_memory_usage"].append(
            nvml.nvmlDeviceGetMemoryInfo(handle1).used >> 20
        )
        gpu1_usage["gpu_power_usage"].append(
            nvml.nvmlDeviceGetPowerUsage(handle1) / 1000
        )
        gpu2_usage["gpu_memory_usage"].append(
            nvml.nvmlDeviceGetMemoryInfo(handle2).used >> 20
        )
        gpu2_usage["gpu_power_usage"].append(
            nvml.nvmlDeviceGetPowerUsage(handle2) / 1000
        )
        time.sleep(0.5)

        if stop:
            break

    return [gpu1_usage, gpu2_usage]

stop = False
thread = MyThread(_get_gpu_info, params=())
thread.start()
# Measuring time spent transcribing this file
file_start = time.time()
# Transcribing the file
output = pipeline("/opt/app-root/src/nbest/bn-nl/nbest-eval-2008-bn-nl-012.wav",  task="transcribe", return_timestamps=True)
# Stop measuring GPU usage for this file
stop = True
thread.join()

logger.info(output)

# Output GPU max memory & power usage
max_memory_usage1 = max(gpu1_usage["gpu_memory_usage"])
max_power_usage1 = max(gpu1_usage["gpu_power_usage"])
max_memory_usage2 = max(gpu2_usage["gpu_memory_usage"])
max_power_usage2 = max(gpu2_usage["gpu_power_usage"])
logger.info(
    "GPU 1 max memory usage: %dMiB / %dMiB (%.2f%%)"
    % (
        max_memory_usage1,
        gpu1_memory_limit,
        (max_memory_usage1 / gpu1_memory_limit) * 100,
    )
)
logger.info(
    "GPU 1 max power usage: %dW / %dW (%.2f%%)"
    % (
        max_power_usage1,
        gpu1_power_limit,
        (max_power_usage1 / gpu1_power_limit) * 100,
    )
)
logger.info("-----------------")
logger.info(
    "GPU 2 max memory usage: %dMiB / %dMiB (%.2f%%)"
    % (
        max_memory_usage2,
        gpu2_memory_limit,
        (max_memory_usage2 / gpu2_memory_limit) * 100,
    )
)
logger.info(
    "GPU 2 max power usage: %dW / %dW (%.2f%%)"
    % (
        max_power_usage2,
        gpu2_power_limit,
        (max_power_usage2 / gpu2_power_limit) * 100,
    )
)

logger.info("================================")
logger.info('Time spent transcribing: ' + str(datetime.timedelta(seconds=time.time() - file_start)))

end = time.time()
time_s = end - start
logger.info("================================")
logger.info('Total time spent evaluating: ' + str(datetime.timedelta(seconds=time_s)))
logger.info("================================")

# Cleanup for the next evaluation (most reliable is restarting the kernel)
logging.shutdown()
nvml.nvmlShutdown()
del pipeline
gc.collect()

### Dataset (multiple files) evaluation (WIP)

In [None]:
### SETTING UP THE LOGGER
logging.basicConfig(filename="jax_" + model_version + "_batchsize_" + str(batch_size) + ".txt",
                    format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
                    level=logging.INFO,
                    force=True)
logger = logging.getLogger(__name__)
consoleHandler = logging.StreamHandler()
logger.addHandler(consoleHandler)

### LOADING THE MODEL
logger.info("================================")
start = time.time()
pipeline = FlaxWhisperPipline("openai/whisper-" + model_version, batch_size=batch_size)
logger.info(f"Time to load the model: {time.time() - start} s")

logger.info(
            "Measuring maximum GPU memory/power usage."
            " Make sure to not have additional processes running on the GPUs."
        )
# Initialization for measuring GPU usage
nvml.nvmlInit()
handle1 = nvml.nvmlDeviceGetHandleByIndex(0)
gpu1_name = nvml.nvmlDeviceGetName(handle1)
gpu1_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle1).total >> 20
gpu1_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle1) / 1000.0

handle2 = nvml.nvmlDeviceGetHandleByIndex(1)
gpu2_name = nvml.nvmlDeviceGetName(handle2)
gpu2_memory_limit = nvml.nvmlDeviceGetMemoryInfo(handle2).total >> 20
gpu2_power_limit = nvml.nvmlDeviceGetPowerManagementLimit(handle2) / 1000.0

# Go through files to transcribe
for file in os.listdir(audios_path):
    # Start measuring GPU usage for this file
    gpu1_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
    gpu2_usage = {"gpu_memory_usage": [], "gpu_power_usage": []}
    
    def _get_gpu_info():
        while True:
            gpu1_usage["gpu_memory_usage"].append(
                nvml.nvmlDeviceGetMemoryInfo(handle1).used >> 20
            )
            gpu1_usage["gpu_power_usage"].append(
                nvml.nvmlDeviceGetPowerUsage(handle1) / 1000
            )
            gpu2_usage["gpu_memory_usage"].append(
                nvml.nvmlDeviceGetMemoryInfo(handle2).used >> 20
            )
            gpu2_usage["gpu_power_usage"].append(
                nvml.nvmlDeviceGetPowerUsage(handle2) / 1000
            )
            time.sleep(0.5)

            if stop:
                break

        return [gpu1_usage, gpu2_usage]

    stop = False
    thread = MyThread(_get_gpu_info, params=())
    thread.start()
    # Measuring time spent transcribing this file
    file_start = time.time()
    # Transcribing the file
    outputs = pipeline(audios_path + file,  task="transcribe", return_timestamps=True)
    # Stop measuring GPU usage for this file
    stop = True
    thread.join()

    # Output GPU max memory & power usage
    max_memory_usage1 = max(gpu1_usage["gpu_memory_usage"])
    max_power_usage1 = max(gpu1_usage["gpu_power_usage"])
    max_memory_usage2 = max(gpu2_usage["gpu_memory_usage"])
    max_power_usage2 = max(gpu2_usage["gpu_power_usage"])
    logger.info(
        "GPU 1 max memory usage: %dMiB / %dMiB (%.2f%%)"
        % (
            max_memory_usage1,
            gpu1_memory_limit,
            (max_memory_usage1 / gpu1_memory_limit) * 100,
        )
    )
    logger.info(
        "GPU 1 max power usage: %dW / %dW (%.2f%%)"
        % (
            max_power_usage1,
            gpu1_power_limit,
            (max_power_usage1 / gpu1_power_limit) * 100,
        )
    )
    logger.info("-----------------")
    logger.info(
        "GPU 2 max memory usage: %dMiB / %dMiB (%.2f%%)"
        % (
            max_memory_usage2,
            gpu2_memory_limit,
            (max_memory_usage2 / gpu2_memory_limit) * 100,
        )
    )
    logger.info(
        "GPU 2 max power usage: %dW / %dW (%.2f%%)"
        % (
            max_power_usage2,
            gpu2_power_limit,
            (max_power_usage2 / gpu2_power_limit) * 100,
        )
    )

    logger.info("================================")
    logger.info('Time spent transcribing: ' + str(datetime.timedelta(seconds=time.time() - file_start)))

end = time.time()
time_s = end - start
logger.info("================================")
logger.info('Total time spent evaluating: ' + str(datetime.timedelta(seconds=time_s)))
logger.info("================================")

# Cleanup for the next evaluation (most reliable is restarting the kernel)
logging.shutdown()
nvml.nvmlShutdown()
del pipeline
gc.collect()