## The code shown below converts our pretrained Whisper torch model to ONNX format with KV caching, using 10s context window (instead of 30s) for faster inference on Android devices for example

In [None]:
# First of all, create conda environment with python 3.10:
#   conda create -n whisper2tflite python=3.10

#!sudo apt update && sudo apt install ffmpeg

# Source: https://colab.research.google.com/github/usefulsensors/openai-whisper/blob/main/notebooks/whisper_encoder_decoder_tflite.ipynb
#!pip install onnx
#!pip install onnxruntime
#!pip install transformers
#!pip install openai-whisper==20230117

### Download model from HuggingFace

In [None]:
from huggingface_hub import snapshot_download
model = "blko/whisper-base-sl-artur-full-ft"
revision = "772cbcea0383a8f4359d3bd8457aa63ca881c47b"	# Training in progress, step 32000 (optimal model)
token = None
snapshot_download(repo_id=model, token=token, revision=revision, local_dir="./whisper-base-sl-artur-full-ft-best")

In [4]:
# REFERENCE CODE FOR CONTEXT REDUCTION: https://github.com/sanchit-gandhi/codesnippets/blob/main/whisper-reduce-context.ipynb

from transformers import WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer
from datasets import load_dataset
import librosa
import torch

device = "cpu"
model_name = "./whisper-base-sl-artur-full-ft-best"

model = WhisperForConditionalGeneration.from_pretrained(model_name)
state_dict = model.state_dict()
state_dict["model.encoder.embed_positions.weight"] = state_dict["model.encoder.embed_positions.weight"][:500, :]

# now load these weights back into the Whisper model, this time configured for this new seq len
config = WhisperConfig.from_pretrained(model_name, max_source_positions=500)
model = WhisperForConditionalGeneration(config)

model.load_state_dict(state_dict)
model.save_pretrained("./whisper-base-sl-artur-full-ft-best-ctx10s")

Non-default generation parameters: {'max_length': 448, 'begin_suppress_tokens': [220, 50257]}


### Quickly test the model with 10s context window on sample recording

In [5]:
from transformers import WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer
from datasets import load_dataset
import librosa
import torch

# load the tokenizer
tokenizer = WhisperTokenizer.from_pretrained(model_name)
# set the input length to 10 seconds
feature_extractor = WhisperFeatureExtractor(chunk_length=10)
# combine to form the processor
processor = WhisperProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# check model works on a given sample
audio_path = "./test.wav"
audio, sr = librosa.load(audio_path, sr=16000)
if audio.shape[0] > 16000*10:
    audio = audio[0:int(16000*9.5)]

input_features = torch.asarray(processor.feature_extractor(audio)["input_features"]).to(device)
pred_ids = model.generate(input_features, max_new_tokens=128)
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)
print(pred_text)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


['Ampak samo to še zdaleč ni dovolj, če se ozremo po domačem osnovnčju te pogoje konec koncev izpolnjujeta tudi Venera in Mars, pa zato ... ']


In [6]:
!python convert_hf_to_openai.py --checkpoint ./whisper-base-sl-artur-full-ft-best-ctx10s --whisper_dump_path ./whisper-base-best-openai-ctx10s.pt

HF model path: ./whisper-base-sl-artur-full-ft-best-ctx10s
OpenAI model path: ./whisper-base-best-openai-ctx10s.pt


