[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/whisper-jax-colab/blob/main/whisper_jax_medium_gradio_colab.ipynb)

In [None]:
!pip install -q "jax[cuda]" -U
!pip install -q git+https://github.com/camenduru/whisper-jax.git datasets soundfile librosa yt_dlp gradio

import gradio as gr
from yt_dlp import YoutubeDL
import os
import jax
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

pipeline = FlaxWhisperPipline("openai/whisper-medium", dtype=jnp.float16)
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("/content/jax_cache")
    
def download_video(url):
  ydl_opts = {'overwrites':True, 'format':'bestaudio[ext=m4a]', 'outtmpl':'/content/audio.m4a'}
  with YoutubeDL(ydl_opts) as ydl:
    ydl.download(url)
    return f"/content/audio.m4a"

def transcribe(audio_in):
    outputs = pipeline("/content/audio.m4a", return_timestamps=True)
    text = outputs["text"]
    chunks = outputs["chunks"]
    return text, chunks

app = gr.Blocks()
with app:
  with gr.Row():
    with gr.Column():
      input_text = gr.Textbox(show_label=False, value="https://www.youtube.com/watch?v=SN2sak8Tp70")
      input_download_button = gr.Button(value="Download from YouTube or Twitch")
      input_transcribe_button = gr.Button(value="Transcribe")
    with gr.Column():
        audio_out = gr.Audio(label="Output Audio")
        text_out = gr.Textbox(label="Output Text")
        chunks_out = gr.Textbox(label="Output Chunks")
    input_download_button.click(download_video, inputs=[input_text], outputs=[audio_out])
    input_transcribe_button.click(transcribe, inputs=[audio_out], outputs=[text_out, chunks_out])
  
app.launch(debug=True)