In [None]:
%pip install -q "openvino>=2023.1.0"
%pip install -q "urllib"
%pip install -q "python-ffmpeg<=1.0.16" moviepy transformers onnx
%pip install -q -I "git+https://github.com/garywu007/pytube.git"
%pip install -q -U gradio
%pip install -q -I "git+https://github.com/openai/whisper.git@e8622f9afc4eba139bf796c210f5c01081000472"

In [None]:
import whisper

model_id = "base"
model = whisper.load_model("base")
model.to("cpu")
model.eval()
pass

### Convert Whisper Encoder to OpenVINO IR [$\Uparrow$](#Table-of-content:)


In [None]:
from pathlib import Path

WHISPER_ENCODER_OV = Path("whisper_encoder.xml")
WHISPER_DECODER_OV = Path("whisper_decoder.xml")

In [None]:
import torch
import openvino as ov

mel = torch.zeros((1, 80, 3000))
audio_features = model.encoder(mel)
encoder_model = ov.convert_model(model.encoder, example_input=mel)
ov.save_model(encoder_model, WHISPER_ENCODER_OV)

In [None]:
import torch
from typing import Optional, Tuple
from functools import partial


def attention_forward(
        attention_module,
        x: torch.Tensor,
        xa: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
    """
    Override for forward method of decoder attention module with storing cache values explicitly.
    Parameters:
      attention_module: current attention module
      x: input token ids.
      xa: input audio features (Optional).
      mask: mask for applying attention (Optional).
      kv_cache: dictionary with cached key values for attention modules.
      idx: idx for search in kv_cache.
    Returns:
      attention module output tensor
      updated kv_cache
    """
    q = attention_module.query(x)

    if xa is None:
        # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
        # otherwise, perform key/value projections for self- or cross-attention as usual.
        k = attention_module.key(x)
        v = attention_module.value(x)
        if kv_cache is not None:
            k = torch.cat((kv_cache[0], k), dim=1)
            v = torch.cat((kv_cache[1], v), dim=1)
        kv_cache_new = (k, v)
    else:
        # for cross-attention, calculate keys and values once and reuse in subsequent calls.
        k = attention_module.key(xa)
        v = attention_module.value(xa)
        kv_cache_new = (None, None)

    wv, qk = attention_module.qkv_attention(q, k, v, mask)
    return attention_module.out(wv), kv_cache_new


def block_forward(
    residual_block,
    x: torch.Tensor,
    xa: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
    """
    Override for residual block forward method for providing kv_cache to attention module.
      Parameters:
        residual_block: current residual block.
        x: input token_ids.
        xa: input audio features (Optional).
        mask: attention mask (Optional).
        kv_cache: cache for storing attention key values.
      Returns:
        x: residual block output
        kv_cache: updated kv_cache

    """
    x0, kv_cache = residual_block.attn(residual_block.attn_ln(
        x), mask=mask, kv_cache=kv_cache)
    x = x + x0
    if residual_block.cross_attn:
        x1, _ = residual_block.cross_attn(
            residual_block.cross_attn_ln(x), xa)
        x = x + x1
    x = x + residual_block.mlp(residual_block.mlp_ln(x))
    return x, kv_cache



# update forward functions
for idx, block in enumerate(model.decoder.blocks):
    block.forward = partial(block_forward, block)
    block.attn.forward = partial(attention_forward, block.attn)
    if block.cross_attn:
        block.cross_attn.forward = partial(attention_forward, block.cross_attn)


def decoder_forward(decoder, x: torch.Tensor, xa: torch.Tensor, kv_cache: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor]]] = None):
    """
    Override for decoder forward method.
    Parameters:
      x: torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens
      xa: torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
           the encoded audio features to be attended on
      kv_cache: Dict[str, torch.Tensor], attention modules hidden states cache from previous steps 
    """
    if kv_cache is not None:
        offset = kv_cache[0][0].shape[1]
    else:
        offset = 0
        kv_cache = [None for _ in range(len(decoder.blocks))]
    x = decoder.token_embedding(
        x) + decoder.positional_embedding[offset: offset + x.shape[-1]]
    x = x.to(xa.dtype)
    kv_cache_upd = []

    for block, kv_block_cache in zip(decoder.blocks, kv_cache):
        x, kv_block_cache_upd = block(x, xa, mask=decoder.mask, kv_cache=kv_block_cache)
        kv_cache_upd.append(tuple(kv_block_cache_upd))

    x = decoder.ln(x)
    logits = (
        x @ torch.transpose(decoder.token_embedding.weight.to(x.dtype), 1, 0)).float()

    return logits, tuple(kv_cache_upd)



