# Video Subtitle Generation using Whisper and OpenVINO™

[Whisper](https://openai.com/blog/whisper/) is an automatic speech recognition (ASR) system trained on 680,000 hours of multilingual and multitask supervised data collected from the web. It is a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.

![asr-training-data-desktop.svg](https://user-images.githubusercontent.com/29454499/204536347-28976978-9a07-416c-acff-fc1214bbfbe0.svg)

You can find more information about this model in the [research paper](https://cdn.openai.com/papers/whisper.pdf), [OpenAI blog](https://openai.com/blog/whisper/), [model card](https://github.com/openai/whisper/blob/main/model-card.md) and GitHub [repository](https://github.com/openai/whisper).

In this notebook, we will use Whisper with OpenVINO to generate subtitles in a sample video.
Notebook contains the following steps:
1. Download the model.
2. Instantiate the PyTorch model pipeline.
3. Export the ONNX model and convert it to OpenVINO IR, using the Model Optimizer tool.
4. Run the Whisper pipeline with OpenVINO models.

## Prerequisites

Clone and install the model repository.

In [1]:
!pip install -q "python-ffmpeg<=1.0.16" moviepy

In [2]:
!pip install -q -I "git+https://github.com/garywu007/pytube.git"
!pip install -q pydub

In [3]:
from pathlib import Path

REPO_DIR = Path("whisper")
if not REPO_DIR.exists():
    !git clone https://github.com/openai/whisper.git -b v20230124
%cd whisper
!python setup.py develop

Cloning into 'whisper'...
remote: Enumerating objects: 712, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 712 (delta 1), reused 3 (delta 0), pack-reused 702[K
Receiving objects: 100% (712/712), 12.45 MiB | 8.50 MiB/s, done.
Resolving deltas: 100% (417/417), done.
Note: switching to '55f690af7914c672c69733b7e04ef5a41b2b2774'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.

If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -c with the switch command. Example:

  git switch -c <new-branch-name>

Or undo this operation with:

  git switch -

Turn off this advice by setting config variable advice.detachedHead to false

/home/user/Desktop/whisper/whisper_openvino/whisper_openvino/whisper
  import

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


running egg_info
creating openai_whisper.egg-info
writing openai_whisper.egg-info/PKG-INFO
writing dependency_links to openai_whisper.egg-info/dependency_links.txt
writing entry points to openai_whisper.egg-info/entry_points.txt
writing requirements to openai_whisper.egg-info/requires.txt
writing top-level names to openai_whisper.egg-info/top_level.txt
writing manifest file 'openai_whisper.egg-info/SOURCES.txt'
file whisper.py (for module whisper) not found
reading manifest file 'openai_whisper.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'openai_whisper.egg-info/SOURCES.txt'
running build_ext
Creating /home/user/miniconda3/envs/whisper/lib/python3.10/site-packages/openai-whisper.egg-link (link to .)
Adding openai-whisper 20230124 to easy-install.pth file
Installing whisper script to /home/user/miniconda3/envs/whisper/bin

Installed /home/user/Desktop/whisper/whisper_openvino/whisper_openvino/whisper
Processing depend

## Instantiate model
Whisper is a Transformer based encoder-decoder model, also referred to as a sequence-to-sequence model. It maps a sequence of audio spectrogram features to a sequence of text tokens. First, the raw audio inputs are converted to a log-Mel spectrogram by action of the feature extractor. Then, the Transformer encoder encodes the spectrogram to form a sequence of encoder hidden states. Finally, the decoder autoregressively predicts text tokens, conditional on both the previous tokens and the encoder hidden states.

You can see the model architecture in the diagram below:

![whisper_architecture.svg](https://user-images.githubusercontent.com/29454499/204536571-8f6d8d77-5fbd-4c6d-8e29-14e734837860.svg)


There are several models of different sizes and capabilities trained by the authors of the model. In this tutorial, we will use the `base` model, but the same actions are also applicable to other models from Whisper family.

In [4]:
import whisper

model = whisper.load_model("small")
model.to("cpu")
model.eval()
pass

100%|███████████████████████████████████████| 461M/461M [01:14<00:00, 6.53MiB/s]


### Convert model to OpenVINO Intermediate Representation (IR) format.

For best results with OpenVINO, it is recommended to convert the model to OpenVINO IR format. OpenVINO supports PyTorch via ONNX conversion. We will use `torch.onnx.export` for exporting the ONNX model from PyTorch. We need to provide initialized model object and example of inputs for shape inference. We will use `mo.convert_model` functionality to convert the ONNX models. The `mo.convert_model` Python function returns an OpenVINO model ready to load on device and start making predictions. We can save it on disk for next usage with `openvino.runtime.serialize`.


### Convert Whisper Encoder to OpenVINO IR

In [5]:
from pathlib import Path

WHISPER_ENCODER_OV = Path(f"whisper_encoder.xml")
WHISPER_DECODER_OV = Path(f"whisper_decoder.xml")

In [6]:
import torch
from openvino.tools import mo
from openvino.runtime import serialize
mel = torch.zeros((1, 80, 3000))
audio_features = model.encoder(mel)
if not WHISPER_ENCODER_OV.exists():
    torch.onnx.export(
        model.encoder, 
        mel, 
        "whisper_encoder.onnx",
        input_names=["mel"], 
        output_names=["output_features"]
    )
    encoder_model = mo.convert_model("whisper_encoder.onnx", compress_to_fp16=True)
    serialize(encoder_model, xml_path="whisper_encoder.xml")

  assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"


### Convert Whisper decoder to OpenVINO IR

To reduce computational complexity, the decoder uses cached key/value projections in attention modules from the previous steps. We need to modify this process for correct tracing to ONNX.

In [7]:
import torch
from typing import Optional, Union, List, Dict
from functools import partial

positional_embeddings_size = model.decoder.positional_embedding.shape[0]


def save_to_cache(cache: Dict[str, torch.Tensor], module: str, output: torch.Tensor):
    """
    Saving cached attention hidden states for previous tokens.
    Parameters:
      cache: dictionary with cache.
      module: current attention module name.
      output: predicted hidden state.
    Returns:
      output: cached attention hidden state for specified attention module.
    """
    if module not in cache or output.shape[1] > positional_embeddings_size:
        # save as-is, for the first token or cross attention
        cache[module] = output
    else:
        cache[module] = torch.cat([cache[module], output], dim=1).detach()
    return cache[module]


def attention_forward(
        attention_module,
        x: torch.Tensor,
        xa: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[dict] = None,
        idx: int = 0
):
    """
    Override for forward method of decoder attention module with storing cache values explicitly.
    Parameters:
      attention_module: current attention module
      x: input token ids.
      xa: input audio features (Optional).
      mask: mask for applying attention (Optional).
      kv_cache: dictionary with cached key values for attention modules.
      idx: idx for search in kv_cache.
    Returns:
      attention module output tensor
      updated kv_cache
    """
    q = attention_module.query(x)

    if kv_cache is None or xa is None:
        # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
        # otherwise, perform key/value projections for self- or cross-attention as usual.
        k = attention_module.key(x if xa is None else xa)
        v = attention_module.value(x if xa is None else xa)
        if kv_cache is not None:
            k = save_to_cache(kv_cache, f'k_{idx}', k)
            v = save_to_cache(kv_cache, f'v_{idx}', v)
    else:
        # for cross-attention, calculate keys and values once and reuse in subsequent calls.
        k = kv_cache.get(f'k_{idx}', save_to_cache(
            kv_cache, f'k_{idx}', attention_module.key(xa)))
        v = kv_cache.get(f'v_{idx}', save_to_cache(
            kv_cache, f'v_{idx}', attention_module.value(xa)))

    wv, qk = attention_module.qkv_attention(q, k, v, mask)
    return attention_module.out(wv), kv_cache


def block_forward(
    residual_block,
    x: torch.Tensor,
    xa: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    kv_cache: Optional[dict] = None,
    idx: int = 0
):
    """
    Override for residual block forward method for providing kv_cache to attention module.
      Parameters:
        residual_block: current residual block.
        x: input token_ids.
        xa: input audio features (Optional).
        mask: attention mask (Optional).
        kv_cache: cache for storing attention key values.
        idx: index of current residual block for search in kv_cache.
      Returns:
        x: residual block output
        kv_cache: updated kv_cache

    """
    x0, kv_cache = residual_block.attn(residual_block.attn_ln(
        x), mask=mask, kv_cache=kv_cache, idx=f'{idx}a')
    x = x + x0
    if residual_block.cross_attn:
        x1, kv_cache = residual_block.cross_attn(
            residual_block.cross_attn_ln(x), xa, kv_cache=kv_cache, idx=f'{idx}c')
        x = x + x1
    x = x + residual_block.mlp(residual_block.mlp_ln(x))
    return x, kv_cache


# update forward functions
for idx, block in enumerate(model.decoder.blocks):
    block.forward = partial(block_forward, block, idx=idx)
    block.attn.forward = partial(attention_forward, block.attn)
    if block.cross_attn:
        block.cross_attn.forward = partial(attention_forward, block.cross_attn)


def decoder_forward(decoder, x: torch.Tensor, xa: torch.Tensor, kv_cache: Optional[dict] = None):
    """
    Override for decoder forward method.
    Parameters:
      x: torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens
      xa: torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
           the encoded audio features to be attended on
      kv_cache: Dict[str, torch.Tensor], attention modules hidden states cache from previous steps 
    """
    offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
    x = decoder.token_embedding(
        x) + decoder.positional_embedding[offset: offset + x.shape[-1]]
    x = x.to(xa.dtype)

    for block in decoder.blocks:
        x, kv_cache = block(x, xa, mask=decoder.mask, kv_cache=kv_cache)

    x = decoder.ln(x)
    logits = (
        x @ torch.transpose(decoder.token_embedding.weight.to(x.dtype), 1, 0)).float()

    return logits, kv_cache


# override decoder forward
model.decoder.forward = partial(decoder_forward, model.decoder)

In [8]:
tokens = torch.ones((5, 3), dtype=torch.int64)

logits, kv_cache = model.decoder(tokens, audio_features, kv_cache={})
kv_cache = {k: v for k, v in kv_cache.items()}
tokens = torch.ones((5, 1), dtype=torch.int64)

In [9]:
outputs = [f"out_{k}" for k in kv_cache.keys()]
inputs = [f"in_{k}" for k in kv_cache.keys()]
dynamic_axes = {
    "tokens": {0: "beam_size", 1: "seq_len"},
    "audio_features": {0: "beam_size"},
    "logits": {0: "beam_size", 1: "seq_len"}}
dynamic_outs = {o: {0: "beam_size", 1: "prev_seq_len"} for o in outputs}
dynamic_inp = {i: {0: "beam_size", 1: "prev_seq_len"} for i in inputs}
dynamic_axes.update(dynamic_outs)
dynamic_axes.update(dynamic_inp)
if not WHISPER_DECODER_OV.exists():
    torch.onnx.export(
        model.decoder, {'x': tokens, 'xa': audio_features, 'kv_cache': kv_cache},
        'whisper_decoder.onnx',
        input_names=["tokens", "audio_features"] + inputs,
        output_names=["logits"] + outputs,
        dynamic_axes=dynamic_axes
    )

  if module not in cache or output.shape[1] > positional_embeddings_size:


The decoder model autoregressively predicts the next token guided by encoder hidden states and previously predicted sequence. This means that the shape of inputs which depends on the previous step (inputs for tokens and attention hidden states from previous step) are dynamic. For efficient utilization of memory, you define an upper bound for dynamic input shapes.

In [10]:
import onnx

def get_input_shapes(onnx_model_path):
    input_shapes = {}
    model = onnx.load(onnx_model_path)
    for input in model.graph.input:
        input_name = input.name
        input_shape = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
        input_shapes[input_name] = input_shape
    return input_shapes

# 使用範例
onnx_model_path = "whisper_decoder.onnx"
input_shapes = get_input_shapes(onnx_model_path)
print("Decoder ONNX 輸入形狀：", input_shapes)

Decoder ONNX 輸入形狀： {'tokens': [0, 0], 'audio_features': [0, 1500, 768], 'in_k_0a': [0, 0, 768], 'in_v_0a': [0, 0, 768], 'in_k_1a': [0, 0, 768], 'in_v_1a': [0, 0, 768], 'in_k_2a': [0, 0, 768], 'in_v_2a': [0, 0, 768], 'in_k_3a': [0, 0, 768], 'in_v_3a': [0, 0, 768], 'in_k_4a': [0, 0, 768], 'in_v_4a': [0, 0, 768], 'in_k_5a': [0, 0, 768], 'in_v_5a': [0, 0, 768], 'in_k_6a': [0, 0, 768], 'in_v_6a': [0, 0, 768], 'in_k_7a': [0, 0, 768], 'in_v_7a': [0, 0, 768], 'in_k_8a': [0, 0, 768], 'in_v_8a': [0, 0, 768], 'in_k_9a': [0, 0, 768], 'in_v_9a': [0, 0, 768], 'in_k_10a': [0, 0, 768], 'in_v_10a': [0, 0, 768], 'in_k_11a': [0, 0, 768], 'in_v_11a': [0, 0, 768]}


In [11]:

input_shapes = "tokens[1..11 1..600],audio_features[1..11 1500 768]"
for k, v in kv_cache.items():
    if k.endswith('a'):
        input_shapes += f",in_{k}[1..11 0..600 768]"
if not WHISPER_DECODER_OV.exists():
    decoder_model = mo.convert_model(
        input_model="whisper_decoder.onnx",
        compress_to_fp16=True,
        input=input_shapes)
    serialize(decoder_model, "whisper_decoder.xml")

