In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
## Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
%load_ext autoreload
%autoreload 2

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# Accelerating HuggingFace Whisper Inference with TensorRT

Whisper is an encoder-decoder model that converts ASR problems into a speech-to-text format. More specifically, it does so by encoding speech in the input stream. This enables a single model to be trained supervised on a wide variety of Language

This notebook shows 3 easy steps to convert a [HuggingFace PyTorch Whisper model](https://huggingface.co/transformers/model_doc/whisper.html) to a TensorRT engine for high-performance inference.

1. [Download HuggingFace whisper model](#1)
1. [Convert to ONNX format](#2)
1. [Convert to TensorRT engine](#3)

## Prerequisite

Follow the instruction at https://github.com/NVIDIA/TensorRT to build the TensorRT-OSS docker container required to run this notebook.

Next, we install some extra dependencies.

In [2]:
# %%capture
# !pip3 install -r ../requirements.txt

**Note:** After this step, you should restart the Jupyter kernel for the change to take effect.

In [3]:
import os
import sys
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

import torch
import tensorrt as trt

# huggingface
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig
)

<a id="1"></a>

## 1. Download HuggingFace T5 model and Whisper model

First, we download the original HuggingFace PyTorch T5 model from HuggingFace model hubs, together with its associated tokernizer.

The T5 variants that are suported by TensorRT 8 are:  t5-small (60M), t5-base (220M), t5-large (770M), t5-3b(3B), t5-11b(11B)

In [4]:
import torch
from datasets import load_dataset

Whisper_VARIANT = "openai/whisper-tiny"    # choices: openai/whisper-tiny | openai/whisper-base | openai/whisper-small | openai/whisper-medium | openai/whisper-large-v2

processor = WhisperProcessor.from_pretrained(Whisper_VARIANT)
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT)
wh_config = WhisperConfig.from_pretrained(Whisper_VARIANT, use_cache = False)

In [5]:
tokenizer=processor.tokenizer

In [6]:
# save model locally
pytorch_model_dir = './models/{}/pytorch'.format(Whisper_VARIANT)
!mkdir -p $pytorch_model_dir

whisper_model.save_pretrained(pytorch_model_dir)
print("Pytorch Model saved to {}".format(pytorch_model_dir))

Pytorch Model saved to ./models/openai/whisper-tiny/pytorch


# Encoder output이 다름!!

In [7]:
import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Numpy array.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """
    resampler = av.audio.resampler.AudioResampler(
        format="s16",
        layout="mono" if not split_stereo else "stereo",
        rate=sampling_rate,
    )

    raw_buffer = io.BytesIO()
    dtype = None

    with av.open(input_file, metadata_errors="ignore") as container:
        frames = container.decode(audio=0)
        frames = _ignore_invalid_frames(frames)
        frames = _group_frames(frames, 500000)
        frames = _resample_frames(frames, resampler)

        for frame in frames:
            array = frame.to_ndarray()
            dtype = array.dtype
            raw_buffer.write(array)

    audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

    # Convert s16 back to f32.
    audio = audio.astype(np.float32) / 32768.0

    if split_stereo:
        left_channel = audio[0::2]
        right_channel = audio[1::2]
        return left_channel, right_channel

    return audio

def _ignore_invalid_frames(frames):
    iterator = iter(frames)

    while True:
        try:
            yield next(iterator)
        except StopIteration:
            break
        except av.error.InvalidDataError:
            continue


def _group_frames(frames, num_samples=None):
    fifo = av.audio.fifo.AudioFifo()

    for frame in frames:
        frame.pts = None  # Ignore timestamp check.
        fifo.write(frame)

        if num_samples is not None and fifo.samples >= num_samples:
            yield fifo.read()

    if fifo.samples > 0:
        yield fifo.read()


def _resample_frames(frames, resampler):
    # Add None to flush the resampler.
    for frame in itertools.chain(frames, [None]):
        yield from resampler.resample(frame)

In [179]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [185]:
encoder_outputs = whisper_model.get_encoder()(inputs['input_features'].cuda())

In [186]:
encoder_outputs.last_hidden_state

tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [187]:
decoder_start_token_id = whisper_model._get_decoder_start_token_id(None, 50258)
input_ids  = torch.ones((1, 1), dtype=torch.long, device='cuda') * decoder_start_token_id

In [188]:
whisper_model.get_decoder().max_source_positions

1500

In [189]:
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ko", task="transcribe", no_timestamps=True)

In [190]:
whisper_model.config.forced_decoder_ids

[(1, 50264), (2, 50359), (3, 50363)]

In [211]:
whisper_model.config.forced_eos_token_id

In [519]:
whisper_model.float().cuda()
hf_out = whisper_model.generate(inputs['input_features'].cuda())

In [520]:
processor.batch_decode(hf_out)

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> The main reason for the Japanese-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean-speaking Korean Korean-speaking Korean Korean-speaking Korean Korean-speaking Korean Korean-speaking Korean Korean Korean-speaking Korean Korean Korean Korean Korean Korean Korean Korean Korean Korean Korean Korean Korean Korean Korean

In [194]:
decoder_outputs = whisper_model.model.decoder(input_ids=input_ids.cuda(), encoder_hidden_states=encoder_outputs['last_hidden_state'].cuda())

In [195]:
hf_output = whisper_model.proj_out(decoder_outputs.last_hidden_state)

In [196]:
hf_output

tensor([[[-2.4892, -5.4267,  2.6655,  ...,  0.2883,  1.1339,  1.4177]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [197]:
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)


In [198]:
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(448)])
logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(3),
    MinLengthLogitsProcessor(0, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)),
    ForcedBOSTokenLogitsProcessor(50258),
    ForcedEOSTokenLogitsProcessor(448, tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
]) # by checking HuggingFace's generate() implementation carefully, the default logits processor for Whisper has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can ensure identical results with raw HuggingFace


In [199]:
encoder_outputs['last_hidden_state'].cuda()

tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [200]:
input_ids

tensor([[50258]], device='cuda:0')

In [201]:
inputs['input_features'].cuda().shape

torch.Size([1, 80, 3000])

In [202]:
encoder_outputs['last_hidden_state'].cuda().shape

torch.Size([1, 1500, 384])

In [203]:
# # legacy: users may modify the model configuration to control generation -- update the generation config
# # model attribute accordingly, if it was created from the model config
# if self.generation_config._from_model_config:
#     new_generation_config = GenerationConfig.from_model_config(self.config)
#     if new_generation_config != self.generation_config:
#         warnings.warn(
#             "You have modified the pretrained model configuration to control generation. This is a"
#             " deprecated strategy to control generation and will be removed soon, in a future version."
#             " Please use a generation configuration file (see"
#             " https://huggingface.co/docs/transformers/main_classes/text_generation )"
#         )
#         self.generation_config = new_generation_config
# generation_config = self.generation_config


In [204]:
logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(3),
    MinLengthLogitsProcessor(0, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)),
    ForcedEOSTokenLogitsProcessor(448, 50364)
])
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(448)])


In [205]:
tokenizer.convert_tokens_to_ids(tokenizer.eos_token)

50257

In [213]:
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ko", task="transcribe", no_timestamps=True)

In [245]:
tokenizer.bos_token_id

50257

In [246]:
bos_token_id = tokenizer.bos_token_id
num_beams = whisper_model.config.num_beams
length_penalty = whisper_model.config.length_penalty
early_stopping = whisper_model.config.early_stopping
num_beam_groups = whisper_model.config.num_beam_groups
do_sample = whisper_model.config.do_sample
num_return_sequences = whisper_model.config.num_return_sequences
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

In [258]:
inputs_tensor, model_input_name, model_kwargs = whisper_model._prepare_model_inputs(inputs, bos_token_id, model_kwargs)


In [272]:
whisper_model.config.forced_decoder_ids

[(1, 50264), (2, 50359), (3, 50363)]

In [266]:
model_kwargs['encoder_outputs']=encoder_outputs.last_hidden_state.cuda()

In [280]:
no_repeat_ngram_size

3

In [306]:
logits_processor = whisper_model._get_logits_processor(
    repetition_penalty = None,
    no_repeat_ngram_size=no_repeat_ngram_size,
    encoder_no_repeat_ngram_size=None,
    input_ids_seq_length=input_ids.shape[-1],
    encoder_input_ids=inputs_tensor,
    bad_words_ids=None,
    min_length=min_length,
    max_length=max_length,
    eos_token_id=eos_token_id,
    forced_bos_token_id=None,
    forced_eos_token_id=None,
    prefix_allowed_tokens_fn=None,
    num_beams=1,
    num_beam_groups=1,
    diversity_penalty=None,
    remove_invalid_values=None,
    exponential_decay_length_penalty=None,
    logits_processor=LogitsProcessorList(),
    renormalize_logits=None,
    suppress_tokens=None,
    begin_suppress_tokens=None,
    forced_decoder_ids=whisper_model.config.forced_decoder_ids,
)


In [309]:
logits_processor

[<transformers.generation_logits_process.NoRepeatNGramLogitsProcessor at 0x7f640513f3a0>,
 <transformers.generation_logits_process.SuppressTokensLogitsProcessor at 0x7f640513f400>,
 <transformers.generation_logits_process.SuppressTokensAtBeginLogitsProcessor at 0x7f640513f460>,
 <transformers.generation_logits_process.ForceTokensLogitsProcessor at 0x7f640513f4c0>]

In [327]:
ForceTokensLogitsProcessor (whisper_model.config.forced_decoder_ids)


<transformers.generation_logits_process.ForceTokensLogitsProcessor at 0x7f6407688fa0>

In [316]:
tokenizer.decode([50257])

'<|endoftext|>'

In [319]:
SuppressTokensAtBeginLogitsProcessor(whisper_model.config.begin_suppress_tokens, None)

<transformers.generation_logits_process.SuppressTokensAtBeginLogitsProcessor at 0x7f640522d220>

In [335]:
whisper_model.config.begin_suppress_tokens

[220, 50257]

In [338]:
SuppressTokensAtBeginLogitsProcessor()

TypeError: __init__() missing 2 required positional arguments: 'begin_suppress_tokens' and 'begin_index'

In [359]:
whisper_model._prepare_decoder_input_ids_for_generation(
    batch_size,
    decoder_start_token_id=decoder_start_token_id,
    bos_token_id=bos_token_id,
    model_kwargs=model_kwargs,
    device='cuda',
)

tensor([[50258]], device='cuda:0')

In [360]:
decoder_start_token_id

50258

In [None]:
input_ids = self._prepare_decoder_input_ids_for_generation(
                batch_size,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                model_kwargs=model_kwargs,
                device=inputs_tensor.device,
            )


In [340]:
input_ids_seq_length = input_ids.shape[-1]


In [347]:
begin_index = input_ids_seq_length
begin_index = begin_index if (input_ids_seq_length > 1 or whisper_model.config.forced_bos_token_id is None) else begin_index + 1
if whisper_model.config.forced_bos_token_id is not None:
    begin_index += whisper_model.config.forced_bos_token_id[-1][0]  # generation starts after the last token that is forced


In [351]:
input_ids.shape[-1]

1

In [345]:
begin_index

1

In [353]:
logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(3),
    SuppressTokensLogitsProcessor(whisper_model.config.suppress_tokens),
    SuppressTokensAtBeginLogitsProcessor(whisper_model.config.begin_suppress_tokens, begin_index), 
    ForceTokensLogitsProcessor (whisper_model.config.forced_decoder_ids)
])

In [354]:
%%time
# greedy_search(
#     input_ids,
#     logits_processor=logits_processor,
#     stopping_criteria=stopping_criteria,
#     pad_token_id=pad_token_id,
#     eos_token_id=eos_token_id,
#     output_scores=output_scores,
#     return_dict_in_generate=return_dict_in_generate,
#     synced_gpus=synced_gpus,
#     **model_kwargs,
# )
decoder_output = whisper_model.greedy_search(
    input_ids=input_ids.cuda(),
    encoder_outputs=encoder_outputs.last_hidden_state.cuda(),
   # stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    
    use_cache=False,
)

CPU times: user 679 ms, sys: 0 ns, total: 679 ms
Wall time: 679 ms


In [355]:
tokenizer.batch_decode(decoder_output)

['<|startoftranscript|><|ko|><|transcribe|><|notimestamps|> 제 6코 태풍 가능은 여전히 매우 강한 세력을 유지한 채 북서진하고 있습니다. 하지만 이동 속도가 점점 늘여져 거의 정체안 모습입니다. 태풄은 동중국회의 머물나 동쪽으로 방향을 급격히 틀어 이동할 걸로 보입니다. 속도도 조금씩 빨라지면 다음 주 초중반에는 일본 교수 남쪽의 상까지 진출하겠습니다. 새력도 크게 약화하지는 않을 전망입니다.<|endoftext|>']

In [325]:
processor.batch_decode(hf_out)

['<|startoftranscript|><|ko|><|transcribe|><|notimestamps|> 제 6코 태풍 가능은 여전히 매우 강한 세력을 유지한 채 북서진하고 있습니다. 하지만 이동 속도가 점점 늘여져 거의 정체안 모습입니다. 태풍은 동중국회의 머물나 동쪽으로 방향을 급격히 틀어 이동할 걸로 보입니다. 속도도 조금씩 빨라지면 다음 주 초중반에는 일본 교수 남쪽의 상까지 진출하겠습니다. 새력도 크게 약화하지는 않을 전망입니다.<|endoftext|>']

In [57]:
whisper_model._prepare_decoder_input_ids_for_generation(1)

tensor([[50258]], device='cuda:0')

### Inference with PyTorch model

Next, we will carry out inference with the PyTorch model.

#### Single example inference

In [521]:
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

audio_inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = audio_inputs.input_features

# WAR: Using an ugly representation because cuda 11.4 does not support GPU models due to cublas errors
if "LD_LIBRARY_PATH" in os.environ and "cuda-11.4" in os.environ["LD_LIBRARY_PATH"]:
    whisper_model = whisper_model.cpu()
    input_features = input_features.to('cpu')
else:
    whisper_model = whisper_model.cuda()
    input_features = input_features.to('cuda:0')   

Found cached dataset librispeech_asr_dummy (/home/nvadmin/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [522]:
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)

In [523]:
with torch.no_grad():
    generated_ids = whisper_model.generate(inputs=input_features)

transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription
# ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'

' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'

#### Model inference benchmark: encoder and decoder stacks

For benchmarking purposes, we will employ a helper functions `encoder_inference` and `decoder_inference` which execute the inference repeatedly for the T5 encoder and decoder stacks separately, and measure end to end execution time. Let's take note of this execution time for comparison with TensorRT. 
 
`TimingProfile` is a named tuple that specifies the number of experiments and number of times to call the function per iteration (and number of warm-up calls although it is not used here).

In [524]:
from Whisper.measurements import decoder_inference as w_decoder_inference, encoder_inference as w_encoder_inference, full_inference as w_full_inference, full_inference_greedy, full_inference_beam
from Whisper.export import WhisperEncoderTorchFile, WhisperDecoderTorchFile, WhisperEncoderTRTEngine, WhisperDecoderTRTEngine

from NNDF.networks import TimingProfile
from NNDF.torch_utils import expand_inputs_for_beam_search

In [525]:
whisper_torch_encoder = WhisperEncoderTorchFile.TorchModule(whisper_model.model.encoder)
whisper_torch_decoder = WhisperDecoderTorchFile.TorchModule(
    whisper_model.model.decoder, whisper_model.proj_out, whisper_model.config
)

In [526]:
generated_ids = whisper_model.generate(inputs=audio_inputs.input_features.to('cuda'))

In [527]:
%%time
input_features = audio_inputs.input_features.to('cuda')

encoder_last_hidden_state, encoder_e2e_median_time = w_encoder_inference(
    whisper_torch_encoder, input_features, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=50)
)
encoder_e2e_median_time

CPU times: user 47.1 ms, sys: 0 ns, total: 47.1 ms
Wall time: 46.8 ms


0.003658406203612685

In [528]:
input_ids = torch.tensor([[1, 1]]) * whisper_model.config.decoder_start_token_id


In [529]:
%%time
_, decoder_e2e_median_time = w_decoder_inference(
    whisper_torch_decoder, input_ids, encoder_last_hidden_state, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=50)
)
decoder_e2e_median_time

CPU times: user 73.4 ms, sys: 0 ns, total: 73.4 ms
Wall time: 73 ms


0.005992966936901212

#### Full model inference and benchmark

Next, we will try the T5 model for the task of translation from English to German.

For benchmarking purposes, we will employ a helper function `full_inference` which executes the inference repeatedly and measures end to end execution time. Let's take note of this execution time for comparison with TensorRT. 

In [530]:
from Whisper.WhisperModelConfig import WhisperModelTRTConfig, WhisperMetadata

In [531]:
import transformers

In [532]:
num_beams = 1
min_output_len =0 
max_output_len = whisper_model.config.max_length
tokenizer = processor.tokenizer

In [533]:
from NNDF.general_utils import measure_python_inference_code
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])

def percentile_print(timing):
    return ', '.join(['p{} {:.2f}ms'.format(timing_profile.percentile[i], p*1000) for i,p in enumerate(timing)])
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT).cuda()

# encoder-decoder inference 
with torch.no_grad():
    output_ids = whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)    
    outputs = processor.tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)    
outputs_hf = outputs

# timing
# FP32
whisper_model.float()
hf_nonkv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16, cuda 11.4 has cublas error that will fail in both cpu or cpu model for Whisper
# if not cuda_114_mode:
whisper_model= whisper_model.half()
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features.half(), max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features.half(), max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [535]:
forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)

In [536]:
# FP32
HF_KV=True
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
whisper_model.float()
whisper_torch_encoder = WhisperEncoderTorchFile.TorchModule(whisper_model.get_encoder())
whisper_torch_decoder = WhisperDecoderTorchFile.TorchModule(whisper_model.get_decoder(), whisper_model.proj_out, whisper_model.config)

with torch.no_grad():

    encoder_last_hidden_state, encoder_pytorch_time = w_encoder_inference(whisper_torch_encoder, input_features, timing_profile)
    _, decoder_pytorch_time = w_decoder_inference(whisper_torch_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids, full_pytorch_time = full_inference_greedy(whisper_torch_encoder,whisper_torch_decoder,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    else:
        output_ids, full_pytorch_time = full_inference_beam(whisper_torch_encoder,whisper_torch_decoder,input_features,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)    

outputs_pytorch = outputs

# # FP16
# if not cuda_114_mode:
whisper_model.half()
input_features= input_features.half()
whisper_torch_encoder_fp16 = WhisperEncoderTorchFile.TorchModule(whisper_model.get_encoder())
whisper_torch_decoder_fp16 = WhisperDecoderTorchFile.TorchModule(whisper_model.get_decoder(), whisper_model.proj_out, whisper_model.config)

with torch.no_grad():

    encoder_last_hidden_state, encoder_pytorch_time_fp16 = w_encoder_inference(whisper_torch_encoder_fp16, input_features, timing_profile)
    _, decoder_pytorch_time_fp16 = w_decoder_inference(whisper_torch_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_greedy(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    else:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_beam(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    outputs_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)    

outputs_pytorch_fp16 = outputs_fp16

In [537]:
outputs_pytorch_fp16

' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'

In [538]:
# print
print(f'PyTorch FP32 Output identical to HF results? {outputs_pytorch == outputs_hf}')
print(f'PyTorch FP16 Output identical to HF results? {outputs_pytorch_fp16 == outputs_hf}')
print('\n')      
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Precision: FP32, Number of Beams: {num_beams}")
print(f"Encoder time: {encoder_pytorch_time}")
print(f"Decoder time: {decoder_pytorch_time}")
print(f"Full E2E time: {full_pytorch_time}")
print(f"Precision: FP16, Number of Beams: {num_beams}")
print(f"Encoder time: {encoder_pytorch_time_fp16}")
print(f"Decoder time: {decoder_pytorch_time_fp16}")
print(f"Full E2E time: {full_pytorch_time_fp16}")

PyTorch FP32 Output identical to HF results? True
PyTorch FP16 Output identical to HF results? True


Device: NVIDIA A100-SXM4-80GB
Precision: FP32, Number of Beams: 1
Encoder time: [0.00220304518006742, 0.002242377959191799]
Decoder time: [0.0034144290257245302, 0.0037977269385010004]
Full E2E time: [0.08790272590704262, 0.2159281640779227]
Precision: FP16, Number of Beams: 1
Encoder time: [0.002915350953117013, 0.003663035109639168]
Decoder time: [0.003403657115995884, 0.0034532209392637014]
Full E2E time: [0.13136885385029018, 0.2274820499587804]


In [539]:
output_ids_fp16, full_pytorch_time_fp16 = full_inference_greedy(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)

In [540]:
processor.tokenizer.batch_decode(output_ids_fp16)

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']

<a id="2"></a>

## 2. Convert to ONNX

Prior to converting the model to a TensorRT engine, we will first convert the PyTorch model to an intermediate universal format.

ONNX is an open format for machine learning and deep learning models. It allows you to convert deep learning and machine learning models from different frameworks such as TensorFlow, PyTorch, MATLAB, Caffe, and Keras to a single format.

The steps to convert a PyTorch model to TensorRT are as follows:
- Convert the pretrained image segmentation PyTorch model into ONNX.
- Import the ONNX model into TensorRT.
- Apply optimizations and generate an engine.
- Perform inference on the GPU. 

For the Whisper model, we will convert the encoder and decoder seperately.

In [541]:
from NNDF.networks import NetworkMetadata, Precision
TRT_KV = False

wh_onnx_model_path = './models/{}/onnx'.format(Whisper_VARIANT)
!mkdir -p $wh_onnx_model_path

# FP32
whisper_model.float()
metadata = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=False), other=WhisperMetadata(kv_cache=TRT_KV))
trt_config = WhisperModelTRTConfig()
metadata_string = trt_config.get_metadata_string(metadata)

wh_encoder_onnx_model_fpath = metadata_string + "-encoder.onnx"
wh_decoder_onnx_model_fpath = metadata_string + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
whisper_torchfile_encoder = WhisperEncoderTorchFile(whisper_model.to('cpu'), metadata)
whisper_torchfile_decoder = WhisperDecoderTorchFile(whisper_model.to('cpu'), metadata)

onnx_whisper_encoder = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), force_overwrite=True)
onnx_whisper_decoder = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), force_overwrite=True)

# FP16
metadata_fp16 = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=True), other=WhisperMetadata(kv_cache=TRT_KV))
trt_config_fp16 = WhisperModelTRTConfig()
metadata_string_fp16 = trt_config.get_metadata_string(metadata_fp16)

wh_encoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-encoder.onnx"
wh_decoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
whisper_torchfile_encoder = WhisperEncoderTorchFile(whisper_model.to('cpu'), metadata)
whisper_torchfile_decoder = WhisperDecoderTorchFile(whisper_model.to('cpu'), metadata)

onnx_whisper_encoder_fp16 = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath_fp16), force_overwrite=True)
onnx_whisper_decoder_fp16 = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath_fp16), force_overwrite=True)

<a id="3"></a>

## 3. Convert to TensorRT

Now we are ready to parse the ONNX encoder and decoder models and convert them to optimized TensorRT engines.

Since the models contains dynamic input shapes, we can specify a valid input range with a TensorRT optimization profile.

In [543]:
from Whisper.export import WhisperDecoderONNXFile, WhisperEncoderONNXFile
from polygraphy.backend.trt import Profile
from tensorrt import PreviewFeature

In [544]:
wh_tensorrt_model_path = './models/{}/tensorrt'.format(Whisper_VARIANT)
!mkdir -p wh_tensorrt_model_path
# Decoder optimization profiles
batch_size = 1
max_sequence_length = WhisperModelTRTConfig.MAX_SEQUENCE_LENGTH[Whisper_VARIANT]
decoder_profile = Profile()
decoder_profile.add(
    "input_ids",
    min=(batch_size * num_beams, 1),
    opt=(batch_size * num_beams, max_sequence_length // 2),
    max=(batch_size * num_beams, max_sequence_length),
)
decoder_profile.add(
    "encoder_hidden_states",
    min=(batch_size * num_beams, 1, max_sequence_length),
    opt=(batch_size * num_beams, 1500, max_sequence_length),
    max=(batch_size * num_beams, 1500, max_sequence_length),
)

# Encoder optimization profiles
encoder_profile = Profile()
encoder_profile.add(
    "input_features",
    min=(batch_size, 80, 3000),
    opt=(batch_size, 80, 3000),
    max=(batch_size, 80, 3000)
)

disable_preview_dynamic_shapes = False
engine_tag = f"bs{batch_size}"

In [545]:
force_write=True
engine_tag = f"bs{batch_size}"

if num_beams > 1:
    engine_tag += "-beam{}".format(num_beams)

preview_features = [PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if disable_preview_dynamic_shapes:
    engine_tag += "-noPreviewFasterDynamicShapes"
else:
    preview_features.append(PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)

# FP32
wh_encoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name) or force_write:
    whisper_trt_encoder_engine = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_encoder_engine_name, 
        profiles=[encoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine = WhisperEncoderTRTEngine(wh_encoder_engine_name, metadata)
    
if not os.path.exists(wh_decoder_engine_name) or force_write:
    whisper_trt_decoder_engine = WhisperDecoderONNXFile(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_decoder_engine_name, 
        profiles=[decoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_decoder_engine = WhisperDecoderTRTEngine(wh_decoder_engine_name, metadata)


In [546]:
# FP16
wh_encoder_engine_name_fp16 = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name_fp16 = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name_fp16) or force_write:
    whisper_trt_encoder_engine_fp16 = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
        wh_encoder_engine_name_fp16, 
        profiles=[encoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine_fp16 = WhisperEncoderTRTEngine(wh_encoder_engine_name_fp16, metadata_fp16)
    
if not os.path.exists(wh_decoder_engine_name_fp16) or force_write:
    whisper_trt_decoder_engine_fp16 = WhisperDecoderONNXFile(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
        wh_decoder_engine_name_fp16, 
        profiles=[decoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_decoder_engine_fp16 = WhisperDecoderTRTEngine(wh_decoder_engine_name_fp16, metadata_fp16)

In [547]:
print(wh_encoder_onnx_model_fpath)
print(wh_decoder_onnx_model_fpath)
print(onnx_whisper_encoder)
print(onnx_whisper_decoder)
#onnx_whisper_encoder = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), force_overwrite=False)
#onnx_whisper_decoder = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), force_overwrite=False)

Whisper-tiny-encoder.onnx
Whisper-tiny-decoder-with-lm-head.onnx
<Whisper.export.WhisperEncoderONNXFile object at 0x7f6404e1e430>
<Whisper.export.WhisperDecoderONNXFile object at 0x7f62378f3fa0>


# Whisper Tensorrt 

In [548]:
from transformers import AutoConfig
from Whisper.trt import WhisperTRTEncoder, WhisperTRTDecoder, TRTHFRunner

In [549]:
# Initialize TensorRT engines
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

# FP32
whisper_trt_encoder = WhisperTRTEncoder(whisper_trt_encoder_engine, metadata, trt_config, batch_size=batch_size)
whisper_trt_decoder = WhisperTRTDecoder(whisper_trt_decoder_engine, metadata, trt_config, batch_size=batch_size, num_beams=num_beams)

# FP16
whisper_trt_encoder_fp16 = WhisperTRTEncoder(whisper_trt_encoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size)
whisper_trt_decoder_fp16 = WhisperTRTDecoder(whisper_trt_decoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size, num_beams=num_beams)

In [550]:
%%time
encoder_last_hidden_state, encoder_trt_time = w_encoder_inference(
    whisper_trt_encoder, input_features, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
)
encoder_e2e_median_time

CPU times: user 19.7 ms, sys: 0 ns, total: 19.7 ms
Wall time: 19.4 ms


0.003658406203612685

### End-to-End TensorRT Inference

In [578]:
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
min_length = WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[Whisper_VARIANT]
decoder_input_ids = torch.full(
    (batch_size, 1),
    WhisperModelTRTConfig.DECODER_START_TOKEN_ID,
)

forced_decoder_ids=processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
logits_processor = LogitsProcessorList(
    [
        NoRepeatNGramLogitsProcessor(no_repeat_ngram_size),
        SuppressTokensLogitsProcessor(WhisperModelTRTConfig.SUPPRESS_TOKENS),
        SuppressTokensAtBeginLogitsProcessor(
            WhisperModelTRTConfig.BEGIN_SUPPRESS_TOKENS, decoder_input_ids.shape[-1]
        ),
        ForceTokensLogitsProcessor(forced_decoder_ids),
    ]
)  # by checking HuggingFace's generate() implementation carefully, the default logits processor for BART has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can get identical results with raw HuggingFace

encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.cpu()
if num_beams > 1:
    decoder_input_ids = expand_inputs_for_beam_search(decoder_input_ids, expand_size=num_beams)
    
# FP32
def e2e_trt():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)

            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder.greedy_search(
                input_ids=decoder_input_ids.cuda(),
                encoder_hidden_states=encoder_outputs.last_hidden_state.cuda(),
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder.beam_search(
                input_ids=decoder_input_ids,
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids = e2e_trt()
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)
trt_time = measure_python_inference_code(e2e_trt, timing_profile)

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1].  Tensor sizes: [2]

In [576]:
decoder_output = whisper_trt_decoder.greedy_search(
    input_ids=decoder_input_ids.cuda(),
    encoder_hidden_states=encoder_outputs.last_hidden_state.cuda(),
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=metadata.other.kv_cache,
    use_cuda=True
)

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1].  Tensor sizes: [2]

In [569]:
encoder_outputs.last_hidden_state

tensor([[[ 0.1843,  0.1644,  0.1285,  ..., -0.0349,  0.0105, -0.0849],
         [ 0.9001,  2.4472,  0.7696,  ...,  0.8979, -0.0847,  1.0418],
         [ 0.6817,  2.8782,  1.2798,  ...,  0.5644, -0.6482,  0.9689],
         ...,
         [ 0.7794, -1.3334,  0.7304,  ...,  0.9130,  1.3428,  0.1077],
         [ 1.3128, -1.1221,  0.1789,  ..., -0.3750, -0.0723, -0.6635],
         [ 0.0936, -0.1169, -1.3959,  ..., -0.0278, -0.5646, -0.2211]]],
       grad_fn=<ToCopyBackward0>)

In [553]:
StoppingCriteriaList.max_length.getter(0)

<property at 0x7f640759e360>

In [None]:
outputs_trt

In [None]:
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import (
    EncoderNoRepeatNGramLogitsProcessor,
    ExponentialDecayLengthPenalty,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    ForceTokensLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitNormalization,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TypicalLogitsWarper,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from transformers.pytorch_utils import torch_int_div

max_length = max_output_len
pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id
output_scores = whisper_model.config.output_scores
output_attentions = whisper_model.config.output_attentions
output_hidden_states = encoder_outputs.last_hidden_state
return_dict_in_generate= whisper_model.config.return_dict_in_generate
synced_gpus=False

whisper_trt_decoder.prepare_inputs_for_generation

model_kwargs = {
    "encoder_hidden_states":whisper_trt_encoder(inputs['input_features']),
    "stopping_criteria":stopping_criteria,
    "logits_processor":logits_processor,
    "use_cache":metadata.other.kv_cache,
    "use_cuda":True
}
decoder_initial_input = torch.full(
    (batch_size, 1), 50257, dtype=torch.int32
).to('cuda')

input_ids=decoder_initial_input

In [None]:
model_inputs = whisper_trt_decoder.prepare_inputs_for_generation(input_ids, **model_kwargs)


In [None]:
outputs = whisper_trt_decoder(
    **model_inputs,
    return_dict=True,
    output_attentions=output_attentions,
    output_hidden_states=whisper_trt_encoder(inputs['input_features']),
)

In [None]:
outputs.logits.shape

In [None]:
next_token_logits = outputs.logits[:, -1, :]
next_tokens_scores = logits_processor(input_ids, next_token_logits)
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
tokenizer.decode(next_tokens)

In [None]:
next_token_logits

In [None]:
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
    warnings.warn(
        "`max_length` is deprecated in this function, use"
        " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
        UserWarning,
    )
    stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
    return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
    encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
    encoder_hidden_states = (
        model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
    )

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    # prepare model inputs
    model_inputs = whisper_trt_decoder.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = whisper_trt_decoder(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]
    print(outputs)

    # pre-process distribution
    next_tokens_scores = logits_processor(input_ids, next_toke_logits)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_tokens_scores,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if whisper_trt_decoder.config.is_encoder_decoder else (outputs.attentions,)
            )
            if whisper_trt_decoder.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if whisper_trt_decoder.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )
    print(1)

    # argmax
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)
    print(tokenizer.decode(next_tokens))

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        if pad_token_id is None:
            raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
    print(2)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = whisper_trt_decoder._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=whisper_trt_decoder.config.is_encoder_decoder
    )
    print(3)

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
    print(4)

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            print(synced_gpus)
            print(5)

            break
        else:
            this_peer_finished = True
    #print(5)


In [None]:
model_inputs['input_ids']

In [None]:
tokenizer.decode(input_ids[0])

In [None]:
whisper_trt_decoder(
    **model_inputs,
    return_dict=True,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
)

In [None]:
tokenizer.decode([50257])

In [None]:
model_inputs

In [None]:
(next_tokens != eos_token_id).long()

In [None]:
unfinished_sequences

In [None]:
synced_gpus

In [None]:
eos_token_id

In [None]:
unfinished_sequences.mul((next_tokens != eos_token_id).long())

In [None]:
# GreedySearchEncoderDecoderOutput(
#     sequences=input_ids,
#     scores=scores,
#     encoder_attentions=encoder_attentions,
#     encoder_hidden_states=encoder_hidden_states,
#     decoder_attentions=decoder_attentions,
#     cross_attentions=cross_attentions,
#     decoder_hidden_states=decoder_hidden_states,
# )

In [None]:
decoder_hidden_states

In [None]:
input_ids

In [None]:

# if return_dict_in_generate:
#     if whisper_trt_decoder.config.is_encoder_decoder:
#         return GreedySearchEncoderDecoderOutput(
#             sequences=input_ids,
#             scores=scores,
#             encoder_attentions=encoder_attentions,
#             encoder_hidden_states=encoder_hidden_states,
#             decoder_attentions=decoder_attentions,
#             cross_attentions=cross_attentions,
#             decoder_hidden_states=decoder_hidden_states,
#         )
#     else:
#         return GreedySearchDecoderOnlyOutput(
#             sequences=input_ids,
#             scores=scores,
#             attentions=decoder_attentions,
#             hidden_states=decoder_hidden_states,
#         )
# else:
#     return input_ids

In [None]:
encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_outputs.last_hidden_state, expand_size=num_beams)

In [None]:
whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)

In [None]:
decoder_output = whisper_trt_decoder.greedy_search(
    input_ids=decoder_initial_input,
    encoder_hidden_states=encoder_outputs.last_hidden_state,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=metadata.other.kv_cache,
    use_cuda=True
)

In [None]:
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
output_ids

In [None]:
tokenizer.bos_token_id

In [None]:
encoder_outputs.last_hidden_state.shape

In [None]:
encoder_last_hidden_states.shape

In [None]:
outputs_trt

In [136]:
# FP16
def e2e_trt_fp16():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder_fp16(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)
            
            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder_fp16.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states) 
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder_fp16.greedy_search(
                input_ids=decoder_initial_input,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder_fp16.beam_search(
                input_ids=decoder_initial_input,
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids_fp16 = e2e_trt_fp16()
outputs_trt_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)
trt_time_fp16 = measure_python_inference_code(e2e_trt_fp16, timing_profile)

In [137]:
decoder_output = whisper_trt_decoder_fp16.greedy_search(
    input_ids=decoder_initial_input,
    encoder_hidden_states=encoder_last_hidden_states,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=metadata.other.kv_cache,
    use_cuda=True
)

NameError: name 'encoder_last_hidden_states' is not defined

In [None]:
tokenizer.decode(decoder_output[0], skip_special_tokens=False)


In [None]:
tokenizer.decode(output_ids_fp16[0], skip_special_tokens=False)

In [None]:
outputs_trt_fp16

In [138]:
trt_time_fp16

[0.12557019293308258, 0.3864895810838789]

In [139]:
# print results and timing statistics
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Using engine: {metadata_string + '-' + engine_tag}")
print(f'Output identical to HF results? {outputs_trt == outputs_hf}')
print(f"Precision: FP32")
print(f'TRT time: {trt_time}')
print()
print(f"Using engine: {metadata_string_fp16 + '-' + engine_tag}")
print(f'Output identical to HF results? {outputs_trt_fp16 == outputs_hf}')
print(f"Precision: FP16")
print(f'TRT time: {trt_time_fp16}')

Device: NVIDIA A100-SXM4-80GB
Using engine: Whisper-tiny-bs1
Output identical to HF results? True
Precision: FP32
TRT time: [0.036946029867976904, 0.0820648071821779]

Using engine: Whisper-tiny-fp16-bs1
Output identical to HF results? False
Precision: FP16
TRT time: [0.12557019293308258, 0.3864895810838789]


In [140]:
%%time
a, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile)



CPU times: user 17.7 ms, sys: 489 µs, total: 18.2 ms
Wall time: 17.9 ms


In [141]:
decoder_trt_time

[0.0010103480890393257, 0.0036271950230002403]

### Time Measurement of Encoder, Decoder, and Full E2E
We will benchmark the encoder, decoder, and full end-to-end as we did for HuggingFace before.

In [142]:
# encoder-decoder inference 
whisper_model.float()
whisper_model = whisper_model.cuda()

input_features = input_features.float().cuda()

with torch.no_grad():
    output_ids = whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)    
    outputs = tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)    
outputs_hf = outputs

# timing
# FP32
input_features = input_features.float().cuda()
hf_nonkv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16, cuda 11.4 has cublas error that will fail in both cpu or cpu model for BART
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [143]:
%%time
decoder_outputs = whisper_model.model.decoder(input_ids=input_ids.cuda(), encoder_hidden_states=whisper_model.get_encoder()(input_features)['last_hidden_state'])

CPU times: user 3.9 ms, sys: 3.99 ms, total: 7.89 ms
Wall time: 7.45 ms


In [144]:
%%time
#encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, input_features, timing_profile)
decoder_outputs2 = whisper_model.model.decoder(input_ids=input_ids.cuda(), encoder_hidden_states=encoder_last_hidden_states)

NameError: name 'encoder_last_hidden_states' is not defined

In [145]:
#dir(encoder_last_hidden_states[0][0][0].cpu())

In [146]:
#encoder_last_hidden_states[0][0]

In [147]:
# FP32
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, input_features, timing_profile)
_, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time = full_inference_beam(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
    
print(f'Encoder time: {percentile_print(encoder_trt_time)}')
print(f'Decoder time: {percentile_print(decoder_trt_time)}')
print(f'Full E2E time: {percentile_print(full_trt_time)}')

# FP16
encoder_last_hidden_states, encoder_trt_time_fp16 = w_encoder_inference(whisper_trt_encoder_fp16, input_features, timing_profile)
_, decoder_trt_time_fp16 = w_decoder_inference(whisper_trt_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time_fp16 = full_inference_greedy(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time_fp16 = full_inference_beam(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
print(f'Encoder FP16 time: {percentile_print(encoder_trt_time_fp16)}')
print(f'Decoder FP16 time: {percentile_print(decoder_trt_time_fp16)}')
print(f'Full E2E FP16 time: {percentile_print(full_trt_time_fp16)}')

Encoder time: p50 1.48ms, p99 1.49ms
Decoder time: p50 0.99ms, p99 1.01ms
Full E2E time: p50 144.67ms, p99 147.17ms
Encoder FP16 time: p50 1.49ms, p99 1.49ms
Decoder FP16 time: p50 0.87ms, p99 0.88ms
Full E2E FP16 time: p50 110.37ms, p99 123.43ms


In [148]:
from tabulate import tabulate

data = [
    ['Framework', 'Precision', 'Encoder p50 (ms)', 'Decoder p50 (ms)', 'Full E2E p50 (ms)', 'Accuracy'],
    ['HuggingFace (w/o cache)', 'FP32', '-', '-', f'{hf_nonkv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP32', '-', '-', f'{hf_kv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/o cache)', 'FP16', '-', '-', f'{hf_nonkv_time_fp16[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP16', '-', '-', f'{hf_kv_time_fp16[0]*1000:.2f}', '-'],
    ['PyTorch', 'FP32', f'{encoder_pytorch_time[0]*1000:.2f}', f'{decoder_pytorch_time[0]*1000:.2f}', f'{full_pytorch_time[0]*1000:.2f}', outputs_pytorch == outputs_hf],
    ['PyTorch', 'FP16', f'{encoder_pytorch_time_fp16[0]*1000:.2f}', f'{decoder_pytorch_time_fp16[0]*1000:.2f}', f'{full_pytorch_time_fp16[0]*1000:.2f}', outputs_pytorch_fp16 == outputs_hf],
    ['TensorRT', 'FP32', f'{encoder_trt_time[0]*1000:.2f}', f'{decoder_trt_time[0]*1000:.2f}', f'{full_trt_time[0]*1000:.2f}', outputs_trt == outputs_hf],
    ['TensorRT', 'FP16', f'{encoder_trt_time_fp16[0]*1000:.2f}', f'{decoder_trt_time_fp16[0]*1000:.2f}', f'{full_trt_time_fp16[0]*1000:.2f}', outputs_trt_fp16 == outputs_hf],
]

print(tabulate(data, headers='firstrow', tablefmt='github'))

| Framework               | Precision   | Encoder p50 (ms)   | Decoder p50 (ms)   |   Full E2E p50 (ms) | Accuracy   |
|-------------------------|-------------|--------------------|--------------------|---------------------|------------|
| HuggingFace (w/o cache) | FP32        | -                  | -                  |              100.22 | -          |
| HuggingFace (w/ cache)  | FP32        | -                  | -                  |               73.68 | -          |
| HuggingFace (w/o cache) | FP16        | -                  | -                  |               97.69 | -          |
| HuggingFace (w/ cache)  | FP16        | -                  | -                  |               73    | -          |
| PyTorch                 | FP32        | 2.19               | 3.22               |              390.05 | False      |
| PyTorch                 | FP16        | 3.95               | 3.35               |              293.63 | False      |
| TensorRT                | FP32        | 1.48  

In [160]:
outputs_trt

''

In [None]:
outputs_trt

In [None]:
%%time
from faster_whisper import WhisperModel

model_size = "tiny"

# Run on GPU with FP16
model = WhisperModel(model_size, device="cuda", compute_type="float16")

# or run on GPU with INT8
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
# or run on CPU with INT8
# model = WhisperModel(model_size, device="cpu", compute_type="int8")

segments, info = model.transcribe("korean_news.mp4", beam_size=5)

print("Detected language '%s' with probability %f" % (info.language, info.language_probability))

for segment in segments:
    print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))

In [None]:
import numpy as np

In [None]:
dir(model.model)

In [None]:
import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Numpy array.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """
    resampler = av.audio.resampler.AudioResampler(
        format="s16",
        layout="mono" if not split_stereo else "stereo",
        rate=sampling_rate,
    )

    raw_buffer = io.BytesIO()
    dtype = None

    with av.open(input_file, metadata_errors="ignore") as container:
        frames = container.decode(audio=0)
        frames = _ignore_invalid_frames(frames)
        frames = _group_frames(frames, 500000)
        frames = _resample_frames(frames, resampler)

        for frame in frames:
            array = frame.to_ndarray()
            dtype = array.dtype
            raw_buffer.write(array)

    audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

    # Convert s16 back to f32.
    audio = audio.astype(np.float32) / 32768.0

    if split_stereo:
        left_channel = audio[0::2]
        right_channel = audio[1::2]
        return left_channel, right_channel

    return audio

def _ignore_invalid_frames(frames):
    iterator = iter(frames)

    while True:
        try:
            yield next(iterator)
        except StopIteration:
            break
        except av.error.InvalidDataError:
            continue


def _group_frames(frames, num_samples=None):
    fifo = av.audio.fifo.AudioFifo()

    for frame in frames:
        frame.pts = None  # Ignore timestamp check.
        fifo.write(frame)

        if num_samples is not None and fifo.samples >= num_samples:
            yield fifo.read()

    if fifo.samples > 0:
        yield fifo.read()


def _resample_frames(frames, resampler):
    # Add None to flush the resampler.
    for frame in itertools.chain(frames, [None]):
        yield from resampler.resample(frame)

In [None]:
from faster_whisper import feature_extractor


In [None]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [None]:
%%time
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, inputs['input_features'], timing_profile)


In [None]:
encoder_last_hidden_states.shape

In [None]:
a = whisper_model.model.encoder(inputs['input_features'].cuda())

In [None]:
input_ids.device

In [None]:
wh_de = whisper_model.model.decoder(input_ids.cuda())

In [None]:
wh_de[0].shape

In [None]:
%%time

result, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )

