**You may need to restart the kernel for the setup stage to be applied**

## Setup

In [None]:
!git clone https://github.com/felafax/whisper-jax.git

In [None]:
!cd ./whisper-jax; pip install -e .

In [None]:
!pip install -q datasets soundfile librosa numpy huggingface scipy

## Helper Functions

In [None]:
import time

class ASRBenchmark:
    def __init__(self, audio):
        self.audio = audio
        self.duration = len(audio['array']) / audio['sampling_rate']

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.end = time.time()
        self.elapsed = self.end - self.start
        rtf = self.duration / self.elapsed
    
        print(f"  🎧  Audio Duration         │  {(self.duration/60):.2f} minutes")
        print(f"  ⏳ Model Execution Time    │  {self.elapsed:.3f} seconds")
        print(f"  🚄 Real-Time Factor (RTF)  │  {rtf:.3f}x")

## Loading the Pipeline

In [None]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import jax

In [None]:
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=32)

We'll then initialise a compilation cache, which will speed-up the compilation time if we close our kernel and want to compile the model again:

In [None]:
from jax.experimental.compilation_cache import compilation_cache as cc

!rm -rf ./jax_cache
cc.initialize_cache("./jax_cache")

## 🎶 Loading audio

In [None]:
from datasets import load_dataset
test_dataset = load_dataset("sanchit-gandhi/whisper-jax-test-files", split="train")

In [None]:
audio = test_dataset[1]["audio"]

We can take a listen to the audio file that we've loaded - we'll see that it's approximately 5 mins long:

In [None]:
audio_30m = audio
audio_10m = audio.copy()
audio_10m['array'] = audio_30m['array'][:audio_30m['array'].shape[0] // 3]

In [None]:
from IPython.display import Audio

Audio(audio["array"], rate=audio["sampling_rate"])

## Run the model - 10 Minute benchmark

In [None]:
# Do a JIT compilation warmup
%time text = pipeline(audio_10m)

In [None]:
with ASRBenchmark(audio_10m):
    text = pipeline(audio_10m)

## 30 Minute run

In [None]:
audio = test_dataset[1]["audio"]  # load the second sample (30 mins) and get the audio array

audio_length_in_mins = len(audio["array"]) / audio["sampling_rate"] / 60
print(f"Audio is {audio_length_in_mins} mins.")

In [None]:
%time text = pipeline(audio_30m)

In [None]:
with ASRBenchmark(audio_30m):
    text = pipeline(audio_30m)