# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import json

# size of the whisper model to use, tiny is good for validation
MODEL_SIZE = "tiny"

In [None]:
#| export
from pydantic import BaseModel
from typing import List

class Word(BaseModel):
    start: float | None
    end: float | None
    text: str
    confidence: float

class Phrase(BaseModel):
    start: float
    end: float
    words: List[Word]
    confidence: float

class Transcript(BaseModel):
    phrases: List[Phrase]

In [None]:
#| export
example_transcript = Transcript(
    phrases=[
        Phrase(
            start=0.0,
            end=2.5,
            words=[
                Word(start=0.0, end=0.5, text="Hello", confidence=0.95),
                Word(start=0.6, end=1.2, text="world", confidence=0.98),
                Word(start=1.3, end=2.5, text="testing", confidence=0.92)
            ],
            confidence=0.96
        ),
        Phrase(
            start=3.0,
            end=5.0,
            words=[
                Word(start=3.0, end=3.5, text="another", confidence=0.93),
                Word(start=3.6, end=4.2, text="phrase", confidence=0.89),
                Word(start=4.3, end=5.0, text="here", confidence=0.94)
            ],
            confidence=0.93
        )
    ]
)

In [None]:
#| epxort
from pydantic import TypeAdapter

TranscriptHandler = TypeAdapter(Transcript)

def dump_transcript_to_json(transcript: Transcript, file_path: str):
    with open(file_path, "w") as file:
        file.write(transcript.model_dump_json(indent=2))

def load_transcript_from_json(file_path: str) -> Transcript:
    with open("test.json", "r") as file:
        json_data = json.load(file)

    # Validate the JSON data using the handler
    validated_transcript = TranscriptHandler.validate_python(json_data)
    return validated_transcript

dump_transcript_to_json(example_transcript, "../test.json")
deserialized_transcript = load_transcript_from_json("../test.json")

assert example_transcript == deserialized_transcript

In [None]:
#| export
def pretty_print_transcript(transcript: Transcript):
    for i, phrase in enumerate(transcript.phrases, start=1):
        print(f"{i}: [{phrase.start}-{phrase.end}] ({phrase.confidence}) {' '.join(word.text for word in phrase.words)}")
        for word in phrase.words:
            if word.confidence < 0.9:
                print(f"  {word.text} ({word.confidence})")

pretty_print_transcript(example_transcript)

1: [0.0-2.5] (0.96) Hello world testing
2: [3.0-5.0] (0.93) another phrase here
  phrase (0.89)


## Whisper JAX

This should be super fast on GPU. Doesn't have much of a perf advantage on CPU and produces phrase-level timestamps so a bit awkward for captioning. Could potentially implement the dynamic time-warping work or use some sort of forced aligner.

Due to the wonders of JAX, pipeline needs to be run once for JIT. Subsequent runs will be faster.

In [None]:
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp

whisper_jax_pipeline = FlaxWhisperPipline(f"openai/whisper-{MODEL_SIZE}", dtype=jnp.bfloat16)



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
#| export
def transcribe_with_whisper_jax(audio_path: str, **kwargs) -> Transcript:
    transcript = whisper_jax_pipeline(audio_path, return_timestamps=True, **kwargs)['chunks']
    return Transcript(phrases = [
        Phrase(start=chunk['timestamp'][0], end=chunk['timestamp'][1] if chunk['timestamp'][1] else -1, 
            words=[
                Word(start=None, end=None, text=word, confidence=1)
                for word in chunk['text'].strip().split()
            ],
            confidence=1
        )
        for chunk in transcript
    ])

transcript = transcribe_with_whisper_jax("../data/transcribe_test.mp3")
pretty_print_transcript(transcript)



Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


1: [0.0-4.4] (1.0) In the last chapter, you and I started to step through the internal workings of a transformer.
2: [4.4-10.8] (1.0) This is one of the key pieces of technology inside large language models, and a lot of other tools in the modern way of AI.
3: [10.8-15.52] (1.0) It first hit the scene and a now famous 2017 paper called Attention as All You Need,
4: [15.52-19.6] (1.0) and in this chapter, you and I will dig into what this attention mechanism is,
5: [19.6-21.6] (1.0) visualizing how it processes data.
6: [26.48-30.12] (1.0) As a quick recap, here's the important context I want you to have in mind.
7: [30.12-34.68] (1.0) The goal of the model that you and I are studying is to take in a piece of text and predict
8: [34.68-36.8] (1.0) what word comes next.
9: [36.8-41.04] (1.0) The input text is broken up into little pieces that we call tokens, and these are very
10: [41.04-47.0] (1.0) often words or pieces of words, but just to make the examples in this video easier for yo

## Whisper Timestamped

[Documentation](https://github.com/linto-ai/whisper-timestamped)

Uses dynamic time warping to generate timestamps for every word.
Runs faster than JAX on CPU but does so by forsaking beam decoding. Might be problematic.

In [None]:
import whisper_timestamped as whisper
import jax

# Check if GPU is available and set the device accordingly
device = "gpu" if jax.local_devices()[0].device_kind == "GPU" else "cpu"
model = whisper.load_model(MODEL_SIZE, device=device)

In [None]:
#| export

def transcribe_with_whisper_timestamped(audio_path: str, **kwargs) -> Transcript:
    audio = whisper.load_audio(audio_path)
    transcript = whisper.transcribe(model, audio, language="en", vad="auditok", **kwargs)['segments']
    return Transcript(phrases=[
        Phrase(start=segment['start'], end=segment['end'], confidence=segment['confidence'],
            words=[
                Word(start=word['start'], end=word['end'], text=word['text'], confidence=word['confidence'])
                for word in segment["words"]
            ]
        )
        for segment in transcript
    ])

transcript = transcribe_with_whisper_timestamped("../data/transcribe_test.mp3")
pretty_print_transcript(transcript)

100%|██████████| 5998/5998 [00:03<00:00, 1738.12frames/s]

1: [0.08-4.0] (0.91) In the last chapter, you and I started to step through the internal workings of a transformer.
  you (0.559)
  started (0.896)
  workings (0.837)
  transformer. (0.829)
2: [4.5-7.64] (0.968) This is one of the key pieces of technology inside large language models,
  language (0.817)
3: [7.96-10.3] (0.973) and a lot of other tools in the modern way of AI.
  way (0.79)
4: [10.96-15.24] (0.8) It first hit the scene and a now famous 2017 paper called Attention as All You Need,
  and (0.43)
  famous (0.435)
  Attention (0.773)
  as (0.366)
  You (0.695)
5: [15.66-19.54] (0.974) and in this chapter, you and I will dig into what this attention mechanism is,
  attention (0.813)
6: [19.74-21.58] (0.978) visualizing how it processes data.
7: [26.36-29.48] (0.957) As a quick recap, here's the important context I want you to have in mind.
  As (0.598)
8: [30.08-33.88] (0.969) The goal of the model that you and I are studying is to take in a piece of text
  is (0.775)
9: [34.04




In [None]:
#export
transcribe = lambda audio_path, **kwargs: transcribe_with_whisper_timestamped(audio_path, **kwargs)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()