In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
## Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
%load_ext autoreload
%autoreload 2

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# Accelerating HuggingFace Whisper Inference with TensorRT

Whisper is an encoder-decoder model that converts ASR problems into a speech-to-text format. More specifically, it does so by encoding speech in the input stream. This enables a single model to be trained supervised on a wide variety of Language

This notebook shows 3 easy steps to convert a [HuggingFace PyTorch Whisper model](https://huggingface.co/transformers/model_doc/whisper.html) to a TensorRT engine for high-performance inference.

1. [Download HuggingFace whisper model](#1)
1. [Convert to ONNX format](#2)
1. [Convert to TensorRT engine](#3)

## Prerequisite

Follow the instruction at https://github.com/NVIDIA/TensorRT to build the TensorRT-OSS docker container required to run this notebook.

Next, we install some extra dependencies.

In [2]:
# %%capture
# !pip3 install -r ../requirements.txt

**Note:** After this step, you should restart the Jupyter kernel for the change to take effect.

In [3]:
import os
import sys
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

import torch
import tensorrt as trt

# huggingface
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig
)

<a id="1"></a>

## 1. Download HuggingFace T5 model and Whisper model

First, we download the original HuggingFace PyTorch T5 model from HuggingFace model hubs, together with its associated tokernizer.

The T5 variants that are suported by TensorRT 8 are:  t5-small (60M), t5-base (220M), t5-large (770M), t5-3b(3B), t5-11b(11B)

In [4]:
import torch
from datasets import load_dataset

Whisper_VARIANT = "openai/whisper-tiny"    # choices: openai/whisper-tiny | openai/whisper-base | openai/whisper-small | openai/whisper-medium | openai/whisper-large-v2

processor = WhisperProcessor.from_pretrained(Whisper_VARIANT)
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT)
wh_config = WhisperConfig.from_pretrained(Whisper_VARIANT, use_cache = False)

In [5]:
whisper_model.config.vocab_size

51865

In [6]:
# save model locally
pytorch_model_dir = './models/{}/pytorch'.format(Whisper_VARIANT)
!mkdir -p $pytorch_model_dir

whisper_model.save_pretrained(pytorch_model_dir)
print("Pytorch Model saved to {}".format(pytorch_model_dir))

Pytorch Model saved to ./models/openai/whisper-tiny/pytorch


# Encoder output이 다름!!

In [199]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [200]:
encoder_outputs = whisper_model.get_encoder()(inputs['input_features'].cuda())

In [201]:
encoder_outputs.last_hidden_state

tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [181]:
whisper_trt_encoder(inputs['input_features'])

tensor([[[ 0.0650,  0.0491,  0.0228,  ...,  0.0256,  0.0469, -0.0124],
         [-0.8297, -1.4162,  0.2438,  ...,  0.8406,  0.1607,  0.2904],
         [-0.6757, -1.3830,  0.4186,  ...,  0.0135, -0.1272,  0.6469],
         ...,
         [ 0.7634, -1.6631,  1.0584,  ..., -0.8335,  0.1044,  0.6690],
         [ 0.6760, -1.7349,  0.5412,  ..., -0.2875, -0.0158,  0.4131],
         [ 0.2625, -0.1128, -1.3304,  ...,  0.2592, -0.4166, -0.0922]]])

In [182]:
decoder_start_token_id = whisper_model._get_decoder_start_token_id(None, 50258)
input_ids  = torch.ones((1, 1), dtype=torch.long, device='cuda') * decoder_start_token_id

In [183]:
whisper_model.get_decoder().max_source_positions

1500

In [184]:
whisper_model.cuda()
hf_out = whisper_model.generate(inputs['input_features'].cuda())

In [185]:
decoder_outputs = whisper_model.model.decoder(input_ids=input_ids.cuda(), encoder_hidden_states=encoder_outputs['last_hidden_state'].cuda())

In [186]:
hf_output = whisper_model.proj_out(decoder_outputs.last_hidden_state)

In [187]:
hf_output

tensor([[[-2.4892, -5.4267,  2.6655,  ...,  0.2883,  1.1339,  1.4177]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [188]:
whisper_trt_decoder.MAX_SOURCE_POSITIONS

384

In [189]:
#dir(whisper_trt_decoder)

In [191]:
trt_output = whisper_trt_decoder(input_ids,encoder_outputs['last_hidden_state'])

In [192]:
trt_output.logits

tensor([[[-2.4938, -5.4309,  2.6605,  ...,  0.2829,  1.1290,  1.4134]]],
       device='cuda:0')

In [388]:

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(whisper_model.config.max_length)])
logits_processor = LogitsProcessorList(
    [
        NoRepeatNGramLogitsProcessor(
            WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
        ),
        MinLengthLogitsProcessor(
            min_length, WhisperModelTRTConfig.EOS_TOKEN_ID
        ),
        ForcedBOSTokenLogitsProcessor(50258),
        ForcedEOSTokenLogitsProcessor(
            whisper_model.config.max_length, WhisperModelTRTConfig.EOS_TOKEN_ID
        ),
    ]
)

In [194]:
whisper_model(
    input_features=input_features,
    decoder_input_ids=input_ids,
    encoder_outputs=encoder_outputs,
    #stopping_criteria=stopping_criteria,
    #logits_processor=logits_processor,
    use_cache=False
)

Seq2SeqLMOutput(loss=None, logits=tensor([[[-2.4892, -5.4267,  2.6655,  ...,  0.2883,  1.1339,  1.4177]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>), past_key_values=None, decoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, encoder_last_hidden_state=tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), encoder_hidden_states=None, encoder_attentions=None)

In [195]:
whisper_trt_decoder

<Whisper.trt.WhisperTRTDecoder at 0x7f69adffe490>

BaseModelOutput(last_hidden_state=tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None)

In [219]:
%%time
#decoder_output = whisper_trt_decoder.greedy_search(
decoder_output = whisper_trt_decoder.greedy_search(
    input_ids=input_ids,
    encoder_hidden_states=encoder_outputs,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=False,
)

1
2


RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1].  Tensor sizes: [2]

In [217]:
input_ids.flatten().shape

torch.Size([1])

In [213]:
inputs

{'input_features': tensor([[[-8.3201e-01, -8.3201e-01, -8.3201e-01,  ..., -2.4743e-01,
          -3.2218e-02, -1.1799e-01],
         [-8.3201e-01, -8.3201e-01, -8.3201e-01,  ...,  8.0496e-04,
          -1.2268e-01, -8.3760e-02],
         [-8.3201e-01, -8.3201e-01, -8.3201e-01,  ...,  2.2572e-02,
          -3.2526e-01, -1.9711e-01],
         ...,
         [-8.3201e-01, -8.3201e-01, -8.3201e-01,  ..., -8.1406e-01,
          -7.2681e-01, -4.5137e-01],
         [-8.3201e-01, -8.3201e-01, -8.3201e-01,  ..., -8.0641e-01,
          -7.7886e-01, -5.0440e-01],
         [-8.3201e-01, -8.3201e-01, -8.3201e-01,  ..., -8.3201e-01,
          -8.3201e-01, -6.4279e-01]]])}

In [202]:
%%time
decoder_output = whisper_model.greedy_search(
    input_ids=input_ids,
    encoder_outputs=encoder_outputs,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=False,
)

CPU times: user 38.6 ms, sys: 0 ns, total: 38.6 ms
Wall time: 38.3 ms


In [203]:
decoder_output

tensor([[50258, 50258, 50358, 50358, 50363,   440, 50257]], device='cuda:0')

In [204]:
tokenizer.batch_decode(decoder_output)

['<|startoftranscript|><|startoftranscript|><|translate|><|translate|><|notimestamps|> The<|endoftext|>']

In [101]:
processor.batch_decode(hf_out)

["<|startoftranscript|><|en|><|transcribe|><|notimestamps|> The most powerful tap-kana is still being used to strengthen the strong energy. However, the movement speed is almost the same as the slow speed. The tap-kana is going to move in the same direction as the middle of the road. The speed is getting faster, and the next week, the first-ever Kyushu men's shopping mall will be held. The strength is not too weak.<|endoftext|>"]

In [103]:
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="english", task="transcribe")

TypeError: get_decoder_prompt_ids() missing 1 required positional argument: 'self'

In [None]:
inputs_tensor.device

In [None]:
batch_size,
decoder_start_token_id=decoder_start_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device

In [106]:
whisper_model._prepare_decoder_input_ids_for_generation(1)

tensor([[50258]], device='cuda:0')

In [107]:
tokenizer.decode([50258])

'<|startoftranscript|>'

In [104]:
WhisperProcessor.get_decoder_prompt_ids

<function transformers.models.whisper.processing_whisper.WhisperProcessor.get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True)>

In [None]:
whisper_model.config

### Inference with PyTorch model

Next, we will carry out inference with the PyTorch model.

#### Single example inference

In [30]:
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

audio_inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = audio_inputs.input_features

# WAR: Using an ugly representation because cuda 11.4 does not support GPU models due to cublas errors
if "LD_LIBRARY_PATH" in os.environ and "cuda-11.4" in os.environ["LD_LIBRARY_PATH"]:
    whisper_model = whisper_model.cpu()
    input_features = input_features.to('cpu')
else:
    whisper_model = whisper_model.cuda()
    input_features = input_features.to('cuda:0')   

Found cached dataset librispeech_asr_dummy (/home/nvadmin/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [31]:
with torch.no_grad():
    generated_ids = whisper_model.generate(inputs=input_features)

transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription
# ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'



' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'

#### Model inference benchmark: encoder and decoder stacks

For benchmarking purposes, we will employ a helper functions `encoder_inference` and `decoder_inference` which execute the inference repeatedly for the T5 encoder and decoder stacks separately, and measure end to end execution time. Let's take note of this execution time for comparison with TensorRT. 
 
`TimingProfile` is a named tuple that specifies the number of experiments and number of times to call the function per iteration (and number of warm-up calls although it is not used here).

In [32]:
from Whisper.measurements import decoder_inference as w_decoder_inference, encoder_inference as w_encoder_inference, full_inference as w_full_inference, full_inference_greedy, full_inference_beam
from Whisper.export import WhisperEncoderTorchFile, WhisperDecoderTorchFile, WhisperEncoderTRTEngine, WhisperDecoderTRTEngine

from NNDF.networks import TimingProfile
from NNDF.torch_utils import expand_inputs_for_beam_search

In [33]:
whisper_torch_encoder = WhisperEncoderTorchFile.TorchModule(whisper_model.model.encoder)
whisper_torch_decoder = WhisperDecoderTorchFile.TorchModule(
    whisper_model.model.decoder, whisper_model.proj_out, whisper_model.config
)

In [34]:
generated_ids = whisper_model.generate(inputs=audio_inputs.input_features.to('cuda'))

In [35]:
%%time
input_features = audio_inputs.input_features.to('cuda')

encoder_last_hidden_state, encoder_e2e_median_time = w_encoder_inference(
    whisper_torch_encoder, input_features, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=50)
)
encoder_e2e_median_time

CPU times: user 37.4 ms, sys: 273 µs, total: 37.7 ms
Wall time: 37.1 ms


0.0027338999789208174

In [36]:
input_ids = torch.tensor([[1, 1]]) * whisper_model.config.decoder_start_token_id


In [37]:
%%time
_, decoder_e2e_median_time = w_decoder_inference(
    whisper_torch_decoder, input_ids, encoder_last_hidden_state, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=50)
)
decoder_e2e_median_time

