In [1]:
%%sh
pwd
unset XLA_PYTHON_CLIENT_PREALLOCATE

/home/ubuntu/whisper-dev/whisper-stream/notebooks


In [2]:
%load_ext autoreload
%load_ext viztracer

In [3]:
import jax

jax.device_count(), jax.devices()[0].device_kind

(1, 'NVIDIA A10G')

In [4]:
from whisper_stream.core.helpers.data_loading import load_data_samples_from_path
from whisper_stream.pipelines.jax_pipelines import (
    JAXStreamingPipeline,
)
from whisper_stream.pipelines.jax_pipelines.constants import (
    JAXValidDtypesMapping,
    JAXScalarDType,
)

from whisper_stream.core.constants import WhisperValidCheckpoints, WhisperValidTasks

from whisper_stream.core.logger import LOG_LEVEL_NAMES
from pathlib import Path
import time

%autoreload 2

In [5]:
# Prepare
checkpoint: WhisperValidCheckpoints = "openai/whisper-large-v2"
model_dtype: JAXScalarDType = JAXValidDtypesMapping["BFLOAT16"]
task: WhisperValidTasks = "transcribe"
language: str = "english"
return_timestamps: bool = True
batch_size: int = 1
log_level: LOG_LEVEL_NAMES = "INFO"

data_directory = Path("../data")

run_opts = {
    "batch_size": batch_size,
    "return_timestamps": return_timestamps,
    "language": language,
    "task": task,
}

# construct
pipeline = JAXStreamingPipeline(
    checkpoint=checkpoint,
    dtype=model_dtype,
    batch_size=batch_size,
    min_log_level=log_level,
)

In [6]:
# Load data
pipeline_data: bytes = load_data_samples_from_path(
    "audio_2.mp3", directory=data_directory, binary_mode=True
)  # 2s
pipeline_data_large: bytes = load_data_samples_from_path(
    "tryst.mp3", directory=data_directory, binary_mode=True
)  # 4:44s
len(pipeline_data), len(pipeline_data_large)

(67823, 4644352)

In [7]:
# initialize & warmup
%time pipeline.initialize_pipeline(**run_opts, use_experimental_cache=True)

event="Initializing openai/whisper-large-v2/<class 'jax.numpy.bfloat16'> pipeline" application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') level='info' timestamp='2023-09-03T18:59:32.224731Z'




event="Compiling openai/whisper-large-v2/<class 'jax.numpy.bfloat16'> pipeline" application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') level='info' timestamp='2023-09-03T18:59:32.226172Z'


