# Video subtitles generation using Whisper and OpenVINO

Whisper is an automatic speech recognition (ASR) system trained on 680,000 hours of multilingual and multitask supervised data collected from the web.  It is a multi-task model that can perform multilingual speech recognition as well as speech translation and language identification.


![asr-training-data-desktop.svg](asr-training-data-desktop.svg)

You can find more information about this model in [paper](https://cdn.openai.com/papers/whisper.pdf), [OpenAI blogpost](https://openai.com/blog/whisper/), [model card](https://github.com/openai/whisper/blob/main/model-card.md) and [repository](https://github.com/openai/whisper).

In this notebook we will use its capabilities for generation subtitles to video.
Notebook contains following steps:
1. Convert model to IR using OpenVINO Model Optimizer tool.
3. Run Whisper pipeline with OpenVINO models.

## Prerequisites

clone and install model repository

In [1]:
!git clone https://github.com/openai/whisper.git
%cd whisper
!python setup.py develop
!pip install git+https://github.com/pytube/pytube
!pip install moviepy


Cloning into 'whisper'...
remote: Enumerating objects: 228, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 228 (delta 4), reused 4 (delta 3), pack-reused 215[K
Receiving objects: 100% (228/228), 3.12 MiB | 277.00 KiB/s, done.
Resolving deltas: 100% (122/122), done.
/home/ea/work/openvino_notebooks/notebooks/226-whisper-subtitels-generation/whisper
running develop
running egg_info
creating whisper.egg-info
writing whisper.egg-info/PKG-INFO
writing dependency_links to whisper.egg-info/dependency_links.txt
writing entry points to whisper.egg-info/entry_points.txt
writing requirements to whisper.egg-info/requires.txt
writing top-level names to whisper.egg-info/top_level.txt
writing manifest file 'whisper.egg-info/SOURCES.txt'
file whisper.py (for module whisper) not found
reading manifest file 'whisper.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest fil

## Instantiate model
Whisper is a Transformer based encoder-decoder model, also referred to as a sequence-to-sequence model. It maps a sequence of audio spectrogram features to a sequence of text tokens. First, the raw audio inputs are converted to a log-Mel spectrogram by action of the feature extractor. The Transformer encoder then encodes the spectrogram to form a sequence of encoder hidden states. Finally, the decoder autoregressively predicts text tokens, conditional on both the previous tokens and the encoder hidden states.

You can see model architecture on diagram below:

![whisper_architecture.svg](whisper_architecture.svg)


There are several models of different sizes and capabilities trained by model authors. In this tutorial we will use `base` model, but the same actions are also applicable to other models from Whisper family.

In [30]:
import whisper

model = whisper.load_model("base")
model.eval()
pass

### Convert model to OpenVINO Intermediate Representation (IR) format.

For starting work with OpenVINO we should convert model to OpenVINO format.
OpenVINO supports Pytorch via ONNX conversion.  We will use `torch.onnx.export` for exportingWe need to provide initialized model object and example of inputs for shape inference.
We will use `mo.convert_model` functionality for conversion ONNX models. 
The `mo.convert_model` function returns OpenVINO model ready to use for model object for loading on device and making prediction.
We can save it on drive for next usage with `openvino.runtime.serialize`.





### Whisper Encoder to IR





In [31]:
import torch
from openvino.tools import mo
from openvino.runtime import serialize
mel = torch.zeros((1, 80, 3000))
audio_features = model.encoder(mel)
torch.onnx.export(model.encoder, mel, 'whisper_encoder.onnx', input_names=['mel'], output_names=['output_features'])
encoder_model = mo.convert_model('whisper_encoder.onnx', compress_to_fp16=True, input='mel[1 80 -1]')
serialize(encoder_model, 'whisper_encoder.xml')

### Whisper decoder to IR

For reducing computational complexity, decoder uses cached key/value projections in attention modules from previous steps. We need to modify this process for correct tracing to ONNX.

In [32]:
import torch
from typing import Optional, Union, List
from functools import partial

positional_embeddings_size = model.decoder.positional_embedding.shape[0]

def save_to_cache(cache, module, output):
    if module not in cache or output.shape[1] > positional_embeddings_size:
        cache[module] = output  # save as-is, for the first token or cross attention
    else:
        cache[module] = torch.cat([cache[module], output], dim=1).detach()
    return cache[module]
  
def attention_forward(
        attention_module,
        x: torch.Tensor,
        xa: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[dict] = None,
        idx: int = 0
):
    q = attention_module.query(x)

    if kv_cache is None or xa is None:
        # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
        # otherwise, perform key/value projections for self- or cross-attention as usual.
        k = attention_module.key(x if xa is None else xa)
        v = attention_module.value(x if xa is None else xa)
        if kv_cache is not None:
            k = save_to_cache(kv_cache, f'k_{idx}', k)
            v = save_to_cache(kv_cache, f'v_{idx}', v)
    else:
        # for cross-attention, calculate keys and values once and reuse in subsequent calls.
        k = kv_cache.get(f'k_{idx}', save_to_cache(kv_cache, f'k_{idx}', attention_module.key(xa)))
        v = kv_cache.get(f'v_{idx}', save_to_cache(kv_cache, f'v_{idx}', attention_module.value(xa)))

    wv = attention_module.qkv_attention(q, k, v, mask)
    return attention_module.out(wv), kv_cache


def block_forward(
        residual_block,
        x: torch.Tensor,
        xa: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[dict] = None,
        idx:int = 0
    ):
        x0, kv_cache = residual_block.attn(residual_block.attn_ln(x), mask=mask, kv_cache=kv_cache, idx=f'{idx}a')
        x = x + x0
        if residual_block.cross_attn:
            x1, kv_cache = residual_block.cross_attn(residual_block.cross_attn_ln(x), xa, kv_cache=kv_cache, idx=f'{idx}c')
            x = x + x1
        x = x + residual_block.mlp(residual_block.mlp_ln(x))
        return x, kv_cache

for idx, block in enumerate(model.decoder.blocks):
    block.forward = partial(block_forward, block, idx=idx)
    block.attn.forward = partial(attention_forward, block.attn)
    if block.cross_attn:
        block.cross_attn.forward = partial(attention_forward, block.cross_attn)


def decoder_forward(decoder, x: torch.Tensor, xa: torch.Tensor, kv_cache: Optional[dict] = None):
    """
    x : torch.LongTensor, shape = (batch_size, <= n_ctx) the text tokens
    xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
        the encoded audio features to be attended on
    """
    offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
    x = decoder.token_embedding(x) + decoder.positional_embedding[offset : offset + x.shape[-1]]
    x = x.to(xa.dtype)

    for block in decoder.blocks:
        x, kv_cache = block(x, xa, mask=decoder.mask, kv_cache=kv_cache)

    x = decoder.ln(x)
    logits = (x @ torch.transpose(decoder.token_embedding.weight.to(x.dtype), 1, 0)).float()

    return logits, kv_cache

model.decoder.forward = partial(decoder_forward, model.decoder)


In [33]:
tokens = torch.ones((5, 3), dtype=torch.int64)

logits, kv_cache = model.decoder(tokens, audio_features, kv_cache={})
kv_cache = {k: v for k, v in kv_cache.items()}
tokens = torch.ones((5, 1), dtype=torch.int64)

In [34]:
outputs = [f'out_{k}' for k in kv_cache.keys()]
inputs = [f'in_{k}' for k in kv_cache.keys()]
dynamic_axes = {'tokens': {0: 'beam_size', 1: 'seq_len'}, 'audio_features': {0: 'beam_size'}, 'logits': {0: 'beam_size', 1: 'seq_len'}}
dynamic_outs = {o: {0: 'beam_size', 1: 'prev_seq_len'} for o in outputs}
dynamic_inp = {i: {0: 'beam_size', 1: 'prev_seq_len'}  for i in inputs}
dynamic_axes.update(dynamic_outs)
dynamic_axes.update(dynamic_inp)
torch.onnx.export(
    model.decoder, {'x': tokens, 'xa': audio_features, 'kv_cache': kv_cache},
'whisper_decoder.onnx',
input_names=['tokens', 'audio_features'] + inputs,
output_names=['logits'] + outputs,
dynamic_axes=dynamic_axes
)

In [35]:
input_shapes = 'tokens[1..5 1..224],audio_features[1..5 1500 512]'
for k, v in kv_cache.items():
    if k.endswith('a'):
        input_shapes += f',in_{k}[1..5 0..224 512]' 
decoder_model = mo.convert_model(input_model='whisper_decoder.onnx', compress_to_fp16=True, input=input_shapes)
serialize(decoder_model, 'whisper_decoder.xml')

## Prepare inference pipeline

For running Pytorch Whisper model, you need just call `transcribe(autio)` function.  We will try to reuse original model pipeline for autio transcribing. In order to run model using OpenVINO, we need just update model parts and decoding functionality.

In [36]:
class OpenVINOAudioEncoder(torch.nn.Module):
    def __init__(self, core, model_path, device='CPU'):
        super().__init__()
        self.model = core.read_model(model_path)
        self.compiled_model = core.compile_model(self.model, device)
        self.output_blob = self.compiled_model.output(0)

    def forward(self, mel:torch.Tensor):
        return torch.from_numpy(self.compiled_model(mel)[self.output_blob])

In [37]:
class OpenVINOTextDecoder(torch.nn.Module):
    def __init__(self, core, model_path, device='CPU'):
        super().__init__()
        self._core = core
        self.model = core.read_model(model_path)
        self._input_names = [inp.any_name for inp in self.model.inputs]
        self.compiled_model = core.compile_model(self.model, device)
        self.device = device
    
    def init_past_inputs(self, feed_dict):
        beam_size = feed_dict['tokens'].shape[0]
        audio_len = feed_dict['audio_features'].shape[-1]
        previous_seq_len = 0
        for name in self._input_names:
            if name in ['tokens', 'audio_features']:
                continue
            feed_dict[name] = np.zeros((beam_size, previous_seq_len, audio_len), dtype=np.float32)
        return feed_dict

    def preprocess_kv_cache_inputs(self, feed_dict, kv_cache):
        if not kv_cache:
            return self.init_past_inputs(feed_dict)
        for k, v in kv_cache.items():
            new_k = f'in_{k}'
            if new_k in self._input_names:
                feed_dict[new_k] = v
        return feed_dict

    def postprocess_outputs(self, outputs):
        logits = None
        kv_cache = {}
        for output_t, out in outputs.items():
            if 'logits' in output_t.get_names():
                logits = torch.from_numpy(out)
            else:
                tensor_name = output_t.any_name
                kv_cache[tensor_name.replace('out_', '')] = torch.from_numpy(out)
        return logits, kv_cache

    def forward(self, x:torch.Tensor, xa:torch.Tensor, kv_cache: Optional[dict]=None):
        feed_dict = {'tokens': x, 'audio_features': xa}
        feed_dict = (self.preprocess_kv_cache_inputs(feed_dict, kv_cache))
        res = self.compiled_model(feed_dict)
        return self.postprocess_outputs(res)

In [38]:
from whisper.decoding import DecodingTask, Inference, DecodingOptions, DecodingResult


class OpenVINOInference(Inference):
    def __init__(self, model: "Whisper", initial_token_length: int):
        self.model: "Whisper" = model
        self.initial_token_length = initial_token_length
        self.kv_cache = {}
    
    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
        if tokens.shape[-1] > self.initial_token_length:
            # only need to use the last token except in the first forward pass
            tokens = tokens[:, -1:]
        logits, self.kv_cache = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
        return logits

    def cleanup_caching(self):
        self.kv_cache = {}

    def rearrange_kv_cache(self, source_indices):
        for module, tensor in self.kv_cache.items():
            # update the key/value cache to contain the selected sequences
            self.kv_cache[module] = tensor[source_indices]


class OpenVINODecodingTask(DecodingTask):
    def __init__(self, model: "Whisper", options: DecodingOptions):
        super().__init__(model, options)
        self.inference = OpenVINOInference(model, len(self.initial_tokens))

@torch.no_grad()
def decode(model: "Whisper", mel: torch.Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
    """
    Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).

    Parameters
    ----------
    model: Whisper
        the Whisper model instance

    mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
        A tensor containing the Mel spectrogram(s)

    options: DecodingOptions
        A dataclass that contains all necessary options for decoding 30-second segments

    Returns
    -------
    result: Union[DecodingResult, List[DecodingResult]]
        The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
    """
    single = mel.ndim == 2
    if single:
        mel = mel.unsqueeze(0)

    result = OpenVINODecodingTask(model, options).run(mel)
    
    if single:
        result = result[0]

    return result


In [39]:
del model.decoder
del model.encoder

In [40]:
from openvino.runtime import Core
from collections import namedtuple

Parameter = namedtuple('Parameter', ['device'])

core = Core()

model.encoder = OpenVINOAudioEncoder(core, 'whisper_encoder.xml')
model.decoder = OpenVINOTextDecoder(core, 'whisper_decoder.xml')
model.decode = partial(decode, model)

def parameters():
    return iter([Parameter(torch.device('cpu'))])

model.parameters = parameters

def logits(model, tokens: torch.Tensor, audio_features: torch.Tensor):
    return model.decoder(tokens, audio_features, None)[0]

model.logits = partial(logits, model)


In [41]:
import io
import numpy as np
from scipy.io import wavfile
from pytube import YouTube
from moviepy.editor import VideoFileClip

def resample(audio, src_sample_rate, dst_sample_rate):
    if src_sample_rate == dst_sample_rate:
        return audio
    duration = audio.shape[0] / src_sample_rate
    resampled_data = np.zeros(shape=(int(duration * dst_sample_rate)), dtype=np.float32)
    x_old = np.linspace(0, duration, audio.shape[0], dtype=np.float32)
    x_new = np.linspace(0, duration, resampled_data.shape[0], dtype=np.float32)
    resampled_audio = np.interp(x_new, x_old, audio)
    return resampled_audio.astype(np.float32)

def audio_to_float(audio):
    return audio.astype(np.float32) / np.iinfo(audio.dtype).max


def get_audio(video_file):
    input_video = VideoFileClip(str(video_file))
    input_video.audio.write_audiofile(video_file.stem + '.wav')
    input_audio_file = video_file.stem + '.wav'
    sample_rate, audio = wavfile.read(io.BytesIO(open(input_audio_file, 'rb').read()))
    audio = audio_to_float(audio)
    if audio.ndim == 2:
        audio = audio.mean(axis=1)
    resampled_audio = resample(audio, sample_rate, 16000)
    return resampled_audio

In [42]:
VIDEO_LINK = 'https://www.youtube.com/watch?v=kgL5LBM-hFI'

yt = YouTube(VIDEO_LINK)  
yt.streams.get_highest_resolution().download(filename = Path("downloaded_video.mp4"))

'/home/ea/work/openvino_notebooks/notebooks/226-whisper-subtitels-generation/whisper/downloaded_video.mp4'

In [43]:
from pathlib import Path
audio = get_audio(Path('downloaded_video.mp4'))

MoviePy - Writing audio in downloaded_video.wav


                                                                    

MoviePy - Done.


In [45]:
transcription = model.transcribe(audio, beam_size=5, best_of=5, task='translate')

In [46]:
def format_timestamp(seconds: float):
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    return (f"{hours}:" if hours > 0 else "00:") + f"{minutes:02d}:{seconds:02d},{milliseconds:03d}"

def prepare_srt_bilingual(transcription, translation):
    segment_lines = []
    for segment1, segment2 in zip(transcription['segments'], translation['segments']):
        segment_lines.append(str(segment1['id'] + 1) +'\n')
        time_start = format_timestamp(segment1['start'])
        time_end = format_timestamp(segment1['end'])
        time_str = f'{time_start} --> {time_end}\n'
        segment_lines.append(time_str)
        segment_lines.append(segment1['text'] + '\n' + segment2['text'] + '\n\n')
    return segment_lines


def prepare_srt(transcription):
    segment_lines = []
    for segment in transcription['segments']:
        segment_lines.append(str(segment['id'] + 1) +'\n')
        time_start = format_timestamp(segment['start'])
        time_end = format_timestamp(segment['end'])
        time_str = f'{time_start} --> {time_end}\n'
        segment_lines.append(time_str)
        segment_lines.append(segment['text'] + '\n\n')
    return segment_lines

In [47]:
srt_lines = prepare_srt(transcription)

In [48]:
from ipywidgets import Video
Video.from_file("downloaded_video.mp4", width=320, height=320)

Video(value=b'\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00isommp42\x00\x00Aimoov\x00\x00\x00lmvhd...', height='320…

In [49]:
print(''.join(srt_lines))

1
00:00:00,000 --> 00:00:05,000
 Oh, what's that?

2
00:00:05,000 --> 00:00:09,000
 Oh, wow.

3
00:00:09,000 --> 00:00:10,000
 Hello, humans.

4
00:00:13,000 --> 00:00:15,000
 Focus on me.

5
00:00:15,000 --> 00:00:18,000
 Focus on the guard.

6
00:00:18,000 --> 00:00:22,000
 Don't tell anyone what you've seen in here.

7
00:00:22,000 --> 00:00:30,000
 Have you seen what's in there?