## Create a model that allows KV-caching, export as Onnx
(massive help from https://cprohm.de/blog/whisper-full/ and https://cprohm.de/blog/whisper-full/convert.py)

In [7]:
import torch

import whisper
from whisper.model import MultiHeadAttention


def export():
    model = whisper.load_model("./whisper-base-best-openai-ctx10s.pt", device="cpu")
    model.eval()
    patch(model)

    encoder = model.encoder
    decoder = FunctionalDecoder(model.decoder)

    x_mel = torch.randn(1, 80, 1000)
    x_tokens = torch.zeros((1, 10), dtype=torch.long)
    x_audio = encoder(x_mel)

    cache_self_attn = torch.zeros(
        (len(decoder.keys_self_attn), 1, 0, model.dims.n_text_state),
    )
    cache_cross_attn = torch.zeros(
        (len(decoder.keys_cross_attn), 1, 0, model.dims.n_audio_state),
    )

    print("self attn shape: ", cache_self_attn.shape)
    print("cross attn shape: ", cache_cross_attn.shape)

    torch.onnx.export(
        encoder,
        (x_mel,),
        "./ctx10encoder.onnx",
        input_names=["mel"],
        output_names=["audio"],
        dynamic_axes={
            "mel": {0: "batch", 1: "time"},
            "audio": {0: "batch", 1: "time"},
        },
        opset_version=12
    )

    torch.onnx.export(
        decoder,
        (x_tokens, x_audio, cache_self_attn, cache_cross_attn),
        "./ctx10decoder.onnx",
        input_names=["tokens", "audio", "cache_self_attn", "cache_cross_attn"],
        output_names=["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        dynamic_axes={
            # inputs
            "tokens": {0: "batch", 1: "seq"},
            "audio": {0: "batch", 1: "time"},
            "cache_self_attn": {1: "batch", 2: "cached_seq"},
            "cache_cross_attn": {1: "batch", 2: "cached_time"},
            # outputs
            "logits": {0: "batch", 1: "seq"},
            "new_cache_self_attn": {1: "batch", 2: "new_cached_seq"},
            "new_cache_cross_attn": {1: "batch", 2: "new_cached_time"},
        },
        opset_version=12
    )


def patch(model):
    for block in model.decoder.blocks:
        block.attn.__class__ = FunctionalMultiHeadAttention
        block.attn.n_ctx = model.dims.n_text_ctx

        block.cross_attn.__class__ = FunctionalMultiHeadAttention
        block.cross_attn.n_ctx = model.dims.n_audio_ctx


class FunctionalDecoder(torch.nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

        self.keys_self_attn = []
        self.keys_cross_attn = []

        for block in decoder.blocks:
            self.keys_self_attn += (block.attn.key, block.attn.value)
            self.keys_cross_attn += (block.cross_attn.key, block.cross_attn.value)

    def forward(self, x, xa, cache_self_attn, cache_cross_attn):
        kv_cache = {
            **dict(zip(self.keys_self_attn, cache_self_attn)),
            **dict(zip(self.keys_cross_attn, cache_cross_attn)),
        }

        logits = self.decoder(x, xa, kv_cache=kv_cache)
        return (
            logits,
            torch.stack([kv_cache[key] for key in self.keys_self_attn]),
            torch.stack([kv_cache[key] for key in self.keys_cross_attn]),
        )


class FunctionalMultiHeadAttention(MultiHeadAttention):
    def forward(self, x, xa=None, mask=None, kv_cache=None):
        k, v = self._get_kv(x, xa, kv_cache)

        q = self.query(x)
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def _get_kv(self, x, xa=None, kv_cache=None):
        xx = x if xa is None else xa
        assert xx is not None

        if kv_cache is None:
            return self.key(xx), self.value(xx)

        key = torch.concat([kv_cache[self.key], self.key(xx).detach()], dim=1)
        key = key[:, -self.n_ctx :, :]
        kv_cache[self.key] = key

        value = torch.concat([kv_cache[self.value], self.value(xx).detach()], dim=1)
        value = value[:, -self.n_ctx :, :]
        kv_cache[self.value] = value

        return kv_cache[self.key], kv_cache[self.value]

export()

self attn shape:  torch.Size([12, 1, 0, 512])
cross attn shape:  torch.Size([12, 1, 0, 512])


  assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
  **dict(zip(self.keys_self_attn, cache_self_attn)),
  **dict(zip(self.keys_cross_attn, cache_cross_attn)),


## Sample onnx inference using base Whisper model with KV caching and 10s context window

In [8]:
import onnx
import whisper.tokenizer
import onnxruntime as ort
import numpy
import whisper
import numpy as np
import time
import librosa

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

sess_encoder = ort.InferenceSession("./ctx10encoder.onnx", sess_options)
sess_decoder = ort.InferenceSession("./ctx10decoder.onnx", sess_options)

# prepare input data
audio_path = "./test.wav"

audio, sr = librosa.load(audio_path, sr=16000)
if audio.shape[0] > 16000*10:
    audio = audio[0:int(16000*9)]   # trim to 9 seconds -- helps prevent some hallucinations

audio = whisper.audio.pad_or_trim(audio, 16000*10)
mel_from_file = whisper.audio.log_mel_spectrogram(audio)
input_data = np.expand_dims(mel_from_file, 0)

orig_input_id_list = [50258, 50305, 50359]

# run encoder inference
t0 = time.time_ns()
xa = sess_encoder.run(
    ["audio"],  # output name
    { "mel": input_data }
)
t1 = time.time_ns()
print("Encoder inference time: {} ms".format((t1-t0)/1e6))

# run first decoder inference
shape_empty_cache_self_attn = shape_empty_cache_cross_attn = (12, 1, 0, 512)

cache_self_attn = np.zeros(shape_empty_cache_self_attn, dtype="float32")
cache_cross_attn = np.zeros(shape_empty_cache_cross_attn, dtype="float32")

x_tokens = np.expand_dims(np.array(orig_input_id_list, dtype=np.int64), 0)
x_audio = xa[0]

t0 = time.time_ns()
logits, cache_self_attn, cache_cross_attn = sess_decoder.run(
    ["logits", "new_cache_self_attn", "new_cache_cross_attn"],
    {
        "tokens": x_tokens,
        "audio": x_audio,
        "cache_self_attn": cache_self_attn,
        "cache_cross_attn": cache_cross_attn,
    },
)
t1 = time.time_ns()
print("First inference time: {} ms".format((t1-t0)/1e6))

output_tokens = orig_input_id_list

# run next decoder inferences
last_token = logits[0, -1, :].argmax()
input_token_tensor = np.expand_dims(np.array([last_token], dtype=np.int64), 0)
output_tokens.append(last_token)

#print(x_audio.shape)
#print(x_audio[:, :0, :].shape)

count = 0
while last_token != 50257 and count < 75:
    count += 1
    t0 = time.time_ns()
    logits, cache_self_attn, cache_cross_attn = sess_decoder.run(
        ["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        {
            "tokens": input_token_tensor,
            "audio": x_audio[:, :0, :],
            "cache_self_attn": cache_self_attn,
            "cache_cross_attn": cache_cross_attn
        }
    )
    t1 = time.time_ns()
    print("Subsequent inference time: {} ms".format((t1-t0)/1e6))
    print("logits: ", logits.shape)
    print("self: ", cache_self_attn.shape)
    print("cross: ", cache_cross_attn.shape)

    last_token = logits[0, -1, :].argmax()
    input_token_tensor = np.expand_dims(np.array([last_token], dtype=np.int64), 0)
    output_tokens.append(last_token)

tok = whisper.tokenizer.get_tokenizer(multilingual=True)
print(tok.decode_with_timestamps(output_tokens))


Encoder inference time: 121.293919 ms
First inference time: 130.769679 ms
Subsequent inference time: 80.017412 ms
logits:  (1, 1, 51865)
self:  (12, 1, 4, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 16.488141 ms
logits:  (1, 1, 51865)
self:  (12, 1, 5, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 17.271758 ms
logits:  (1, 1, 51865)
self:  (12, 1, 6, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 15.634089 ms
logits:  (1, 1, 51865)
self:  (12, 1, 7, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 21.0186 ms
logits:  (1, 1, 51865)
self:  (12, 1, 8, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 15.632377 ms
logits:  (1, 1, 51865)
self:  (12, 1, 9, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 15.704873 ms
logits:  (1, 1, 51865)
self:  (12, 1, 10, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 15.053787 ms
logits:  (1, 1, 51865)
self:  (12, 1, 11, 512)
cross:  (12, 1, 500, 512)
Subsequent inference time: 15.