### Test CTranslate2 model

In [None]:
import numpy as np
import ctranslate2
from ctranslate2.models import Whisper
import whisper
import time

tok = whisper.tokenizer.get_tokenizer(multilingual=True)
m = Whisper(model_path="./ct2_base_artur_best")

audio_path = "./mnt/d/Music/asr/test_laptop_mic-loud.wav"    # 16kHz sample rate
mel_from_file = whisper.audio.log_mel_spectrogram(audio_path).unsqueeze(dim=0).numpy()
print(mel_from_file.shape)
features = ctranslate2.StorageView.from_array(mel_from_file)

# correct_prompt = [50258, 50305, 50359, 50363]  # == <|startoftranscript|><|sl|><|transcribe|><|notimestamps|>
correct_prompt = [50258, 50305, 50359]

t1 = time.time_ns()
out = m.generate(features=features, prompts=[correct_prompt], beam_size=1, return_scores=True)
t2 = time.time_ns()
#print(out[0].sequences_ids[0])
print("Runtime: {:.2f}s".format((t2-t1)/1e9), tok.decode(out[0].sequences_ids[0]))
#print(m.detect_language(features))
print(out[0].scores)

### Test ONNX models using KV caching

In [None]:
import onnx
import whisper.tokenizer
import onnxruntime as ort
import numpy
import whisper
import numpy as np
import time

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

sess_encoder = ort.InferenceSession("./kv_onnx_encoder.onnx", sess_options)
sess_decoder = ort.InferenceSession("./kv_onnx_decoder.onnx", sess_options)

# prepare input data
audio_path = "./test_speech.wav"    # 16kHz sample rate
mel_from_file = whisper.audio.log_mel_spectrogram(audio_path)
input_data = whisper.audio.pad_or_trim(mel_from_file, whisper.audio.N_FRAMES)
input_data = np.expand_dims(input_data, 0)

orig_input_id_list = [50258, 50305, 50359]

# run encoder inference
xa = sess_encoder.run(
    ["audio"],  # output name
    { "mel": input_data }
)

# run first decoder inference
shape_empty_cache_self_attn = shape_empty_cache_cross_attn = (12, 1, 0, 512)

cache_self_attn = np.zeros(shape_empty_cache_self_attn, dtype="float32")
cache_cross_attn = np.zeros(shape_empty_cache_cross_attn, dtype="float32")

x_tokens = np.expand_dims(np.array(orig_input_id_list, dtype=np.int64), 0)
x_audio = xa[0]

t0 = time.time_ns()
logits, cache_self_attn, cache_cross_attn = sess_decoder.run(
    ["logits", "new_cache_self_attn", "new_cache_cross_attn"],
    {
        "tokens": x_tokens,
        "audio": x_audio,
        "cache_self_attn": cache_self_attn,
        "cache_cross_attn": cache_cross_attn,
    },
)
t1 = time.time_ns()
print("First inference time: {} ms".format((t1-t0)/1e6))

output_tokens = orig_input_id_list

# run next decoder inferences
last_token = logits[0, -1, :].argmax()
input_token_tensor = np.expand_dims(np.array([last_token], dtype=np.int64), 0)
output_tokens.append(last_token)

#print(x_audio.shape)
#print(x_audio[:, :0, :].shape)

while last_token != 50257:
    t0 = time.time_ns()
    logits, cache_self_attn, cache_cross_attn = sess_decoder.run(
        ["logits", "new_cache_self_attn", "new_cache_cross_attn"],
        {
            "tokens": input_token_tensor,
            "audio": x_audio[:, :0, :],
            "cache_self_attn": cache_self_attn,
            "cache_cross_attn": cache_cross_attn
        }
    )
    t1 = time.time_ns()
    print("Subsequent inference time: {} ms".format((t1-t0)/1e6))
    print("logits: ", logits.shape)
    print("self: ", cache_self_attn.shape)
    print("cross: ", cache_cross_attn.shape)

    last_token = logits[0, -1, :].argmax()
    input_token_tensor = np.expand_dims(np.array([last_token], dtype=np.int64), 0)
    output_tokens.append(last_token)

tok = whisper.tokenizer.get_tokenizer(multilingual=True)
print(tok.decode_with_timestamps(output_tokens))
