In [1]:
# See available TPU devices
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [2]:
# Download Whisper JAX from git repo, and audio loading dependencies
!pip install --quiet --upgrade pip
!pip install --quiet git+https://github.com/sanchit-gandhi/whisper-jax.git datasets soundfile librosa

[0m

In [3]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login
# Get HuggingFace Token
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("huggingface-token")
# Login to Huggingface Hub
login(token=hf_token)

  from .autonotebook import tqdm as notebook_tqdm


Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
# Run Whisper JAX through FlaxWhisperPipline class for ease of use
whisper = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)

Downloading (…)rocessor_config.json: 100%|██████████| 185k/185k [00:00<00:00, 3.77MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 800/800 [00:00<00:00, 120kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 836k/836k [00:00<00:00, 9.97MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.20M/2.20M [00:00<00:00, 21.1MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 494k/494k [00:00<00:00, 9.79MB/s]
Downloading (…)main/normalizer.json: 100%|██████████| 52.7k/52.7k [00:00<00:00, 24.8MB/s]
Downloading (…)in/added_tokens.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.21MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 2.08k/2.08k [00:00<00:00, 1.27MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 1.99k/1.99k [00:00<00:00, 341kB/s]
Downloading flax_model.msgpack: 100%|██████████| 6.17G/6.17G [00:30<00:00, 200MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 3.51k/3.51k [00:00<00:00, 1.87MB/s]


In [5]:
# Speed up transcription through a compilation cache
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("./jax_cache")



In [6]:
from datasets import load_dataset
# get test audio for forward JIT call
test_dataset = load_dataset("sanchit-gandhi/whisper-jax-test-files", split="train")
test_audio = test_dataset[0]["audio"]  # load the first sample (5 mins) and get the audio array

Downloading readme: 100%|██████████| 371/371 [00:00<00:00, 289kB/s]


Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]
Downloading data:   0%|          | 0.00/113M [00:00<?, ?B/s][A
Downloading data:   3%|▎         | 3.38M/113M [00:00<00:03, 33.8MB/s][A
Downloading data:   7%|▋         | 8.28M/113M [00:00<00:02, 42.7MB/s][A
Downloading data:  12%|█▏        | 13.2M/113M [00:00<00:02, 45.7MB/s][A
Downloading data:  16%|█▌        | 17.8M/113M [00:00<00:02, 45.5MB/s][A
Downloading data:  20%|█▉        | 22.3M/113M [00:00<00:01, 45.6MB/s][A
Downloading data:  24%|██▍       | 27.2M/113M [00:00<00:01, 46.6MB/s][A
Downloading data:  28%|██▊       | 32.0M/113M [00:00<00:01, 47.1MB/s][A
Downloading data:  32%|███▏      | 36.7M/113M [00:00<00:01, 47.2MB/s][A
Downloading data:  37%|███▋      | 41.5M/113M [00:00<00:01, 46.9MB/s][A
Downloading data:  41%|████      | 46.2M/113M [00:01<00:01, 47.0MB/s][A
Downloading data:  45%|████▌     | 51.2M/113M [00:01<00:01, 47.8MB/s][A
Downloading data:  49%|████▉     | 56.1M/113M [00:01<00:01, 48.2MB/s][

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/sanchit-gandhi___parquet/sanchit-gandhi--whisper-jax-test-files-95479fe55e88baac/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


In [8]:
# transcribe on test audio - should be slow
%time text = whisper(test_audio)

CPU times: user 3min 16s, sys: 57.9 s, total: 4min 14s
Wall time: 2min 6s


In [10]:
# transribe again on test audio - should be pretty fast
%time text = whisper(test_audio)

CPU times: user 25 s, sys: 45.8 s, total: 1min 10s
Wall time: 5.65 s


In [11]:
# transribe on test audio with timestamps - should be slow
%time outputs = whisper(test_audio, task="transcribe", return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

CPU times: user 3min 8s, sys: 44.9 s, total: 3min 53s
Wall time: 2min 1s


In [12]:
# transribe again on test audio with timestamps - should be lightning
%time outputs = whisper(test_audio, task="transcribe", return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

CPU times: user 39.4 s, sys: 1min 6s, total: 1min 45s
Wall time: 3.16 s


In [14]:
!pip install nltk

[0m

In [15]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [16]:
import nltk.data
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')

In [52]:
def timestamp_sentences(chunks, video_id=None):
    timestamped = []
    init_text = chunks[0]['text'].lstrip()
    sentences = sent_detector.tokenize(init_text)
    start, end = int(chunks[0]['timestamp'][0]), int(chunks[0]['timestamp'][1])
    if len(sentences) > 1:
        for sent in sentences[:-1]:
            timestamped.append({'id': video_id, 'start': start, 'end': end, 'text': sent})
    current_sentence = {'id': video_id, 'start': start, 'end': end, 'text': sentences[-1]}
    
    prev_text = init_text
    for chunk in chunks[1:]:
        text = chunk['text'].lstrip()
        # Rudimentarily deal with hallucination         
        if text == prev_text: continue
        # Tokenize each chunk into sentences        
        sentences = sent_detector.tokenize(text)
        for sent in sentences:
            # If not sentence, append sent to current sentence            
            if sent[0].isalpha() and current_sentence['text'][-1] not in (".", "?", "!"):
                current_sentence['text'] += " " + sent
                current_sentence['end'] = int(chunk['timestamp'][1])
            # Else create a new sentence            
            else:
                timestamped.append(current_sentence)
                start = int(chunk['timestamp'][0])
                end = int(chunk['timestamp'][1]) if chunk['timestamp'][1] is not None else None
                current_sentence = {'id': video_id, 'start': start, 'end': end, 'text': sent}
        prev_text = text
    # Append the last sentence    
    timestamped.append(current_sentence)
    return timestamped

In [53]:
import os
dir_ = '/kaggle/input/huberman-audio/mp3'
huberman_audio = os.listdir(dir_)
huberman_timestamped, error_ids = [], []
# Iterate over every audio file
for i, audio_file in enumerate(huberman_audio):
    video_id = audio_file[:-4]
    print('Transcribing Video ->', video_id)
    audio_path = os.path.join(dir_, audio_file)
    try:
        # Transcribe     
        %time outputs = whisper(audio_path, task="transcribe", return_timestamps=True)
        # Timestamp each sentence     
        timestamped = timestamp_sentences(chunks=outputs['chunks'], video_id=video_id)
        huberman_timestamped.extend(timestamped)
        print('Transcribed ({done}/{total})\n'.format(done=i+1, total=len(huberman_audio)))
    except Exception as err:
        print('ERROR transcribing video ->', video_id, '\n' + err)
        error_ids.append(video_id)

Transcribing Video -> 7YGZZcXqKxE
CPU times: user 10min 8s, sys: 17min 17s, total: 27min 25s
Wall time: 1min 14s
Transcribed (1/129)

Transcribing Video -> NAATB55oxeQ
CPU times: user 11min 22s, sys: 19min 8s, total: 30min 30s
Wall time: 1min 20s
Transcribed (2/129)

Transcribing Video -> 8IWDAqodDas
CPU times: user 6min 14s, sys: 10min 43s, total: 16min 58s
Wall time: 49.7 s
Transcribed (3/129)

Transcribing Video -> 9tRohh0gErM
CPU times: user 13min 53s, sys: 23min 22s, total: 37min 16s
Wall time: 1min 41s
Transcribed (4/129)

Transcribing Video -> ulHrUVV3Kq4
CPU times: user 12min 19s, sys: 20min 44s, total: 33min 3s
Wall time: 1min 27s
Transcribed (5/129)

Transcribing Video -> ntfcfJ28eiU
CPU times: user 9min 52s, sys: 16min 41s, total: 26min 33s
Wall time: 1min 11s
Transcribed (6/129)

Transcribing Video -> ObtW353d5i0
CPU times: user 9min 15s, sys: 15min 49s, total: 25min 5s
Wall time: 1min 15s
Transcribed (7/129)

Transcribing Video -> a9yFKPmPZ90
CPU times: user 16min 18s, sys

There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used?


CPU times: user 10min, sys: 17min 15s, total: 27min 15s
Wall time: 1min 21s
Transcribed (38/129)

Transcribing Video -> szqPAPKE5tQ
CPU times: user 12min 43s, sys: 21min 35s, total: 34min 18s
Wall time: 1min 35s
Transcribed (39/129)

Transcribing Video -> hx3U64IXFOY
CPU times: user 8min 14s, sys: 14min 3s, total: 22min 17s
Wall time: 1min 5s
Transcribed (40/129)

Transcribing Video -> oC3fhUjg30E
CPU times: user 12min 13s, sys: 20min 33s, total: 32min 47s
Wall time: 1min 27s
Transcribed (41/129)

Transcribing Video -> Ze2pc6NwsHQ
CPU times: user 11min 20s, sys: 19min 10s, total: 30min 30s
Wall time: 1min 22s
Transcribed (42/129)

Transcribing Video -> O1YRwWmue4Y
CPU times: user 24min 51s, sys: 41min 53s, total: 1h 6min 45s
Wall time: 3min 9s
Transcribed (43/129)

Transcribing Video -> IOl28gj_RXw
CPU times: user 13min 59s, sys: 23min 43s, total: 37min 43s
Wall time: 1min 47s
Transcribed (44/129)

Transcribing Video -> mcPSRWUYCv0
CPU times: user 10min 9s, sys: 17min 7s, total: 27min 

There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used?


CPU times: user 8min 52s, sys: 15min, total: 23min 52s
Wall time: 1min 13s
Transcribed (59/129)

Transcribing Video -> LG53Vxum0as
CPU times: user 7min 13s, sys: 12min 23s, total: 19min 37s
Wall time: 1min 3s
Transcribed (60/129)

Transcribing Video -> VAEzZeaV5zM
CPU times: user 11min 57s, sys: 20min 14s, total: 32min 11s
Wall time: 1min 29s
Transcribed (61/129)

Transcribing Video -> Mwz8JprPeMc
CPU times: user 11min 58s, sys: 20min 19s, total: 32min 17s
Wall time: 1min 28s
Transcribed (62/129)

Transcribing Video -> dzOvi0Aa2EA
CPU times: user 11min 8s, sys: 18min 57s, total: 30min 6s
Wall time: 1min 26s
Transcribed (63/129)

Transcribing Video -> azb3Ih68awQ
CPU times: user 11min 40s, sys: 19min 55s, total: 31min 36s
Wall time: 1min 30s
Transcribed (64/129)

Transcribing Video -> poOf8b2WE2g
CPU times: user 10min 5s, sys: 16min 56s, total: 27min 1s
Wall time: 1min 25s
Transcribed (65/129)

Transcribing Video -> 77CdVSpnUX4
CPU times: user 9min 14s, sys: 15min 40s, total: 24min 54s


In [64]:
from datasets import Dataset
dataset = Dataset.from_list(huberman_timestamped)

In [65]:
dataset.push_to_hub("hbattu/huberman-timestamped")

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format:   0%|          | 0/165 [00:00<?, ?ba/s][A
Creating parquet from Arrow format:  42%|████▏     | 70/165 [00:00<00:00, 697.94ba/s][A
Creating parquet from Arrow format: 100%|██████████| 165/165 [00:00<00:00, 693.65ba/s][A

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s][A
Upload 1 LFS files: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s][A
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:01<00:00,  1.13s/it]
