In [None]:
%%sh
pwd

In [None]:
%load_ext autoreload

In [None]:
import jax

jax.device_count(), jax.devices()[0].device_kind

In [None]:
from whisper_stream.core.helpers.data_loading import load_data_samples_from_path
from whisper_stream.pipelines.jax_pipelines import (
    JAXStreamingPipeline,
)
from whisper_stream.pipelines.jax_pipelines.constants import (
    JAXValidDtypesMapping,
    JAXScalarDType,
)

from whisper_stream.core.constants import WhisperValidCheckpoints, WhisperValidTasks

from whisper_stream.core.logger import LOG_LEVEL_NAMES
from pathlib import Path
import time

%autoreload 2

In [None]:
# Prepare
checkpoint: WhisperValidCheckpoints = "openai/whisper-tiny"
model_dtype: JAXScalarDType = JAXValidDtypesMapping["BFLOAT16"]
task: WhisperValidTasks = "transcribe"
language: str = "english"
return_timestamps: bool = True
batch_size: int = 1
log_level: LOG_LEVEL_NAMES = "INFO"

data_directory = Path("../data")

run_opts = {
    "batch_size": batch_size,
    "return_timestamps": return_timestamps,
    "language": language,
    "task": task,
}

# construct
pipeline = JAXStreamingPipeline(
    checkpoint=checkpoint,
    dtype=model_dtype,
    batch_size=batch_size,
    min_log_level=log_level,
)

In [None]:
# Load data
pipeline_data: bytes = load_data_samples_from_path(
    "audio_2.mp3", directory=data_directory, binary_mode=True
)  # 4s
pipeline_data_large: bytes = load_data_samples_from_path(
    "tryst.mp3", directory=data_directory, binary_mode=True
)  # 4:44s

In [None]:
# initialize & warmup
%time pipeline.initialize_pipeline(**run_opts, use_experimental_cache=True)

In [None]:
# should be warmed up now (time should be similar to # small data)
%time list(pipeline(pipeline_data, **run_opts))

In [None]:
# small data
%time list(pipeline(pipeline_data, **run_opts))

In [None]:
# small data in batch
%time list(pipeline([pipeline_data] * 10, **run_opts))

In [None]:
# chunkable data
%time list(pipeline(pipeline_data_large, **run_opts))

In [None]:
# chunkable data in batches
%time list(pipeline([pipeline_data_large] * 32, **run_opts))

In [None]:
mixed_mode_data: list[bytes] = [
    pipeline_data_large,
    pipeline_data,
    pipeline_data,
    pipeline_data,
] * 4

In [None]:
# mixed data, received as it comes, using default `smallest` strategy the smaller files will come in larger batches first
start: float = time.time()
for data in pipeline(mixed_mode_data, strategy="smallest", **run_opts):
    print({"num_items": len(data)}, end="\n")
    print({"data": data, "time_taken": f"{time.time() - start:.2}s"}, end="\n")
    print("-" * 40, end="\n")
    start = time.time()