In [None]:
metadata.other.kv_cache

In [None]:
from NNDF.tensorrt_utils import TRTNativeRunner


In [None]:
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
logits_processor = LogitsProcessorList(
    [
        NoRepeatNGramLogitsProcessor(no_repeat_ngram_size),
        MinLengthLogitsProcessor(
            min_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        ),
        ForcedBOSTokenLogitsProcessor(
            50258
        ),
        ForcedEOSTokenLogitsProcessor(
            max_output_len, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)
        ),
    ]
)  # by checking HuggingFace's generate() implementation carefully, the default logits processor for BART has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can get identical results with raw HuggingFace

decoder_input_ids = torch.full(
    (batch_size, 1),
    tokenizer.convert_tokens_to_ids(tokenizer.eos_token),
    dtype=torch.int32,
)

if False:
    decoder_input_ids = decoder_input_ids.to("cuda")
else:
    decoder_input_ids = decoder_input_ids.to("cpu")

def _e2e():
    with torch.no_grad():
        encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
        decoder_output_greedy = whisper_trt_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=False,
        )


In [None]:
def _e2e_trt():
        with torch.no_grad():
            encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
            whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(
                encoder_last_hidden_state
            )
            decoder_output_greedy = whisper_trt_decoder.greedy_search(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_last_hidden_state,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=False,
            )
        return decoder_output_greedy

In [None]:
measurement_function = _e2e
if isinstance(whisper_trt_decoder, TRTNativeRunner):
    whisper_trt_decoder.set_return_device("cuda" if False else "cpu")
    measurement_function = _e2e_trt

full_e2e_time = measure_python_inference_code(measurement_function, timing_profile)



In [None]:
#return (measurement_function(), full_e2e_time)

In [None]:
tokenizer.decode(50257)

In [None]:
whisper_model._get_decoder_start_token_id()

In [None]:
tokenizer.convert_tokens_to_ids(tokenizer.eos_token)