CPU times: user 57.1 ms, sys: 1.02 ms, total: 58.1 ms
Wall time: 57.6 ms


0.004369891015812755

#### Full model inference and benchmark

Next, we will try the T5 model for the task of translation from English to German.

For benchmarking purposes, we will employ a helper function `full_inference` which executes the inference repeatedly and measures end to end execution time. Let's take note of this execution time for comparison with TensorRT. 

In [38]:
from Whisper.WhisperModelConfig import WhisperModelTRTConfig, WhisperMetadata

In [39]:
import transformers

In [40]:
num_beams = 1
min_output_len =0 
max_output_len = whisper_model.config.max_length
tokenizer = processor.tokenizer

In [41]:
from NNDF.general_utils import measure_python_inference_code
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])

def percentile_print(timing):
    return ', '.join(['p{} {:.2f}ms'.format(timing_profile.percentile[i], p*1000) for i,p in enumerate(timing)])
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT).cuda()

# encoder-decoder inference 
with torch.no_grad():
    output_ids = whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)    
    outputs = processor.tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)    
outputs_hf = outputs

# timing
# FP32
whisper_model.float()
hf_nonkv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16, cuda 11.4 has cublas error that will fail in both cpu or cpu model for Whisper
# if not cuda_114_mode:
whisper_model= whisper_model.half()
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features.half(), max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features.half(), max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [42]:
# FP32
HF_KV=True
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
whisper_model.float()
whisper_torch_encoder = WhisperEncoderTorchFile.TorchModule(whisper_model.get_encoder())
whisper_torch_decoder = WhisperDecoderTorchFile.TorchModule(whisper_model.get_decoder(), whisper_model.proj_out, whisper_model.config)

