In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
## Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
%load_ext autoreload
%autoreload 2

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# Accelerating HuggingFace Whisper Inference with TensorRT

Whisper is an encoder-decoder model that converts ASR problems into a speech-to-text format. More specifically, it does so by encoding speech in the input stream. This enables a single model to be trained supervised on a wide variety of Language

This notebook shows 3 easy steps to convert a [HuggingFace PyTorch Whisper model](https://huggingface.co/transformers/model_doc/whisper.html) to a TensorRT engine for high-performance inference.

1. [Download HuggingFace whisper model](#1)
1. [Convert to ONNX format](#2)
1. [Convert to TensorRT engine](#3)

## Prerequisite

Follow the instruction at https://github.com/NVIDIA/TensorRT to build the TensorRT-OSS docker container required to run this notebook.

Next, we install some extra dependencies.

In [2]:
# %%capture
# !pip3 install -r ../requirements.txt

**Note:** After this step, you should restart the Jupyter kernel for the change to take effect.

In [3]:
import os
import sys
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

import torch
import tensorrt as trt

# huggingface
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig
)

<a id="1"></a>

## 1. Download HuggingFace T5 model and Whisper model

First, we download the original HuggingFace PyTorch T5 model from HuggingFace model hubs, together with its associated tokernizer.

The T5 variants that are suported by TensorRT 8 are:  t5-small (60M), t5-base (220M), t5-large (770M), t5-3b(3B), t5-11b(11B)

In [4]:
import torch
from datasets import load_dataset

Whisper_VARIANT = "openai/whisper-large-v2"    # choices: openai/whisper-tiny | openai/whisper-base | openai/whisper-small | openai/whisper-medium | openai/whisper-large-v2

processor = WhisperProcessor.from_pretrained(Whisper_VARIANT)
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT)
wh_config = WhisperConfig.from_pretrained(Whisper_VARIANT, use_cache = False)

In [5]:
tokenizer=processor.tokenizer

In [6]:
# save model locally
pytorch_model_dir = './models/{}/pytorch'.format(Whisper_VARIANT)
!mkdir -p $pytorch_model_dir

whisper_model.save_pretrained(pytorch_model_dir)
print("Pytorch Model saved to {}".format(pytorch_model_dir))

Pytorch Model saved to ./models/openai/whisper-large-v2/pytorch


# Encoder output이 다름!!

In [7]:
import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Numpy array.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """
    resampler = av.audio.resampler.AudioResampler(
        format="s16",
        layout="mono" if not split_stereo else "stereo",
        rate=sampling_rate,
    )

    raw_buffer = io.BytesIO()
    dtype = None

    with av.open(input_file, metadata_errors="ignore") as container:
        frames = container.decode(audio=0)
        frames = _ignore_invalid_frames(frames)
        frames = _group_frames(frames, 500000)
        frames = _resample_frames(frames, resampler)

        for frame in frames:
            array = frame.to_ndarray()
            dtype = array.dtype
            raw_buffer.write(array)

    audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

    # Convert s16 back to f32.
    audio = audio.astype(np.float32) / 32768.0

    if split_stereo:
        left_channel = audio[0::2]
        right_channel = audio[1::2]
        return left_channel, right_channel

    return audio

def _ignore_invalid_frames(frames):
    iterator = iter(frames)

    while True:
        try:
            yield next(iterator)
        except StopIteration:
            break
        except av.error.InvalidDataError:
            continue


def _group_frames(frames, num_samples=None):
    fifo = av.audio.fifo.AudioFifo()

    for frame in frames:
        frame.pts = None  # Ignore timestamp check.
        fifo.write(frame)

        if num_samples is not None and fifo.samples >= num_samples:
            yield fifo.read()

    if fifo.samples > 0:
        yield fifo.read()


def _resample_frames(frames, resampler):
    # Add None to flush the resampler.
    for frame in itertools.chain(frames, [None]):
        yield from resampler.resample(frame)

In [8]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [9]:
encoder_outputs = whisper_model.get_encoder()(inputs['input_features'])

In [10]:
encoder_outputs

BaseModelOutput(last_hidden_state=tensor([[[-0.3705,  0.0469, -0.2255,  ..., -3.7427, -0.1800,  0.1537],
         [-0.3257,  0.2098, -0.0563,  ..., -3.8935,  0.1012,  0.0075],
         [-0.0417,  0.3738,  0.3748,  ..., -3.8021,  0.2297, -0.1170],
         ...,
         [ 0.3712,  0.1128, -0.5410,  ...,  0.4811,  0.7378,  1.6581],
         [ 0.4409,  0.5879, -0.0911,  ...,  0.1291, -0.2723,  0.6060],
         [ 0.1199,  0.5981,  0.6376,  ..., -0.2903, -0.0964,  0.5309]]],
       grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None)

In [11]:
decoder_start_token_id = whisper_model._get_decoder_start_token_id(None, 50258)
input_ids  = torch.ones((1, 1), dtype=torch.long, device='cuda') * decoder_start_token_id

In [12]:
whisper_model.get_decoder().max_source_positions

1500

In [13]:
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ko", task="transcribe", no_timestamps=True)

In [14]:
whisper_model.float().cuda(1)
hf_out = whisper_model.generate(inputs['input_features'].cuda(1))



In [15]:
tokenizer.batch_decode(hf_out)

['<|startoftranscript|><|ko|><|transcribe|><|notimestamps|> 제6호 태풍 카누는 여전히 매우 강한 세력을 유지한 채 북서지나고 있습니다. 하지만 이동 속도가 점점 느려져 거의 정체하는 모습입니다. 태풍은 동중국해에 머물다 동쪽으로 방향을 급격히 틀어 이동할 걸로 보입니다. 속도도 조금씩 빨라지며 다음 주 초중반에는 일본 규슈 남쪽 해상까지 진출하겠습니다. 세력도 크게 약화하지는 않을 전망입니다.<|endoftext|>']

In [16]:
decoder_outputs = whisper_model.model.decoder(input_ids=input_ids.cuda(1), encoder_hidden_states=encoder_outputs['last_hidden_state'].cuda(1))

In [17]:
hf_output = whisper_model.proj_out(decoder_outputs.last_hidden_state)

In [18]:
# # legacy: users may modify the model configuration to control generation -- update the generation config
# # model attribute accordingly, if it was created from the model config
# if self.generation_config._from_model_config:
#     new_generation_config = GenerationConfig.from_model_config(self.config)
#     if new_generation_config != self.generation_config:
#         warnings.warn(
#             "You have modified the pretrained model configuration to control generation. This is a"
#             " deprecated strategy to control generation and will be removed soon, in a future version."
#             " Please use a generation configuration file (see"
#             " https://huggingface.co/docs/transformers/main_classes/text_generation )"
#         )
#         self.generation_config = new_generation_config
# generation_config = self.generation_config


In [19]:
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    LogitsProcessorList,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    ForceTokensLogitsProcessor,
)

from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)

In [20]:
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ko", task="transcribe", no_timestamps=True)

In [21]:
bos_token_id = tokenizer.bos_token_id
num_beams = whisper_model.config.num_beams
length_penalty = whisper_model.config.length_penalty
early_stopping = whisper_model.config.early_stopping
num_beam_groups = whisper_model.config.num_beam_groups
do_sample = whisper_model.config.do_sample
num_return_sequences = whisper_model.config.num_return_sequences
pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id

In [22]:
#inputs_tensor, model_input_name, model_kwargs = whisper_model._prepare_model_inputs(inputs, bos_token_id, model_kwargs)


In [23]:
stopping_criteria = whisper_model._get_stopping_criteria(
    max_length=whisper_model.config.max_length, max_time=None, stopping_criteria=StoppingCriteriaList()
)

In [24]:
input_ids_seq_length = input_ids.shape[-1]


begin_index = input_ids_seq_length
begin_index = begin_index if (input_ids_seq_length > 1 or whisper_model.config.forced_bos_token_id is None) else begin_index + 1
if whisper_model.config.forced_bos_token_id is not None:
    begin_index += whisper_model.config.forced_bos_token_id[-1][0]  # generation starts after the last token that is forced


In [25]:
logits_processor = LogitsProcessorList([
    SuppressTokensLogitsProcessor(whisper_model.config.suppress_tokens),
    SuppressTokensAtBeginLogitsProcessor(whisper_model.config.begin_suppress_tokens, begin_index), 
    ForceTokensLogitsProcessor (processor.get_decoder_prompt_ids(language="ko", task="transcribe", no_timestamps=True))
])

In [26]:
%%time
# greedy_search(
#     input_ids,
#     logits_processor=logits_processor,
#     stopping_criteria=stopping_criteria,
#     pad_token_id=pad_token_id,
#     eos_token_id=eos_token_id,
#     output_scores=output_scores,
#     return_dict_in_generate=return_dict_in_generate,
#     synced_gpus=synced_gpus,
#     **model_kwargs,
# )
decoder_output = whisper_model.greedy_search(
    input_ids=input_ids.cuda(1),
    encoder_outputs=encoder_outputs.last_hidden_state.cuda(1),
   # stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    stopping_criteria=stopping_criteria,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    use_cache=False,
)

CPU times: user 6.66 s, sys: 268 ms, total: 6.93 s
Wall time: 6.93 s


In [27]:
tokenizer.batch_decode(decoder_output)

['<|startoftranscript|><|ko|><|transcribe|><|notimestamps|> 제6호 태풍 카누는 여전히 매우 강한 세력을 유지한 채 북서지나고 있습니다. 하지만 이동 속도가 점점 느려져 거의 정체하는 모습입니다. 태풍은 동중국해에 머물다 동쪽으로 방향을 급격히 틀어 이동할 걸로 보입니다. 속도도 조금씩 빨라지며 다음 주 초중반에는 일본 규슈 남쪽 해상까지 진출하겠습니다. 세력도 크게 약화하지는 않을 전망입니다.<|endoftext|>']

In [28]:
processor.batch_decode(hf_out)

['<|startoftranscript|><|ko|><|transcribe|><|notimestamps|> 제6호 태풍 카누는 여전히 매우 강한 세력을 유지한 채 북서지나고 있습니다. 하지만 이동 속도가 점점 느려져 거의 정체하는 모습입니다. 태풍은 동중국해에 머물다 동쪽으로 방향을 급격히 틀어 이동할 걸로 보입니다. 속도도 조금씩 빨라지며 다음 주 초중반에는 일본 규슈 남쪽 해상까지 진출하겠습니다. 세력도 크게 약화하지는 않을 전망입니다.<|endoftext|>']

### Inference with PyTorch model

Next, we will carry out inference with the PyTorch model.

#### Single example inference

In [29]:
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

audio_inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = audio_inputs.input_features

# WAR: Using an ugly representation because cuda 11.4 does not support GPU models due to cublas errors
if "LD_LIBRARY_PATH" in os.environ and "cuda-11.4" in os.environ["LD_LIBRARY_PATH"]:
    whisper_model = whisper_model.cpu()
    input_features = input_features.to('cpu')
else:
    whisper_model = whisper_model.cuda()
    input_features = input_features.to('cuda:1')   

Found cached dataset librispeech_asr_dummy (/home/jisu/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [30]:
#input_features = inputs['input_features'].cuda(1)

In [31]:
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)

In [32]:
whisper_model.cuda(1)
with torch.no_grad():
    generated_ids = whisper_model.generate(inputs=input_features)

transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription
# ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'



' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'

#### Model inference benchmark: encoder and decoder stacks

For benchmarking purposes, we will employ a helper functions `encoder_inference` and `decoder_inference` which execute the inference repeatedly for the T5 encoder and decoder stacks separately, and measure end to end execution time. Let's take note of this execution time for comparison with TensorRT. 
 
`TimingProfile` is a named tuple that specifies the number of experiments and number of times to call the function per iteration (and number of warm-up calls although it is not used here).

In [33]:
from Whisper.measurements import decoder_inference as w_decoder_inference, encoder_inference as w_encoder_inference, full_inference as w_full_inference, full_inference_greedy, full_inference_beam
from Whisper.export import WhisperEncoderTorchFile, WhisperDecoderTorchFile, WhisperEncoderTRTEngine, WhisperDecoderTRTEngine

from NNDF.networks import TimingProfile
from NNDF.torch_utils import expand_inputs_for_beam_search

In [34]:
whisper_torch_encoder = WhisperEncoderTorchFile.TorchModule(whisper_model.model.encoder)
whisper_torch_decoder = WhisperDecoderTorchFile.TorchModule(
    whisper_model.model.decoder, whisper_model.proj_out, whisper_model.config
)

In [35]:
generated_ids = whisper_model.generate(inputs=input_features)

In [36]:
%%time
input_features = input_features

encoder_last_hidden_state, encoder_e2e_median_time = w_encoder_inference(
    whisper_torch_encoder, input_features, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=50)
)
encoder_e2e_median_time

CPU times: user 2.02 s, sys: 1.07 s, total: 3.09 s
Wall time: 3.09 s


0.18660088896285743

In [37]:
input_ids = torch.tensor([[1, 1]]) * whisper_model.config.decoder_start_token_id


In [38]:
%%time
decoder_output, decoder_e2e_median_time = w_decoder_inference(
    whisper_torch_decoder, input_ids.cuda(), encoder_last_hidden_state, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=50)
)
decoder_e2e_median_time

CPU times: user 709 ms, sys: 90.8 ms, total: 800 ms
Wall time: 800 ms


0.03334461199119687

#### Full model inference and benchmark

Next, we will try the T5 model for the task of translation from English to German.

For benchmarking purposes, we will employ a helper function `full_inference` which executes the inference repeatedly and measures end to end execution time. Let's take note of this execution time for comparison with TensorRT. 

In [39]:
from Whisper.WhisperModelConfig import WhisperModelTRTConfig, WhisperMetadata

In [40]:
import transformers

In [41]:
num_beams = 1
min_output_len =0 
max_output_len = whisper_model.config.max_length
tokenizer = processor.tokenizer
forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)
whisper_model.config.forced_decoder_ids = forced_decoder_ids

In [42]:
from NNDF.general_utils import measure_python_inference_code
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])

def percentile_print(timing):
    return ', '.join(['p{} {:.2f}ms'.format(timing_profile.percentile[i], p*1000) for i,p in enumerate(timing)])
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT).cuda(1)
whisper_model.config.forced_decoder_ids = forced_decoder_ids
# encoder-decoder inference 
with torch.no_grad():
    output_ids = whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)
    outputs = processor.tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)
outputs_hf = outputs

# timing
# FP32
whisper_model.float()
hf_nonkv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16, cuda 11.4 has cublas error that will fail in both cpu or cpu model for Whisper
# if not cuda_114_mode:
whisper_model= whisper_model.half()
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features.half(), max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features.half(), max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [43]:
# FP32
HF_KV=True
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
whisper_model.float()
input_features =input_features.float()
whisper_torch_encoder = WhisperEncoderTorchFile.TorchModule(whisper_model.get_encoder())
whisper_torch_decoder = WhisperDecoderTorchFile.TorchModule(whisper_model.get_decoder(), whisper_model.proj_out, whisper_model.config)

with torch.no_grad():
    encoder_last_hidden_state, encoder_pytorch_time = w_encoder_inference(whisper_torch_encoder, input_features, timing_profile)
    _, decoder_pytorch_time = w_decoder_inference(whisper_torch_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids, full_pytorch_time = full_inference_greedy(whisper_torch_encoder,whisper_torch_decoder,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    else:
        output_ids, full_pytorch_time = full_inference_beam(whisper_torch_encoder,whisper_torch_decoder,input_features,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)

outputs_pytorch = outputs

# # FP16
# if not cuda_114_mode:
whisper_model.half()
input_features= input_features.half()
whisper_torch_encoder_fp16 = WhisperEncoderTorchFile.TorchModule(whisper_model.get_encoder())
whisper_torch_decoder_fp16 = WhisperDecoderTorchFile.TorchModule(whisper_model.get_decoder(), whisper_model.proj_out, whisper_model.config)

with torch.no_grad():
    encoder_last_hidden_state, encoder_pytorch_time_fp16 = w_encoder_inference(whisper_torch_encoder_fp16, input_features, timing_profile)
    _, decoder_pytorch_time_fp16 = w_decoder_inference(whisper_torch_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_greedy(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    else:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_beam(whisper_torch_encoder_fp16,whisper_torch_decoder_fp16,input_features,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV, forced_decoder_ids=forced_decoder_ids)
    outputs_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)    

outputs_pytorch_fp16 = outputs_fp16

In [44]:

decoder_input_ids = torch.full(
    (1, 1),
    whisper_torch_decoder.config.decoder_start_token_id,
    dtype=torch.int32,
)
if forced_decoder_ids is None:
    forced_decoder_ids = whisper_torch_decoder.config.forced_decoder_ids

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(224)])
logits_processor = LogitsProcessorList(
    [
        SuppressTokensLogitsProcessor(whisper_torch_decoder.config.suppress_tokens),
        SuppressTokensAtBeginLogitsProcessor(
            whisper_torch_decoder.config.begin_suppress_tokens,
            decoder_input_ids.shape[-1],
        ),
        ForceTokensLogitsProcessor(forced_decoder_ids),
    ]
)


decoder_input_ids = decoder_input_ids.to("cuda")

def _e2e():
    with torch.no_grad():
        encoder_last_hidden_state = whisper_torch_encoder(input_features=input_features)
        decoder_output_greedy = whisper_torch_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=use_cache,
        )
    return decoder_output_greedy

# With e2e we can opt to bind inputs only once for hidden states for optimization
def _e2e_trt():
    with torch.no_grad():
        encoder_last_hidden_state = whisper_torch_encoder(input_features=input_features)
        whisper_torch_decoder.set_encoder_hidden_states_for_inference_cycle(
            encoder_last_hidden_state
        )
        decoder_output_greedy = whisper_torch_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=use_cache,
        )
    return decoder_output_greedy

In [45]:
# print
print(f'PyTorch FP32 Output identical to HF results? {outputs_pytorch == outputs_hf}')
print(f'PyTorch FP16 Output identical to HF results? {outputs_pytorch_fp16 == outputs_hf}')
print('\n')      
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Precision: FP32, Number of Beams: {num_beams}")
print(f"Encoder time: {encoder_pytorch_time}")
print(f"Decoder time: {decoder_pytorch_time}")
print(f"Full E2E time: {full_pytorch_time}")
print(f"Precision: FP16, Number of Beams: {num_beams}")
print(f"Encoder time: {encoder_pytorch_time_fp16}")
print(f"Decoder time: {decoder_pytorch_time_fp16}")
print(f"Full E2E time: {full_pytorch_time_fp16}")

PyTorch FP32 Output identical to HF results? True
PyTorch FP16 Output identical to HF results? True


Device: NVIDIA A100-SXM4-80GB
Precision: FP32, Number of Beams: 1
Encoder time: [0.18671592697501183, 0.22230989392846823]
Decoder time: [0.031467140070162714, 0.03270867792889476]
Full E2E time: [0.6680863999063149, 0.669867079006508]
Precision: FP16, Number of Beams: 1
Encoder time: [0.046387549955397844, 0.04661275597754866]
Decoder time: [0.027844374999403954, 0.04649727907963097]
Full E2E time: [0.5367980289738625, 0.5557327539427206]


<a id="2"></a>

## 2. Convert to ONNX

Prior to converting the model to a TensorRT engine, we will first convert the PyTorch model to an intermediate universal format.

ONNX is an open format for machine learning and deep learning models. It allows you to convert deep learning and machine learning models from different frameworks such as TensorFlow, PyTorch, MATLAB, Caffe, and Keras to a single format.

The steps to convert a PyTorch model to TensorRT are as follows:
- Convert the pretrained image segmentation PyTorch model into ONNX.
- Import the ONNX model into TensorRT.
- Apply optimizations and generate an engine.
- Perform inference on the GPU. 

For the Whisper model, we will convert the encoder and decoder seperately.

In [46]:
from NNDF.networks import NetworkMetadata, Precision
TRT_KV = False

wh_onnx_model_path = './models/{}/onnx'.format(Whisper_VARIANT)
!mkdir -p $wh_onnx_model_path

# FP32
whisper_model.float()
metadata = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=False), other=WhisperMetadata(kv_cache=TRT_KV))
trt_config = WhisperModelTRTConfig()
metadata_string = trt_config.get_metadata_string(metadata)

wh_encoder_onnx_model_fpath = metadata_string + "-encoder.onnx"
wh_decoder_onnx_model_fpath = metadata_string + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
whisper_torchfile_encoder = WhisperEncoderTorchFile(whisper_model.to('cpu'), metadata)
whisper_torchfile_decoder = WhisperDecoderTorchFile(whisper_model.to('cpu'), metadata)

onnx_whisper_encoder = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), force_overwrite=True)
onnx_whisper_decoder = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), force_overwrite=True)

# FP16
metadata_fp16 = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=True), other=WhisperMetadata(kv_cache=TRT_KV))
trt_config_fp16 = WhisperModelTRTConfig()
metadata_string_fp16 = trt_config_fp16.get_metadata_string(metadata_fp16)

wh_encoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-encoder.onnx"
wh_decoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
whisper_torchfile_encoder = WhisperEncoderTorchFile(whisper_model.to('cpu'), metadata)
whisper_torchfile_decoder = WhisperDecoderTorchFile(whisper_model.to('cpu'), metadata)

onnx_whisper_encoder_fp16 = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath_fp16), force_overwrite=True)
onnx_whisper_decoder_fp16 = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath_fp16), force_overwrite=True)

  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  if input_shape[-1] > 1:
  mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):


<a id="3"></a>

## 3. Convert to TensorRT

Now we are ready to parse the ONNX encoder and decoder models and convert them to optimized TensorRT engines.

Since the models contains dynamic input shapes, we can specify a valid input range with a TensorRT optimization profile.

In [47]:
from Whisper.export import WhisperDecoderONNXFile, WhisperEncoderONNXFile
from polygraphy.backend.trt import Profile
from tensorrt import PreviewFeature

In [48]:
encoder_hidden_size = whisper_model.config.d_model

In [49]:
num_beams=1

In [50]:
wh_tensorrt_model_path = './models/{}/tensorrt'.format(Whisper_VARIANT)
!mkdir -p wh_tensorrt_model_path
# Decoder optimization profiles
batch_size = 1
max_sequence_length = WhisperModelTRTConfig.MAX_SEQUENCE_LENGTH[Whisper_VARIANT]
decoder_profile = Profile()
decoder_profile.add(
    "input_ids",
    min=(batch_size * num_beams, 1),
    opt=(batch_size * num_beams, max_sequence_length // 2),
    max=(batch_size * num_beams, max_sequence_length),
)
decoder_profile.add(
    "encoder_hidden_states",
    min=(batch_size * num_beams, 1, encoder_hidden_size),
    opt=(batch_size * num_beams, 1500, encoder_hidden_size),
    max=(batch_size * num_beams, 1500, encoder_hidden_size),
)

# Encoder optimization profiles
encoder_profile = Profile()
encoder_profile.add(
    "input_features",
    min=(batch_size, 80, 3000),
    opt=(batch_size, 80, 3000),
    max=(batch_size, 80, 3000)
)

disable_preview_dynamic_shapes = False
engine_tag = f"bs{batch_size}"

In [51]:
force_write=False
engine_tag = f"bs{batch_size}"

if num_beams > 1:
    engine_tag += "-beam{}".format(num_beams)

preview_features = [PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if disable_preview_dynamic_shapes:
    engine_tag += "-noPreviewFasterDynamicShapes"
else:
    preview_features.append(PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)

# FP32
wh_encoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name) or force_write:
    whisper_trt_encoder_engine = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_encoder_engine_name, 
        force_overwrite=force_write,
        profiles=[encoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine = WhisperEncoderTRTEngine(wh_encoder_engine_name, metadata)
    

In [52]:
from Whisper.trt import WhisperTRTEncoder, WhisperTRTDecoder, TRTHFRunner
from transformers import WhisperConfig

In [53]:
trt_config = WhisperConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

In [54]:
whisper_trt_encoder = WhisperTRTEncoder(whisper_trt_encoder_engine, metadata, trt_config, batch_size=batch_size)

[09/06/2023-10:09:38] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading


In [55]:
%time
whisper_trt_encoder(input_features=inputs['input_features'])

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 14.5 µs


tensor([[[-0.3555, -0.4582, -0.3098,  ..., -3.4159, -0.5448,  0.0769],
         [ 0.0394, -0.2632,  0.2806,  ..., -3.0623, -0.5295,  0.1974],
         [ 0.7707, -0.2223,  0.6847,  ..., -2.1922, -0.2827, -0.0169],
         ...,
         [ 0.5559,  0.1904, -0.0888,  ..., -0.9813, -0.1762, -0.6617],
         [ 0.2957,  0.3258, -0.5251,  ..., -0.5859, -0.5897, -0.5441],
         [-0.0442,  0.6663, -0.8131,  ..., -0.1867, -0.7591, -0.5466]]])

In [56]:
whisper_torch_encoder.cuda()
whisper_torch_encoder(input_features=inputs['input_features'].cuda())

tensor([[[-0.3700,  0.0478, -0.2257,  ..., -3.7430, -0.1793,  0.1543],
         [-0.3259,  0.2111, -0.0563,  ..., -3.8932,  0.1018,  0.0082],
         [-0.0418,  0.3745,  0.3751,  ..., -3.8022,  0.2305, -0.1163],
         ...,
         [ 0.3711,  0.1129, -0.5408,  ...,  0.4810,  0.7379,  1.6578],
         [ 0.4409,  0.5879, -0.0910,  ...,  0.1290, -0.2721,  0.6057],
         [ 0.1196,  0.5980,  0.6380,  ..., -0.2902, -0.0962,  0.5307]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [57]:
force_write=True
engine_tag = f"bs{batch_size}"

if num_beams > 1:
    engine_tag += "-beam{}".format(num_beams)

preview_features = [PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
if disable_preview_dynamic_shapes:
    engine_tag += "-noPreviewFasterDynamicShapes"
else:
    preview_features.append(PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)

# FP32
wh_encoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name) or force_write:
    whisper_trt_encoder_engine = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_encoder_engine_name, 
        force_overwrite=force_write,
        profiles=[encoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine = WhisperEncoderTRTEngine(wh_encoder_engine_name, metadata)
    
if not os.path.exists(wh_decoder_engine_name) or force_write:
    whisper_trt_decoder_engine = WhisperDecoderONNXFile(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), metadata).as_trt_engine(
        wh_decoder_engine_name, 
        force_overwrite=force_write,
        profiles=[decoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_decoder_engine = WhisperDecoderTRTEngine(wh_decoder_engine_name, metadata)


[38;5;11m[W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading[0m
[38;5;11m[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.[0m
[38;5;11m[W] It looks like some layers in the network have compute precision set, but precision constraints were not enabled. 
    Precision constraints must be set to 'prefer' or 'obey' for layer compute precision to take effect. 
    Note: Layers and their requested precisions were: {'encoder/layers.0/self_attn_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/ReduceMean_1': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Add': 'FLOAT', 'encoder/layers.0/self_attn

In [58]:
# FP16
wh_encoder_engine_name_fp16 = os.path.join(wh_tensorrt_model_path, wh_encoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine".replace(f"-beam{num_beams}", "") # encoder engine not affected by beam search
wh_decoder_engine_name_fp16 = os.path.join(wh_tensorrt_model_path, wh_decoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine"

if not os.path.exists(wh_encoder_engine_name_fp16) or force_write:
    whisper_trt_encoder_engine_fp16 = WhisperEncoderONNXFile(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
        wh_encoder_engine_name_fp16, 
        force_overwrite=force_write,
        profiles=[encoder_profile],
        preview_features=preview_features
    )
else:
    whisper_trt_encoder_engine_fp16 = WhisperEncoderTRTEngine(wh_encoder_engine_name_fp16, metadata_fp16)
    
if not os.path.exists(wh_decoder_engine_name_fp16) or force_write:
    whisper_trt_decoder_engine_fp16 = WhisperDecoderONNXFile(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
        wh_decoder_engine_name_fp16, 
        force_overwrite=force_write,
        profiles=[decoder_profile], 
        preview_features=preview_features
    )
else:
    whisper_trt_decoder_engine_fp16 = WhisperDecoderTRTEngine(wh_decoder_engine_name_fp16, metadata_fp16)

[38;5;11m[W] It looks like some layers in the network have compute precision set, but precision constraints were not enabled. 
    Precision constraints must be set to 'prefer' or 'obey' for layer compute precision to take effect. 
    Note: Layers and their requested precisions were: {'encoder/layers.0/self_attn_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/ReduceMean_1': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Add': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Sqrt': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Div': 'FLOAT', 'encoder/layers.0/self_attn_layer_norm/Mul': 'FLOAT', 'encoder/layers.0/final_layer_norm/ReduceMean': 'FLOAT', 'encoder/layers.0/final_layer_norm/Pow': 'FLOAT', 'encoder/layers.0/final_layer_norm/ReduceMean_1': 'FLOAT', 'encoder/layers.0/final_layer_norm/Add': 'FLOAT', 'encoder/layers.0/final_layer_norm/Sqrt': 'FLOAT', 'encoder/layers.0/final_layer_norm/Div': 'FLOAT', 

In [59]:
wh_encoder_engine_name_fp16

'./models/openai/whisper-large-v2/tensorrt/Whisper-large-v2-fp16-encoder.onnx-bs1.engine'

In [60]:
print(wh_encoder_onnx_model_fpath)
print(wh_decoder_onnx_model_fpath)
print(onnx_whisper_encoder)
print(onnx_whisper_decoder)
#onnx_whisper_encoder = whisper_torchfile_encoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_encoder_onnx_model_fpath), force_overwrite=False)
#onnx_whisper_decoder = whisper_torchfile_decoder.as_onnx_model(os.path.join(wh_onnx_model_path, wh_decoder_onnx_model_fpath), force_overwrite=False)

Whisper-large-v2-encoder.onnx
Whisper-large-v2-decoder-with-lm-head.onnx
<Whisper.export.WhisperEncoderONNXFile object at 0x7f00ec4cb430>
<Whisper.export.WhisperDecoderONNXFile object at 0x7f01a8009730>


# Whisper Tensorrt 

In [61]:
from transformers import AutoConfig
from Whisper.trt import WhisperTRTEncoder, WhisperTRTDecoder, TRTHFRunner

In [62]:
decoder_profile

Profile().add('input_ids', min=(1, 1), opt=(1, 224), max=(1, 448)).add('encoder_hidden_states', min=(1, 1, 1280), opt=(1, 1500, 1280), max=(1, 1500, 1280))

In [63]:
# Initialize TensorRT engines
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

# FP32
whisper_trt_encoder = WhisperTRTEncoder(whisper_trt_encoder_engine, metadata, trt_config, batch_size=batch_size)
whisper_trt_decoder = WhisperTRTDecoder(whisper_trt_decoder_engine, metadata, trt_config, batch_size=batch_size, num_beams=num_beams)

# FP16
whisper_trt_encoder_fp16 = WhisperTRTEncoder(whisper_trt_encoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size)
whisper_trt_decoder_fp16 = WhisperTRTDecoder(whisper_trt_decoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size, num_beams=num_beams)

[09/06/2023-10:20:44] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[09/06/2023-10:20:50] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[09/06/2023-10:20:51] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[09/06/2023-10:20:53] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up 

In [64]:
metadata_fp16

NetworkMetadata(variant='openai/whisper-large-v2', precision=Precision(fp16=True), other=WhisperMetadata(kv_cache=False))

In [65]:
%%time
encoder_last_hidden_state, encoder_trt_time = w_encoder_inference(
    whisper_trt_encoder, input_features, TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
)
encoder_e2e_median_time

CPU times: user 574 ms, sys: 0 ns, total: 574 ms
Wall time: 574 ms


0.18660088896285743

In [66]:
encoder_last_hidden_states = whisper_trt_encoder(input_features=inputs['input_features'])

In [67]:
%time
whisper_trt_encoder_fp16(input_features=inputs['input_features'])

CPU times: user 5 µs, sys: 1e+03 ns, total: 6 µs
Wall time: 11 µs


tensor([[[-0.3556, -0.4579, -0.3098,  ..., -3.4154, -0.5449,  0.0765],
         [ 0.0392, -0.2633,  0.2807,  ..., -3.0624, -0.5296,  0.1974],
         [ 0.7705, -0.2222,  0.6846,  ..., -2.1927, -0.2830, -0.0171],
         ...,
         [ 0.5561,  0.1907, -0.0890,  ..., -0.9821, -0.1759, -0.6618],
         [ 0.2950,  0.3261, -0.5257,  ..., -0.5850, -0.5901, -0.5439],
         [-0.0440,  0.6661, -0.8131,  ..., -0.1867, -0.7592, -0.5468]]])

In [68]:
%time
encoder_last_hidden_states

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 10.7 µs


tensor([[[-0.3556, -0.4579, -0.3098,  ..., -3.4154, -0.5449,  0.0765],
         [ 0.0392, -0.2633,  0.2807,  ..., -3.0624, -0.5296,  0.1974],
         [ 0.7705, -0.2222,  0.6846,  ..., -2.1927, -0.2830, -0.0171],
         ...,
         [ 0.5561,  0.1907, -0.0890,  ..., -0.9821, -0.1759, -0.6618],
         [ 0.2950,  0.3261, -0.5257,  ..., -0.5850, -0.5901, -0.5439],
         [-0.0440,  0.6661, -0.8131,  ..., -0.1867, -0.7592, -0.5468]]])

In [69]:
whisper_torch_encoder(input_features=input_features.cuda().float())

tensor([[[-1.2115e+00, -7.7491e-01, -1.2609e+00,  ..., -3.0477e+00,
          -1.3879e-01,  4.1334e-01],
         [-1.0031e+00,  4.8575e-01, -3.9980e-01,  ..., -2.9799e+00,
           1.4637e-01, -1.0124e-01],
         [-9.0454e-01,  6.1162e-01, -3.8161e-01,  ..., -3.1192e+00,
          -5.3057e-02,  2.5831e-02],
         ...,
         [-8.0957e-04, -8.1212e-03, -1.0276e-02,  ...,  4.6533e-03,
          -4.0645e-03, -1.0909e-02],
         [-2.8883e-03, -5.0357e-03, -1.2229e-02,  ...,  4.9758e-03,
          -3.9479e-03, -1.0991e-02],
         [-5.3624e-03,  2.5027e-03, -1.0828e-02,  ...,  4.0317e-03,
           3.3815e-04, -1.2044e-02]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>)

In [70]:
whisper_torch_decoder.cuda()

TorchModule(
  (decoder): WhisperDecoder(
    (embed_tokens): Embedding(51865, 1280, padding_idx=50257)
    (embed_positions): WhisperPositionalEmbedding(448, 1280)
    (layers): ModuleList(
      (0): WhisperDecoderLayer(
        (self_attn): WhisperAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (activation_fn): GELUActivation()
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (encoder_attn): WhisperAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear

In [71]:
encoder_last_hidden_state = whisper_torch_encoder(input_features=input_features.cuda().float())
decoder_output_greedy = whisper_torch_decoder.greedy_search(
    input_ids=decoder_input_ids.cuda(),
    encoder_hidden_states=encoder_last_hidden_state.cuda(),
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=False,
)

In [72]:
processor.tokenizer.decode(decoder_output_greedy[0])

'<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>'

In [73]:
decoder_input_ids = torch.full(
    (batch_size, 1),
    WhisperModelTRTConfig.DECODER_START_TOKEN_ID,
)


In [74]:
encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features.cuda().float())

decoder_output = whisper_trt_decoder.greedy_search(
                input_ids=decoder_input_ids.cuda(),
                encoder_hidden_states=encoder_last_hidden_states.cuda(),
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata_fp16.other.kv_cache,
                use_cuda=True
            )

In [75]:
tokenizer.decode(decoder_output[0])

'<|startoftranscript|><|en|><|transcribe|><|notimestamps|> you<|endoftext|>'

### End-to-End TensorRT Inference

In [85]:
from transformers.generation_logits_process import (
    LogitsProcessorList,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    ForceTokensLogitsProcessor,
)

from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
min_length = WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[Whisper_VARIANT]
decoder_input_ids = torch.full(
    (batch_size, 1),
    WhisperModelTRTConfig.DECODER_START_TOKEN_ID,
)

forced_decoder_ids=processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(whisper_model.config.max_length)])
no_repeat_ngram_size = WhisperModelTRTConfig.NO_REPEAT_NGRAM_SIZE
logits_processor = LogitsProcessorList(
    [
        SuppressTokensLogitsProcessor(WhisperModelTRTConfig.SUPPRESS_TOKENS),
        SuppressTokensAtBeginLogitsProcessor(
            WhisperModelTRTConfig.BEGIN_SUPPRESS_TOKENS, decoder_input_ids.shape[-1]
        ),
        ForceTokensLogitsProcessor(forced_decoder_ids),
    ]
)  # by checking HuggingFace's generate() implementation carefully, the default logits processor for BART has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can get identical results with raw HuggingFace

encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.cpu()
if num_beams > 1:
    decoder_input_ids = expand_inputs_for_beam_search(decoder_input_ids, expand_size=num_beams)
    
# FP32
def e2e_trt():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)

            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda:1",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder.greedy_search(
                input_ids=decoder_input_ids.cuda(),
                encoder_hidden_states=encoder_outputs.last_hidden_state.cuda(1),
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder.beam_search(
                input_ids=decoder_input_ids.cuda(),
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids = e2e_trt()
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)
trt_time = measure_python_inference_code(e2e_trt, timing_profile)

In [86]:
# FP16
def e2e_trt_fp16():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder_fp16(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)
            
            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda:1",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder_fp16.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states) 
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder_fp16.greedy_search(
                input_ids=decoder_input_ids.cuda(),
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder_fp16.beam_search(
                input_ids=decoder_input_ids.cuda(),
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids_fp16 = e2e_trt_fp16()
outputs_trt_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)
trt_time_fp16 = measure_python_inference_code(e2e_trt_fp16, timing_profile)

In [87]:
# print results and timing statistics
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Using engine: {metadata_string + '-' + engine_tag}")
print(f'Output identical to HF results? {outputs_trt == outputs_hf}')
print(f"Precision: FP32")
print(f'TRT time: {trt_time}')
print()
print(f"Using engine: {metadata_string_fp16 + '-' + engine_tag}")
print(f'Output identical to HF results? {outputs_trt_fp16 == outputs_hf}')
print(f"Precision: FP16")
print(f'TRT time: {trt_time_fp16}')

Device: NVIDIA A100-SXM4-80GB
Using engine: Whisper-large-v2-bs1
Output identical to HF results? False
Precision: FP32
TRT time: [0.1209204220212996, 0.12111746391747147]

Using engine: Whisper-large-v2-fp16-bs1
Output identical to HF results? False
Precision: FP16
TRT time: [0.08881275996100157, 0.08901880006305873]


In [88]:
%%time
a, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile)



CPU times: user 158 ms, sys: 255 µs, total: 158 ms
Wall time: 158 ms


### Time Measurement of Encoder, Decoder, and Full E2E
We will benchmark the encoder, decoder, and full end-to-end as we did for HuggingFace before.

In [89]:
# encoder-decoder inference 
whisper_model.float()
whisper_model = whisper_model.cuda(1)

input_features = input_features.float().cuda(1)

with torch.no_grad():
    output_ids = whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)    
    outputs = tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)    
outputs_hf = outputs

# timing
# FP32
input_features = input_features.float().cuda(1)
hf_nonkv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16, cuda 11.4 has cublas error that will fail in both cpu or cpu model for BART
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: whisper_model.generate(input_features, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [90]:
# FP32
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, input_features, timing_profile)
_, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=0,
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time = full_inference_beam(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=0,
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
    
print(f'Encoder time: {percentile_print(encoder_trt_time)}')
print(f'Decoder time: {percentile_print(decoder_trt_time)}')
print(f'Full E2E time: {percentile_print(full_trt_time)}')

# FP16
encoder_last_hidden_states, encoder_trt_time_fp16 = w_encoder_inference(whisper_trt_encoder_fp16, input_features, timing_profile)
_, decoder_trt_time_fp16 = w_decoder_inference(whisper_trt_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time_fp16 = full_inference_greedy(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=0,
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time_fp16 = full_inference_beam(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=0,
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
print(f'Encoder FP16 time: {percentile_print(encoder_trt_time_fp16)}')
print(f'Decoder FP16 time: {percentile_print(decoder_trt_time_fp16)}')
print(f'Full E2E FP16 time: {percentile_print(full_trt_time_fp16)}')

Encoder time: p50 46.94ms, p99 61.85ms
Decoder time: p50 13.04ms, p99 14.47ms
Full E2E time: p50 120.81ms, p99 122.05ms
Encoder FP16 time: p50 46.89ms, p99 46.92ms
Decoder FP16 time: p50 7.92ms, p99 8.67ms
Full E2E FP16 time: p50 88.59ms, p99 89.33ms


In [91]:
from tabulate import tabulate

data = [
    ['Framework', 'Precision', 'Encoder p50 (ms)', 'Decoder p50 (ms)', 'Full E2E p50 (ms)', 'Accuracy'],
    ['HuggingFace (w/o cache)', 'FP32', '-', '-', f'{hf_nonkv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP32', '-', '-', f'{hf_kv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/o cache)', 'FP16', '-', '-', f'{hf_nonkv_time_fp16[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP16', '-', '-', f'{hf_kv_time_fp16[0]*1000:.2f}', '-'],
    ['PyTorch', 'FP32', f'{encoder_pytorch_time[0]*1000:.2f}', f'{decoder_pytorch_time[0]*1000:.2f}', f'{full_pytorch_time[0]*1000:.2f}', outputs_pytorch == outputs_hf],
    ['PyTorch', 'FP16', f'{encoder_pytorch_time_fp16[0]*1000:.2f}', f'{decoder_pytorch_time_fp16[0]*1000:.2f}', f'{full_pytorch_time_fp16[0]*1000:.2f}', outputs_pytorch_fp16 == outputs_hf],
    ['TensorRT', 'FP32', f'{encoder_trt_time[0]*1000:.2f}', f'{decoder_trt_time[0]*1000:.2f}', f'{full_trt_time[0]*1000:.2f}', outputs_trt == outputs_hf],
    ['TensorRT', 'FP16', f'{encoder_trt_time_fp16[0]*1000:.2f}', f'{decoder_trt_time_fp16[0]*1000:.2f}', f'{full_trt_time_fp16[0]*1000:.2f}', outputs_trt_fp16 == outputs_hf],
]

print(tabulate(data, headers='firstrow', tablefmt='github'))

| Framework               | Precision   | Encoder p50 (ms)   | Decoder p50 (ms)   |   Full E2E p50 (ms) | Accuracy   |
|-------------------------|-------------|--------------------|--------------------|---------------------|------------|
| HuggingFace (w/o cache) | FP32        | -                  | -                  |             1130.56 | -          |
| HuggingFace (w/ cache)  | FP32        | -                  | -                  |              676.55 | -          |
| HuggingFace (w/o cache) | FP16        | -                  | -                  |             1131.42 | -          |
| HuggingFace (w/ cache)  | FP16        | -                  | -                  |              683.25 | -          |
| PyTorch                 | FP32        | 186.72             | 31.47              |              668.09 | True       |
| PyTorch                 | FP16        | 46.39              | 27.84              |              536.8  | True       |
| TensorRT                | FP32        | 46.94 

In [95]:
outputs_pytorch_fp16

' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'

In [93]:
outputs_hf

' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'

In [96]:
encoder_last_hidden_states, encoder_trt_time_fp16 = w_encoder_inference(whisper_trt_encoder_fp16, input_features, timing_profile)