loc("-":700:9): error: Dialect `cf' not found for custom op 'cf.switch' 
INFO:jax._src.dispatch:'pmap_generate' took at least 1.00 seconds to compile (34.83s), writing persistent cache entry
INFO:jax._src.compilation_cache:Writing pmap_generate to persistent compilation cache with key 8527d0dc345a8af6ae3ce4eb33979f8ccbe96d8b5f56ca93aa5832aba08d17d8.
INFO:jax._src.dispatch:Not writing persistent cache entry for 'jit_reshape' because it took < 1.00 seconds to compile (0.01s)


event='Compilation done in 46.57s' application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') level='info' timestamp='2023-09-03T19:00:18.798081Z'
CPU times: user 1min 6s, sys: 1.4 s, total: 1min 8s
Wall time: 46.6 s


In [8]:
# should be warmed up now (time should be similar to # small data)
%time list(pipeline(pipeline_data, **run_opts))

event='ffmpeg conversion' time_taken='0.022s' application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') num_items=1 level='info' timestamp='2023-09-03T19:00:18.841356Z'
CPU times: user 443 ms, sys: 407 ms, total: 849 ms
Wall time: 379 ms


[[{'text': ' I know all the players of cricket.',
   'chunks': [{'timestamp': (0.0, 2.8),
     'text': ' I know all the players of cricket.'}]}]]

In [9]:
# small data
%time list(pipeline(pipeline_data, **run_opts))

event='ffmpeg conversion' time_taken='0.022s' application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') num_items=1 level='info' timestamp='2023-09-03T19:00:19.243766Z'


CPU times: user 477 ms, sys: 431 ms, total: 908 ms
Wall time: 378 ms


[[{'text': ' I know all the players of cricket.',
   'chunks': [{'timestamp': (0.0, 2.8),
     'text': ' I know all the players of cricket.'}]}]]

In [10]:
# small data in batch
%time list(pipeline([pipeline_data] * 10, **run_opts))

event='ffmpeg conversion' time_taken='0.1s' application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') num_items=10 level='info' timestamp='2023-09-03T19:00:19.726702Z'


In [None]:
# chunkable data
%time list(pipeline(pipeline_data_large, **run_opts))

event='ffmpeg conversion' time_taken='0.63s' application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') num_items=1 level='info' timestamp='2023-09-03T18:58:46.261793Z'


CPU times: user 19.8 s, sys: 7.28 s, total: 27 s
Wall time: 18.7 s


[[{'text': ' Long years ago, we made a truce with destiny, and now the time comes when we shall redeem our pledge, not only or in full measure, but very substantially. At the stroke of the midnight hour, when the world sleeps, India will awake to life and freedom. At the stroke of the midnight hour, when the world sleeps, India will awake to life and freedom. A moment comes, which comes but rarely in history, when we step out from the old to the new, kept out from the old to the new, when an age end, and when the soul of a nation, long suppressed, finds utterance. It is fitting that at this solemn moment we take the pledge of dedication to the service of India and her people, and to the still larger cause of humanity. her people and to the still larger cause of humanity. At the dawn of history, India started on her unending quest, and trackless centuries are filled with her striving and the grandeur of her successes and her failures. Through good and ill fortune alike, she has never lo

In [None]:
# chunkable data in batches
%time list(pipeline([pipeline_data_large] * 10, **run_opts))

event='ffmpeg conversion' time_taken='0.73s' application='whisper_stream' version='0.0.4' python_version='3.11.5' platform_architecture=('64bit', 'ELF') num_items=4 level='info' timestamp='2023-09-03T18:38:03.205106Z'
TASKS {'smallest': BatchPreProcessorTasksMapping(task_callable=<bound method JAXStreamingPipeline._preprocess_batches_for_unchunkable of <whisper_stream.pipelines.jax_pipelines.streaming.JAXStreamingPipeline object at 0x7f6a0a56cdd0>>, task_kwargs={'inputs': [], 'batch_size': 1, 'target_sampling_rate': 16000, 'do_normalize': True}), 'largest': BatchPreProcessorTasksMapping(task_callable=<bound method JAXStreamingPipeline._generate_batching_info_for_chunkable of <whisper_stream.pipelines.jax_pipelines.streaming.JAXStreamingPipeline object at 0x7f6a0a56cdd0>>, task_kwargs={'inputs': [array([-0.0012111065443605185, -0.0008541397983208299,
       -0.0015560932224616408, ..., 0.0, 0.0, 0.0], dtype=object), array([-0.0012111065443605185, -0.0008541397983208299,
       -0.001556

[[{'text': ' will.', 'chunks': [{'timestamp': (0.0, 1.0), 'text': ' will.'}]}],
 [{'text': ' will.', 'chunks': [{'timestamp': (0.0, 1.0), 'text': ' will.'}]}],
 [{'text': ' will.', 'chunks': [{'timestamp': (0.0, 1.0), 'text': ' will.'}]}],
 [{'text': ' will.', 'chunks': [{'timestamp': (0.0, 1.0), 'text': ' will.'}]}]]

In [None]:
mixed_mode_data: list[bytes] = [
    pipeline_data_large,
    pipeline_data,
    pipeline_data,
    pipeline_data,
] * 5

In [None]:
# mixed data, received as it comes, using default `smallest` strategy the smaller files will come in larger batches first
start: float = time.time()
for data in pipeline(mixed_mode_data, strategy="smallest", **run_opts):
    print({"num_items": len(data)}, end="\n")
    print({"data": data, "time_taken": f"{time.time() - start:.2}s"}, end="\n")
    print("-" * 40, end="\n")
    start = time.time()