with torch.no_grad():

    encoder_last_hidden_state, encoder_pytorch_time = w_encoder_inference(whisper_torch_encoder, input_features, timing_profile)
    _, decoder_pytorch_time = w_decoder_inference(whisper_torch_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids, full_pytorch_time = full_inference_greedy(whisper_torch_encoder,whisper_torch_decoder,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    else:
        output_ids, full_pytorch_time = full_inference_beam(whisper_torch_encoder,whisper_torch_decoder,input_features,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)    

outputs_pytorch = outputs

# # FP16
# if not cuda_114_mode:
whisper_model.half()
input_features= input_features.half()
whisper_torch_encoder_fp16 = WhisperEncoderTorchFile.TorchModule(whisper_model.get_encoder())
whisper_torch_decoder_fp16 = WhisperDecoderTorchFile.TorchModule(whisper_model.get_decoder(), whisper_model.proj_out, whisper_model.config)

with torch.no_grad():

    encoder_last_hidden_state, encoder_pytorch_time_fp16 = w_encoder_inference(whisper_torch_encoder_fp16, input_features, timing_profile)
    _, decoder_pytorch_time_fp16 = w_decoder_inference(whisper_torch_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_greedy(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    else:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_beam(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    outputs_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)    

outputs_pytorch_fp16 = outputs_fp16

In [43]:
# print
print(f'PyTorch FP32 Output identical to HF results? {outputs_pytorch == outputs_hf}')
print(f'PyTorch FP16 Output identical to HF results? {outputs_pytorch_fp16 == outputs_hf}')
print('\n')      
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Precision: FP32, Number of Beams: {num_beams}")
print(f"Encoder time: {encoder_pytorch_time}")
print(f"Decoder time: {decoder_pytorch_time}")
print(f"Full E2E time: {full_pytorch_time}")
print(f"Precision: FP16, Number of Beams: {num_beams}")
print(f"Encoder time: {encoder_pytorch_time_fp16}")
print(f"Decoder time: {decoder_pytorch_time_fp16}")
print(f"Full E2E time: {full_pytorch_time_fp16}")

PyTorch FP32 Output identical to HF results? False
PyTorch FP16 Output identical to HF results? False


Device: NVIDIA A100-SXM4-80GB
Precision: FP32, Number of Beams: 1
Encoder time: [0.0022491710260510445, 0.0022752098739147186]
Decoder time: [0.0034474600106477737, 0.0037134091835469007]
Full E2E time: [0.020401515997946262, 0.022245587082579732]
Precision: FP16, Number of Beams: 1
Encoder time: [0.0065692279022186995, 0.019419423071667552]
Decoder time: [0.00563500402495265, 0.0058245910331606865]
Full E2E time: [0.015776696149259806, 0.03169610211625695]


In [44]:
output_ids_fp16, full_pytorch_time_fp16 = full_inference_greedy(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)

In [45]:
processor.tokenizer.batch_decode(output_ids_fp16)

['<|endoftext|><|endoftext|>']

In [46]:
whisper_model.float()

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 384)
      (layers): ModuleList(
        (0): WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=384, out_features=384, bias=False)
            (v_proj): Linear(in_features=384, out_features=384, bias=True)
            (q_proj): Linear(in_features=384, out_features=384, bias=True)
            (out_proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (final_lay

<a id="2"></a>

## 2. Convert to ONNX

Prior to converting the model to a TensorRT engine, we will first convert the PyTorch model to an intermediate universal format.

ONNX is an open format for machine learning and deep learning models. It allows you to convert deep learning and machine learning models from different frameworks such as TensorFlow, PyTorch, MATLAB, Caffe, and Keras to a single format.

The steps to convert a PyTorch model to TensorRT are as follows:
- Convert the pretrained image segmentation PyTorch model into ONNX.
- Import the ONNX model into TensorRT.
- Apply optimizations and generate an engine.
- Perform inference on the GPU. 

For the Whisper model, we will convert the encoder and decoder seperately.

In [47]:
from NNDF.networks import NetworkMetadata, Precision
TRT_KV = False

wh_onnx_model_path = './models/{}/onnx'.format(Whisper_VARIANT)
!mkdir -p $wh_onnx_model_path

# FP32
whisper_model.float()
metadata = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=False), other=WhisperMetadata(kv_cache=TRT_KV))
trt_config = WhisperModelTRTConfig()
metadata_string = trt_config.get_metadata_string(metadata)

wh_encoder_onnx_model_fpath = metadata_string + "-encoder.onnx"
wh_decoder_onnx_model_fpath = metadata_string + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
whisper_torchfile_encoder = WhisperEncoderTorchFile(whisper_model.to('cpu'), metadata)
whisper_torchfile_decoder = WhisperDecoderTorchFile(whisper_model.to('cpu'), metadata)

onnx_whisper_encoder = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), force_overwrite=True)
onnx_whisper_decoder = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), force_overwrite=True)