# override decoder forward
model.decoder.forward = partial(decoder_forward, model.decoder)

In [None]:
tokens = torch.ones((5, 3), dtype=torch.int64)
logits, kv_cache = model.decoder(tokens, audio_features, kv_cache=None)

tokens = torch.ones((5, 1), dtype=torch.int64)
decoder_model = ov.convert_model(model.decoder, example_input=(tokens, audio_features, kv_cache))

ov.save_model(decoder_model, WHISPER_DECODER_OV)

In [None]:
core = ov.Core()

In [None]:
import ipywidgets as widgets

device = widgets.Dropdown(
    options=core.available_devices + ["AUTO"],
    value='AUTO',
    description='Device:',
    disabled=False,
)

device

In [None]:
from utils import patch_whisper_for_ov_inference, OpenVINOAudioEncoder, OpenVINOTextDecoder

patch_whisper_for_ov_inference(model)

model.encoder = OpenVINOAudioEncoder(core, WHISPER_ENCODER_OV, device=device.value)
model.decoder = OpenVINOTextDecoder(core, WHISPER_DECODER_OV, device=device.value)

In [None]:
def output_file = "MarkBranded.mp4"

from urllib.request import urlretrieve

urlretrieve("https://quoscdn.s3.amazonaws.com/media/0440968d-538d-427e-8b1f-aa824c897bcc/Black_Tech_Weekend_-_Detroit/MarkBoys2MenEdit_branded.mp4", output_file )


In [None]:
from utils import get_audio

audio = get_audio(output_file)

Select the task for the model:

* **transcribe** - generate audio transcription in the source language (automatically detected).
* **translate** - generate audio transcription with translation to English language.

In [None]:
task = widgets.Select(
    options=["transcribe", "translate"],
    value="transcribe",
    description="Select task:",
    disabled=False
)
task

In [None]:
transcription = model.transcribe(audio, task=task.value)

"The results will be saved in the `downloaded_video.srt` file. SRT is one of the most popular formats for storing subtitles and is compatible with many modern video players. This file can be used to embed transcription into videos during playback or by injecting them directly into video files using `ffmpeg`.

In [None]:
from utils import prepare_srt

srt_lines = prepare_srt(transcription)
# save transcription
with output_file.with_suffix(".srt").open("w") as f:
    f.writelines(srt_lines)

Now let us see the results.

In [None]:
widgets.Video.from_file(output_file, loop=False, width=800, height=800)

In [None]:
print("".join(srt_lines))

## Interactive demo [$\Uparrow$](#Table-of-content:)

In [None]:
import gradio as gr


def transcribe(url, task):
    output_file = Path("download_video.mp4")
    yt = YouTube(url)
    yt.streams.get_highest_resolution().download(filename=output_file)
    audio = get_audio(output_file)
    transcription = model.transcribe(audio, task=task.lower())
    srt_lines = prepare_srt(transcription)
    with output_file.with_suffix(".srt").open("w") as f:
        f.writelines(srt_lines)
    return [str(output_file), str(output_file.with_suffix(".srt"))]


demo = gr.Interface(
    transcribe,
    [gr.Textbox(label="YouTube URL"), gr.Radio(["Transcribe", "Translate"], value="Transcribe")],
    "video",
    examples=[["https://youtu.be/kgL5LBM-hFI", "Transcribe"]],
    allow_flagging="never"
)
try:
    demo.launch(debug=True)
except Exception:
    demo.launch(share=True, debug=True)
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_name='your server name', server_port='server port in int')
# Read more in the docs: https://gradio.app/docs/