### Download model from HuggingFace

In [None]:
from huggingface_hub import snapshot_download
model = "blko/whisper-base-sl-artur-full-ft"
revision = "772cbcea0383a8f4359d3bd8457aa63ca881c47b"	# Training in progress, step 32000 (optimal model)
token = None
snapshot_download(repo_id=model, token=token, revision=revision, local_dir="./whisper-base-sl-artur-full-ft-best")

### Convert model to CTranslate2 format

In [None]:
!pip install ctranslate2 OpenNMT-py==2.* sentencepiece

In [None]:
!ct2-transformers-converter --model ./whisper-base-sl-artur-full-ft-best --output_dir ./ct2_base_artur_best --copy_files tokenizer_config.json preprocessor_config.json

### Convert model to ONNX with KV caching  
(massive help from https://cprohm.de/blog/whisper-full/ and https://cprohm.de/blog/whisper-full/convert.py)

In [None]:
!pip install openai-whisper==20230117   # Important to use this older version

In [None]:
# Convert huggingface model to openai pytorch format using "convert_hf_to_openai.py" (source: )
!python convert_hf_to_openai.py --checkpoint ./whisper-base-sl-artur-full-ft-best --whisper_dump_path ./whisper-base-best-openai.pt

In [None]:
import torch

import whisper
from whisper.model import MultiHeadAttention


def export():
    model = whisper.load_model("./whisper-base-best-openai.pt", device="cpu")
    model.eval()
    patch(model)

    encoder = model.encoder
    decoder = FunctionalDecoder(model.decoder)

    x_mel = torch.randn(1, 80, 3000)
    x_tokens = torch.zeros((1, 10), dtype=torch.long)
    x_audio = encoder(x_mel)

    cache_self_attn = torch.zeros(
        (len(decoder.keys_self_attn), 1, 0, model.dims.n_text_state),
    )
    cache_cross_attn = torch.zeros(
        (len(decoder.keys_cross_attn), 1, 0, model.dims.n_audio_state),
    )

    print("self attn shape: ", cache_self_attn.shape)
    print("cross attn shape: ", cache_cross_attn.shape)

    torch.onnx.export(
        encoder,
        (x_mel,),
        "./kv_onnx_encoder.onnx",
        input_names=["mel"],
        output_names=["audio"],
        dynamic_axes={
            "mel": {0: "batch", 1: "time"},
            "audio": {0: "batch", 1: "time"},
        },
        opset_version=12
    )

    torch.onnx.export(
        decoder,
        (x_tokens, x_audio, cache_self_attn, cache_cross_attn),
        "./kv_onnx_decoder.onnx",
        input_names=["tokens", "audio", "cache_self_attn", "cache_cross_attn"],
        output_names=["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        dynamic_axes={
            # inputs
            "tokens": {0: "batch", 1: "seq"},
            "audio": {0: "batch", 1: "time"},
            "cache_self_attn": {1: "batch", 2: "cached_seq"},
            "cache_cross_attn": {1: "batch", 2: "cached_time"},
            # outputs
            "logits": {0: "batch", 1: "seq"},
            "new_cache_self_attn": {1: "batch", 2: "new_cached_seq"},
            "new_cache_cross_attn": {1: "batch", 2: "new_cached_time"},
        },
        opset_version=12
    )


def patch(model):
    for block in model.decoder.blocks:
        block.attn.__class__ = FunctionalMultiHeadAttention
        block.attn.n_ctx = model.dims.n_text_ctx    # n_text_ctx = 448 in general

        block.cross_attn.__class__ = FunctionalMultiHeadAttention
        block.cross_attn.n_ctx = model.dims.n_audio_ctx # n_audio_ctx = 1500


class FunctionalDecoder(torch.nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

        self.keys_self_attn = []
        self.keys_cross_attn = []

        for block in decoder.blocks:
            self.keys_self_attn += (block.attn.key, block.attn.value)
            self.keys_cross_attn += (block.cross_attn.key, block.cross_attn.value)

    def forward(self, x, xa, cache_self_attn, cache_cross_attn):
        # Create a dictionary that maps key() and value() functions to their respective cache values
        kv_cache = {
            **dict(zip(self.keys_self_attn, cache_self_attn)),
            **dict(zip(self.keys_cross_attn, cache_cross_attn)),
        }

        logits = self.decoder(x, xa, kv_cache=kv_cache)
        return (
            logits,
            torch.stack([kv_cache[key] for key in self.keys_self_attn]),
            torch.stack([kv_cache[key] for key in self.keys_cross_attn]),
        )

class FunctionalMultiHeadAttention(MultiHeadAttention):
    def forward(self, x, xa=None, mask=None, kv_cache=None):
        k, v = self._get_kv(x, xa, kv_cache)

        q = self.query(x)
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def _get_kv(self, x, xa=None, kv_cache=None):
        xx = x if xa is None else xa
        assert xx is not None

        if kv_cache is None:
            return self.key(xx), self.value(xx)

        # append a key to kv_cache
        key = torch.concat([kv_cache[self.key], self.key(xx).detach()], dim=1)
        # truncate kv cache to (at most) self.n_ctx length (this may be 1500 for n_audio_ctx or <= 448 for n_text_ctx)
        key = key[:, -self.n_ctx :, :]
        kv_cache[self.key] = key

        value = torch.concat([kv_cache[self.value], self.value(xx).detach()], dim=1)
        value = value[:, -self.n_ctx :, :]
        kv_cache[self.value] = value

        return kv_cache[self.key], kv_cache[self.value]

export()

# At this point, we have kv_onnx_encoder.onnx and kv_onnx_decoder.onnx models for Whisper inference