# FP16
metadata_fp16 = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=True), other=WhisperMetadata(kv_cache=TRT_KV))
trt_config_fp16 = WhisperModelTRTConfig()
metadata_string_fp16 = trt_config.get_metadata_string(metadata_fp16)

wh_encoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-encoder.onnx"
wh_decoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
whisper_torchfile_encoder = WhisperEncoderTorchFile(whisper_model.to('cpu'), metadata)
whisper_torchfile_decoder = WhisperDecoderTorchFile(whisper_model.to('cpu'), metadata)

onnx_whisper_encoder_fp16 = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath_fp16), force_overwrite=True)
onnx_whisper_decoder_fp16 = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath_fp16), force_overwrite=True)

  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  if input_shape[-1] > 1:
  mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):


In [48]:
wh_decoder_onnx_model_fpath_fp16

'Whisper-tiny-fp16-decoder-with-lm-head.onnx'

<a id="3"></a>

## 3. Convert to TensorRT

Now we are ready to parse the ONNX encoder and decoder models and convert them to optimized TensorRT engines.

Since the models contains dynamic input shapes, we can specify a valid input range with a TensorRT optimization profile.

In [49]:
from Whisper.export import WhisperDecoderONNXFile, WhisperEncoderONNXFile
from polygraphy.backend.trt import Profile
from tensorrt import PreviewFeature

