## Whisper JAX ⚡️

This Kaggle notebook demonstratese how to run Whisper JAX on a TPU v3-8. Whisper JAX is a highly optimised JAX implementation of the Whisper model by OpenAI, largely built on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x faster**, making it the fastest Whisper implementation available.

# Straightforward Code
---

## Install dependencies

In [None]:
!apt-get update -y
!apt install -y ffmpeg
!pip install --quiet --upgrade pip
!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git 
#!pip datasets soundfile librosa
!pip install pytubefix

## Import libraries

In [None]:
# Suppress specific FutureWarnings
import warnings
warnings.filterwarnings("ignore")

# Load the PipeLine
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

# Initialize Compilation Cache
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("./jax_cache")

# Load YouTube libraries and Link and others
from pytubefix import YouTube 
import os
import json
from datetime import timedelta

# Function
def format_timedelta(td):
    """Format timedelta to SRT timecode format (HH:MM:SS,MMM)"""
    total_seconds = int(td.total_seconds())
    hours, remainder = divmod(total_seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    milliseconds = td.microseconds // 1000
    return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"

## Run core script

In [106]:
timestamps = True

model = "openai/whisper-large-v2" if timestamps else "openai/whisper-large-v3"

# Check if 'pipeline' is None
if pipeline is None:
    pipeline = FlaxWhisperPipline(model, dtype=jnp.bfloat16, batch_size=16)

# Continue with the rest of your code


link = input("Paste your YouTube link: ")
try: 
    # object creation using YouTube
    # which was imported in the beginning 
    yt = YouTube(link) 
    print("Copied link. Success")
except: 
    print("Connection Error")

#streams_mp4 = yt.streams.filter(file_extension='mp4').first()

print("\nDownloading YouTube audio....\n")

# Define the stream and download it
stream = yt.streams.get_by_itag(139)
stream.download(output_path='.', filename='audio.mp3')
YouTubeAudio='/kaggle/working/audio.mp3'

# Check if 'train_audio.mp3' exists in the current directory
if not os.path.exists('train_audio.mp3'):
    print("Training not done yet, proceding... (Should take around 1-2 minutes)")
    
    # Run the pipeline on a short audio to cache
    !wget "https://nt.doveai.cloud/s/3ZAEHJcxTcWb5FS/download/train_audio.mp3" -O "train_audio.mp3" > /dev/null 2>&1
    
    # JIT compile the forward call - slow, but we only do once
    test_audio = "train_audio.mp3"
    %time text = pipeline(test_audio)
    
    if timestamps:
        # compile the forward call with timestamps - slow but we only do once
        %time outputs = pipeline(test_audio, return_timestamps=True)
        time_stamped_text = outputs["text"]  # transcription
        time_stamped_chunks = outputs["chunks"]  # transcription + timestamps
    
else:
    print("Cache already loaded and Pipeline is ready, skipping training stage...")
    
# used cached function thereafter - super fast!
print("\nTranscribing audio... \n")
%time text = pipeline(YouTubeAudio)

# Get the video title
video_title = yt.title

# Print the video title
print("\nVideo Title:", video_title,"\n")

# let's check our transcription - looks spot on!
print("Transcribed:",text)

# Extract the text value from the JSON dictionary
text_content = text.get("text", "")

# Save the transcription to a file
with open("transcription.txt", "w", encoding="utf-8") as file:
    file.write(text_content)
    print("Saved file as: transcription.txt")

# If timestamps is ON we can proceed to do it as well
if timestamps:
    # use cached timestamps function - super fast!
    %time outputs = pipeline(YouTubeAudio, return_timestamps=True)
    time_stamped_text = outputs["text"] 
    time_stamped_chunks = outputs["chunks"]
    
    # Prepare the SRT content
    srt_content = []
    for i, chunk in enumerate(outputs["chunks"], start=1):
        if isinstance(chunk, dict):
            start_time, end_time = chunk.get("timestamp", (0, 0))
            text = chunk.get("text", "")

            # Convert to timedelta objects for formatting
            start_td = timedelta(seconds=start_time)
            end_td = timedelta(seconds=end_time)

            # Format timestamps
            start_str = format_timedelta(start_td)
            end_str = format_timedelta(end_td)

            # Format the SRT entry
            srt_entry = f"{i}\n{start_str} --> {end_str}\n{text}\n"
            srt_content.append(srt_entry)

    # Write the SRT file
    with open("timestamped_transcription.srt", "w", encoding="utf-8") as file:
        file.writelines(srt_content)
    
    #print("Time stamped chunks: ", time_stamped_chunks)
    print("Saved file as: timestamped_transcription.txt")

Paste your YouTube link:  https://www.youtube.com/watch?v=joRPB5zwcmQ


Copied link. Success

Downloading YouTube audio....

Cache already loaded and Pipeline is ready, skipping training stage...

Transcribing audio... 

CPU times: user 45.4 s, sys: 1min 18s, total: 2min 4s
Wall time: 4.71 s
Transcribed: {'text': ' Hola amigos, bendiciones. Queremos compartirte este video. Es un sueño acerca del rapto de la iglesia. Hoy quiero hablarles sobre el sueño que Dios me entregó hace dos meses atrás sobre el rapto. Me encontraba en un puerto de barcos y allí había un barco bien grande y dentro de esa embarcación estaba Jesús y él salía hacia afuera de la embarcación y detrás habían dos ángeles dentro del barco que sostenían la puerta esperando la orden de Dios para cerrar la puerta, sin embargo gente estaban entrando y el Señor me vio a mí que estaba evangelizando a alguien para ganarlo para Cristo. Y mientras estaba hablando sobre Cristo, este hombre, el Señor me estaba haciendo la señal de que tenía que apresurarme para entrar y porque ya el tiempo se había acab

---

## Short section (for cached pipeline)

In [None]:
# Once cached - You can run this (No timestamp)

link = input("Paste your YouTube link: ")
try: 
    # object creation using YouTube
    # which was imported in the beginning 
    yt = YouTube(link) 
    print("Copied link. Success")
except: 
    print("Connection Error")

print("\nDownloading YouTube audio....\n")

# Define the stream and download it
stream = yt.streams.get_by_itag(139)
stream.download(output_path='.', filename='audio.mp3')
YouTubeAudio='/kaggle/working/audio.mp3'

# JIT compile the forward call - slow, but we only do once
print("\nTranscribing audio... \n")
test_audio = "train_audio.mp3"
%time text = pipeline(YouTubeAudio)