In [50]:
wh_tensorrt_model_path = './models/{}/tensorrt'.format(Whisper_VARIANT)
!mkdir -p wh_tensorrt_model_path
# Decoder optimization profiles
batch_size = 1
max_sequence_length = WhisperModelTRTConfig.MAX_SEQUENCE_LENGTH[Whisper_VARIANT]
decoder_profile = Profile()
decoder_profile.add(
    "input_ids",
    min=(batch_size * num_beams, 1),
    opt=(batch_size * num_beams, max_sequence_length // 2),
    max=(batch_size * num_beams, max_sequence_length),
)
decoder_profile.add(
    "encoder_hidden_states",
    min=(batch_size * num_beams, 1, max_sequence_length),
    opt=(batch_size * num_beams, 1500, max_sequence_length),
    max=(batch_size * num_beams, 1500, max_sequence_length),
)

# Encoder optimization profiles
encoder_profile = Profile()
encoder_profile.add(
    "input_features",
    min=(batch_size, 80, 3000),
    opt=(batch_size, 80, 3000),
    max=(batch_size, 80, 3000)
)

disable_preview_dynamic_shapes = False
engine_tag = f"bs{batch_size}"

In [51]:
force_write=True
engine_tag = f"bs{batch_size}"

if num_beams > 1:
    engine_tag += "-beam{}".format(num_beams)

preview_features = [PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if disable_preview_dynamic_shapes:
    engine_tag += "-noPreviewFasterDynamicShapes"
else:
    preview_features.append(PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)

# FP32
wh_encoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name) or force_write:
    whisper_trt_encoder_engine = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_encoder_engine_name, 
        profiles=[encoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine = WhisperEncoderTRTEngine(wh_encoder_engine_name, metadata)
    
if not os.path.exists(wh_decoder_engine_name) or force_write:
    whisper_trt_decoder_engine = WhisperDecoderONNXFile(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_decoder_engine_name, 
        profiles=[decoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_decoder_engine = WhisperDecoderTRTEngine(wh_decoder_engine_name, metadata)


[38;5;11m[W] onnx2trt_utils.cpp:377: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.[0m
[38;5;11m[W] It looks like some layers in the network have compute precision set, but precision constraints were not enabled. 
    Precision constraints must be set to 'prefer' or 'obey' for layer compute precision to take effect. 
    Note: Layers and their requested precisions were: {'encoder/layers.0/self_attn_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/ReduceMean_1': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Add': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Sqrt': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Div': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Mul': 'FLOAT', 'encoder/layers.0/final_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/final_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/final_layer_n

In [52]:
# FP16
wh_encoder_engine_name_fp16 = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name_fp16 = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name_fp16) or force_write:
    whisper_trt_encoder_engine_fp16 = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
        wh_encoder_engine_name_fp16, 
        profiles=[encoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine_fp16 = WhisperEncoderTRTEngine(wh_encoder_engine_name_fp16, metadata_fp16)
    
if not os.path.exists(wh_decoder_engine_name_fp16) or force_write:
    whisper_trt_decoder_engine_fp16 = WhisperDecoderONNXFile(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
        wh_decoder_engine_name_fp16, 
        profiles=[decoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_decoder_engine_fp16 = WhisperDecoderTRTEngine(wh_decoder_engine_name_fp16, metadata_fp16)

[38;5;11m[W] It looks like some layers in the network have compute precision set, but precision constraints were not enabled. 
    Precision constraints must be set to 'prefer' or 'obey' for layer compute precision to take effect. 
    Note: Layers and their requested precisions were: {'encoder/layers.0/self_attn_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/ReduceMean_1': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Add': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Sqrt': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Div': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Mul': 'FLOAT', 'encoder/layers.0/final_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/final_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/final_layer_norm/ReduceMean_1': 'FLOAT', 'encoder/layers.0/final_layer_norm/Add': 'FLOAT', 'encoder/layers.0/final_layer_norm/Sqrt': 'FLOAT', 'encoder/layers.0/final_layer_norm/Div': 'FLOAT', 

In [53]:
print(wh_encoder_onnx_model_fpath)
print(wh_decoder_onnx_model_fpath)
print(onnx_whisper_encoder)
print(onnx_whisper_decoder)
#onnx_whisper_encoder = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), force_overwrite=False)
#onnx_whisper_decoder = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), force_overwrite=False)

Whisper-tiny-encoder.onnx
Whisper-tiny-decoder-with-lm-head.onnx
<Whisper.export.WhisperEncoderONNXFile object at 0x7f69adffe0a0>
<Whisper.export.WhisperDecoderONNXFile object at 0x7f6acc312cd0>


# Whisper Tensorrt 

In [54]:
from transformers import AutoConfig
from Whisper.trt import WhisperTRTEncoder, WhisperTRTDecoder, TRTHFRunner

In [55]:
# Initialize TensorRT engines
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

# FP32
whisper_trt_encoder = WhisperTRTEncoder(whisper_trt_encoder_engine, metadata, trt_config, batch_size=batch_size)
whisper_trt_decoder = WhisperTRTDecoder(whisper_trt_decoder_engine, metadata, trt_config, batch_size=batch_size, num_beams=num_beams)

# FP16
whisper_trt_encoder_fp16 = WhisperTRTEncoder(whisper_trt_encoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size)
whisper_trt_decoder_fp16 = WhisperTRTDecoder(whisper_trt_decoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size, num_beams=num_beams)

In [56]:
%%time
encoder_last_hidden_state, encoder_trt_time = w_encoder_inference(
    whisper_trt_encoder, input_features, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
)
encoder_e2e_median_time

CPU times: user 35.4 ms, sys: 0 ns, total: 35.4 ms
Wall time: 35.1 ms


0.0027338999789208174

### End-to-End TensorRT Inference

In [392]:
batch_size = 1

In [393]:
max_output_len=384

In [394]:
tokenizer.decode(whisper_model.config.decoder_start_token_id)


'<|startoftranscript|>'

In [395]:
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
min_length = WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[Whisper_VARIANT]
logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(no_repeat_ngram_size),
    MinLengthLogitsProcessor(min_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)),
    ForcedBOSTokenLogitsProcessor(50258),
    ForcedEOSTokenLogitsProcessor(max_output_len, tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
]) # by checking HuggingFace's generate() implementation carefully, the default logits processor for Whisper has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can ensure identical results with raw HuggingFace

decoder_initial_input = torch.full(
    (batch_size, 1), 50258, dtype=torch.int32
).to('cuda')

if num_beams > 1:
    decoder_initial_input = expand_inputs_for_beam_search(decoder_initial_input, expand_size=num_beams)
    
# FP32
def e2e_trt():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)

            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder.greedy_search(
                input_ids=decoder_initial_input,
                encoder_hidden_states=encoder_outputs.last_hidden_state,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder.beam_search(
                input_ids=decoder_initial_input,
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids = e2e_trt()
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)
trt_time = measure_python_inference_code(e2e_trt, timing_profile)

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1].  Tensor sizes: [2]

In [396]:
StoppingCriteriaList.max_length.getter(0)

<property at 0x7f6b6e926450>

In [397]:
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import (
    EncoderNoRepeatNGramLogitsProcessor,
    ExponentialDecayLengthPenalty,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    ForceTokensLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitNormalization,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TypicalLogitsWarper,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from transformers.pytorch_utils import torch_int_div

max_length = max_output_len
pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id
output_scores = whisper_model.config.output_scores
output_attentions = whisper_model.config.output_attentions
output_hidden_states = encoder_outputs.last_hidden_state
return_dict_in_generate= whisper_model.config.return_dict_in_generate
synced_gpus=False

whisper_trt_decoder.prepare_inputs_for_generation

model_kwargs = {
    "encoder_hidden_states":encoder_outputs.last_hidden_state,
    "stopping_criteria":stopping_criteria,
    "logits_processor":logits_processor,
    "use_cache":metadata.other.kv_cache,
    "use_cuda":True
}

input_ids=decoder_initial_input

In [398]:
tokenizer.decode(decoder_initial_input[0])

'<|startoftranscript|>'

In [399]:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
    warnings.warn(
        "`max_length` is deprecated in this function, use"
        " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
        UserWarning,
    )
    stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
    return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
    encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
    encoder_hidden_states = (
        model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
    )

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    # prepare model inputs
    model_inputs = whisper_trt_decoder.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = whisper_trt_decoder(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # pre-process distribution
    next_tokens_scores = logits_processor(input_ids, next_token_logits)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_tokens_scores,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if whisper_trt_decoder.config.is_encoder_decoder else (outputs.attentions,)
            )
            if whisper_trt_decoder.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if whisper_trt_decoder.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )
    print(1)

    # argmax
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        if pad_token_id is None:
            raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
    print(2)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = whisper_trt_decoder._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=whisper_trt_decoder.config.is_encoder_decoder
    )
    print(3)

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
    print(4)

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            print(synced_gpus)
            print(5)

            break
        else:
            this_peer_finished = True
    #print(5)


1
2
3
4




RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1].  Tensor sizes: [2]

In [350]:
whisper_trt_decoder(
    **model_inputs,
    return_dict=True,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
)

Seq2SeqLMOutput(loss=None, logits=tensor([[[15.7015, 15.4863, 17.9366,  ..., 12.0131, 11.3766, 10.9509]]],
       device='cuda:0'), past_key_values=None, decoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, encoder_last_hidden_state=None, encoder_hidden_states=None, encoder_attentions=None)

In [352]:
tokenizer.decode([50257])

'<|endoftext|>'

In [351]:
model_inputs

{'input_ids': tensor([[50257]], device='cuda:0', dtype=torch.int32),
 'encoder_hidden_states': tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
          [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
          [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
          ...,
          [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
          [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
          [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
        device='cuda:0', grad_fn=<NativeLayerNormBackward0>)}

In [342]:
(next_tokens != eos_token_id).long()

tensor([0], device='cuda:0')

In [343]:
unfinished_sequences

tensor([0], device='cuda:0')

In [326]:
synced_gpus

False

In [318]:
eos_token_id

50257

In [319]:
unfinished_sequences.mul((next_tokens != eos_token_id).long())

tensor([0], device='cuda:0')

In [295]:
GreedySearchEncoderDecoderOutput(
    sequences=input_ids,
    scores=scores,
    encoder_attentions=encoder_attentions,
    encoder_hidden_states=encoder_hidden_states,
    decoder_attentions=decoder_attentions,
    cross_attentions=cross_attentions,
    decoder_hidden_states=decoder_hidden_states,
)

NameError: name 'GreedySearchEncoderDecoderOutput' is not defined

In [300]:
decoder_hidden_states

In [296]:
input_ids

tensor([[50257, 50257]], device='cuda:0')

In [None]:

if return_dict_in_generate:
    if whisper_trt_decoder.config.is_encoder_decoder:
        return GreedySearchEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return GreedySearchDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return input_ids

In [229]:
encoder_last_hidden_states

tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       device='cuda:0', grad_fn=<IndexSelectBackward0>)

In [230]:
encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_outputs.last_hidden_state, expand_size=num_beams)

In [231]:
whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)

In [232]:
decoder_output = whisper_trt_decoder.greedy_search(
    input_ids=decoder_initial_input,
    encoder_hidden_states=encoder_outputs.last_hidden_state,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=metadata.other.kv_cache,
    use_cuda=True
)

In [154]:
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [155]:
output_ids

tensor([[50257, 50257]], device='cuda:0')

In [157]:
tokenizer.bos_token_id

50257

In [133]:
encoder_outputs.last_hidden_state.shape

torch.Size([1, 1500, 384])

In [134]:
encoder_last_hidden_states.shape

torch.Size([1, 1500, 384])

In [116]:
outputs_trt

''

In [60]:

# FP16
def e2e_trt_fp16():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder_fp16(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)
            
            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder_fp16.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder_fp16.greedy_search(
                input_ids=decoder_initial_input,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder_fp16.beam_search(
                input_ids=decoder_initial_input,
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids_fp16 = e2e_trt_fp16()
outputs_trt_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)
trt_time_fp16 = measure_python_inference_code(e2e_trt_fp16, timing_profile)

In [115]:
outputs_trt_fp16

''

In [61]:
trt_time_fp16

[0.0152398650534451, 0.021135421004146338]

In [62]:
# print results and timing statistics
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Using engine: {metadata_string + '-' + engine_tag}")   
print(f'Output identical to HF results? {outputs_trt == outputs_hf}')
print(f"Precision: FP32")
print(f'TRT time: {trt_time}')
print()
print(f"Using engine: {metadata_string_fp16 + '-' + engine_tag}")   
print(f'Output identical to HF results? {outputs_trt_fp16 == outputs_hf}')
print(f"Precision: FP16")
print(f'TRT time: {trt_time_fp16}')

Device: NVIDIA A100-SXM4-80GB
Using engine: Whisper-tiny-bs1
Output identical to HF results? False
Precision: FP32
TRT time: [0.011139628943055868, 0.018518408993259072]

Using engine: Whisper-tiny-fp16-bs1
Output identical to HF results? False
Precision: FP16
TRT time: [0.0152398650534451, 0.021135421004146338]


In [63]:
%%time
a, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile)



CPU times: user 36.3 ms, sys: 88 µs, total: 36.4 ms
Wall time: 36.1 ms


In [64]:
decoder_trt_time

[0.003225164022296667, 0.0032994120847433805]

### Time Measurement of Encoder, Decoder, and Full E2E
We will benchmark the encoder, decoder, and full end-to-end as we did for HuggingFace before.

In [65]:
# encoder-decoder inference 
whisper_model.float()
whisper_model = whisper_model.cuda()

input_features = input_features.float().cuda()

with torch.no_grad():
    output_ids = whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)    
    outputs = tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)    
outputs_hf = outputs

# timing
# FP32
input_features = input_features.float().cuda()
hf_nonkv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16, cuda 11.4 has cublas error that will fail in both cpu or cpu model for BART
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [66]:
# FP32
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, input_features, timing_profile)
_, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time = full_inference_beam(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
    
print(f'Encoder time: {percentile_print(encoder_trt_time)}')
print(f'Decoder time: {percentile_print(decoder_trt_time)}')
print(f'Full E2E time: {percentile_print(full_trt_time)}')

# FP16
encoder_last_hidden_states, encoder_trt_time_fp16 = w_encoder_inference(whisper_trt_encoder_fp16, input_features, timing_profile)
_, decoder_trt_time_fp16 = w_decoder_inference(whisper_trt_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time_fp16 = full_inference_greedy(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time_fp16 = full_inference_beam(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
print(f'Encoder FP16 time: {percentile_print(encoder_trt_time_fp16)}')
print(f'Decoder FP16 time: {percentile_print(decoder_trt_time_fp16)}')
print(f'Full E2E FP16 time: {percentile_print(full_trt_time_fp16)}')

Encoder time: p50 3.84ms, p99 3.89ms
Decoder time: p50 1.29ms, p99 2.34ms
Full E2E time: p50 10.72ms, p99 18.58ms
Encoder FP16 time: p50 1.48ms, p99 3.80ms
Decoder FP16 time: p50 0.90ms, p99 1.72ms
Full E2E FP16 time: p50 15.43ms, p99 21.18ms


In [67]:
from tabulate import tabulate

data = [
    ['Framework', 'Precision', 'Encoder p50 (ms)', 'Decoder p50 (ms)', 'Full E2E p50 (ms)', 'Accuracy'],
    ['HuggingFace (w/o cache)', 'FP32', '-', '-', f'{hf_nonkv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP32', '-', '-', f'{hf_kv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/o cache)', 'FP16', '-', '-', f'{hf_nonkv_time_fp16[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP16', '-', '-', f'{hf_kv_time_fp16[0]*1000:.2f}', '-'],
    ['PyTorch', 'FP32', f'{encoder_pytorch_time[0]*1000:.2f}', f'{decoder_pytorch_time[0]*1000:.2f}', f'{full_pytorch_time[0]*1000:.2f}', outputs_pytorch == outputs_hf],
    ['PyTorch', 'FP16', f'{encoder_pytorch_time_fp16[0]*1000:.2f}', f'{decoder_pytorch_time_fp16[0]*1000:.2f}', f'{full_pytorch_time_fp16[0]*1000:.2f}', outputs_pytorch_fp16 == outputs_hf],
    ['TensorRT', 'FP32', f'{encoder_trt_time[0]*1000:.2f}', f'{decoder_trt_time[0]*1000:.2f}', f'{full_trt_time[0]*1000:.2f}', outputs_trt == outputs_hf],
    ['TensorRT', 'FP16', f'{encoder_trt_time_fp16[0]*1000:.2f}', f'{decoder_trt_time_fp16[0]*1000:.2f}', f'{full_trt_time_fp16[0]*1000:.2f}', outputs_trt_fp16 == outputs_hf],
]

print(tabulate(data, headers='firstrow', tablefmt='github'))

| Framework               | Precision   | Encoder p50 (ms)   | Decoder p50 (ms)   |   Full E2E p50 (ms) | Accuracy   |
|-------------------------|-------------|--------------------|--------------------|---------------------|------------|
| HuggingFace (w/o cache) | FP32        | -                  | -                  |              184.55 | -          |
| HuggingFace (w/ cache)  | FP32        | -                  | -                  |              142.76 | -          |
| HuggingFace (w/o cache) | FP16        | -                  | -                  |              176.84 | -          |
| HuggingFace (w/ cache)  | FP16        | -                  | -                  |              125.35 | -          |
| PyTorch                 | FP32        | 2.25               | 3.45               |               20.4  | False      |
| PyTorch                 | FP16        | 6.57               | 5.64               |               16.06 | False      |
| TensorRT                | FP32        | 3.84  

In [81]:
outputs_trt

''

In [68]:
%%time
from faster_whisper import WhisperModel

model_size = "tiny"

# Run on GPU with FP16
model = WhisperModel(model_size, device="cuda", compute_type="float16")

# or run on GPU with INT8
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_size, device="cpu", compute_type="int8")

segments, info = model.transcribe("korean_news.mp4", beam_size=5)

print("Detected language '%s' with probability %f" % (info.language, info.language_probability))

for segment in segments:
    print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))

Detected language 'ko' with probability 0.998047
[0.00s -> 5.50s]  제 6코 태풍 가능은 여전히 매우 강한 세력을 유지한 채 북서진하고 있습니다.
[5.50s -> 10.00s]  하지만 이동 속도가 점점들여져 거의 정체하는 모습입니다.
[10.00s -> 15.00s]  태풍은 동중국회의 머물나 동쪽으로 방향을 급격히 틀어 이동할 걸로 보입니다.
[15.00s -> 22.00s]  속도도 조금씩 빨라지면 다음 주 초중반에는 일본 교수 남쪽의 상까지 진출하겠습니다.
[22.00s -> 25.00s]  새력도 크게 약화하지는 않을 전망입니다.
[25.00s -> 28.00s]  태양 열령량도 좋은 적건을 보이고 있습니다.
[28.00s -> 36.00s]  따라서 북상하면서 급격히 약화되는 것이 아니라 어느 정도 세력은 계속해서 유지해서 북상할 것으로 예상되고 있습니다.
[36.00s -> 40.00s]  이후 진로는 유동적이지만 크게 두 가지 경로가 유력합니다.
[40.00s -> 46.00s]  갑자기 방향을 북쪽으로 들어 일본 열돌을 관통한 뒤 동일로 북상할 가능성이 있습니다.
[46.00s -> 52.00s]  반대로 일본의 상륙한 뒤 열돌을 따라 계속 동진하는 진로를 택할 수도 있습니다.
[52.00s -> 56.00s]  만일 동일로 올라온다면 다음 주 중후반점이 될 걸로 보입니다.
[56.00s -> 61.00s]  이 경우 우리나라 동쪽 지역이 태풍 직점 영향권에 들게 됩니다.
[61.00s -> 68.00s]  기상청은 주변 기약 개편화에 따라 태풍 진로가 최종 결정될 거라면 태풍에 대한 경계를 당부했습니다.
[68.00s -> 71.00s]  YTN 김민경입니다.
CPU times: user 5.52 s, sys: 4.48 s, total: 9.99 s
Wall time: 4.5 s


In [69]:
import numpy as np

In [70]:
dir(model.model)

['__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'align',
 'detect_language',
 'device',
 'device_index',
 'encode',
 'generate',
 'is_multilingual',
 'num_active_batches',
 'num_queued_batches',
 'num_workers']

In [71]:
import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Numpy array.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """
    resampler = av.audio.resampler.AudioResampler(
        format="s16",
        layout="mono" if not split_stereo else "stereo",
        rate=sampling_rate,
    )

    raw_buffer = io.BytesIO()
    dtype = None

    with av.open(input_file, metadata_errors="ignore") as container:
        frames = container.decode(audio=0)
        frames = _ignore_invalid_frames(frames)
        frames = _group_frames(frames, 500000)
        frames = _resample_frames(frames, resampler)

        for frame in frames:
            array = frame.to_ndarray()
            dtype = array.dtype
            raw_buffer.write(array)

    audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

    # Convert s16 back to f32.
    audio = audio.astype(np.float32) / 32768.0

    if split_stereo:
        left_channel = audio[0::2]
        right_channel = audio[1::2]
        return left_channel, right_channel

    return audio

def _ignore_invalid_frames(frames):
    iterator = iter(frames)

    while True:
        try:
            yield next(iterator)
        except StopIteration:
            break
        except av.error.InvalidDataError:
            continue


def _group_frames(frames, num_samples=None):
    fifo = av.audio.fifo.AudioFifo()

    for frame in frames:
        frame.pts = None  # Ignore timestamp check.
        fifo.write(frame)

        if num_samples is not None and fifo.samples >= num_samples:
            yield fifo.read()

    if fifo.samples > 0:
        yield fifo.read()


def _resample_frames(frames, resampler):
    # Add None to flush the resampler.
    for frame in itertools.chain(frames, [None]):
        yield from resampler.resample(frame)

In [72]:
from faster_whisper import feature_extractor


In [73]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [74]:
%%time
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, inputs['input_features'], timing_profile)


CPU times: user 453 ms, sys: 1.94 s, total: 2.39 s
Wall time: 47.9 ms


In [75]:
encoder_last_hidden_states.shape

torch.Size([1, 1500, 384])

In [76]:
a = whisper_model.model.encoder(inputs['input_features'].cuda())

In [77]:
input_ids.device

device(type='cpu')

In [None]:
wh_de = whisper_model.model.decoder(input_ids.cuda())

In [None]:
wh_de[0].shape

In [None]:
%%time

result, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )

In [None]:
metadata.other.kv_cache

In [None]:
from NNDF.tensorrt_utils import TRTNativeRunner


In [None]:
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
logits_processor = LogitsProcessorList(
    [
        NoRepeatNGramLogitsProcessor(no_repeat_ngram_size),
        MinLengthLogitsProcessor(
            min_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        ),
        ForcedBOSTokenLogitsProcessor(
            50258
        ),
        ForcedEOSTokenLogitsProcessor(
            max_output_len, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        ),
    ]
)  # by checking HuggingFace's generate() implementation carefully, the default logits processor for BART has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can get identical results with raw HuggingFace

decoder_input_ids = torch.full(
    (batch_size, 1),
    tokenizer.convert_tokens_to_ids(tokenizer.eos_token),
    dtype=torch.int32,
)

if False:
    decoder_input_ids = decoder_input_ids.to("cuda")
else:
    decoder_input_ids = decoder_input_ids.to("cpu")

def _e2e():
    with torch.no_grad():
        encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
        decoder_output_greedy = whisper_trt_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=False,
        )


In [None]:
def _e2e_trt():
        with torch.no_grad():
            encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
            whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(
                encoder_last_hidden_state
            )
            decoder_output_greedy = whisper_trt_decoder.greedy_search(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_last_hidden_state,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=False,
            )
        return decoder_output_greedy

In [None]:
measurement_function = _e2e
if isinstance(whisper_trt_decoder, TRTNativeRunner):
    whisper_trt_decoder.set_return_device("cuda" if False else "cpu")
    measurement_function = _e2e_trt

full_e2e_time = measure_python_inference_code(measurement_function, timing_profile)



In [None]:
return (measurement_function(), full_e2e_time)

In [None]:
tokenizer.decode(50257)

In [None]:
whisper_model._get_decoder_start_token_id()

In [None]:
tokenizer.convert_tokens_to_ids(tokenizer.eos_token)