In [1]:
import os
import sys
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

import torch
import tensorrt as trt

# huggingface
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig
)

import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Numpy array.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """
    resampler = av.audio.resampler.AudioResampler(
        format="s16",
        layout="mono" if not split_stereo else "stereo",
        rate=sampling_rate,
    )

    raw_buffer = io.BytesIO()
    dtype = None

    with av.open(input_file, metadata_errors="ignore") as container:
        frames = container.decode(audio=0)
        frames = _ignore_invalid_frames(frames)
        frames = _group_frames(frames, 500000)
        frames = _resample_frames(frames, resampler)

        for frame in frames:
            array = frame.to_ndarray()
            dtype = array.dtype
            raw_buffer.write(array)

    audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

    # Convert s16 back to f32.
    audio = audio.astype(np.float32) / 32768.0

    if split_stereo:
        left_channel = audio[0::2]
        right_channel = audio[1::2]
        return left_channel, right_channel

    return audio

def _ignore_invalid_frames(frames):
    iterator = iter(frames)

    while True:
        try:
            yield next(iterator)
        except StopIteration:
            break
        except av.error.InvalidDataError:
            continue


def _group_frames(frames, num_samples=None):
    fifo = av.audio.fifo.AudioFifo()

    for frame in frames:
        frame.pts = None  # Ignore timestamp check.
        fifo.write(frame)

        if num_samples is not None and fifo.samples >= num_samples:
            yield fifo.read()

    if fifo.samples > 0:
        yield fifo.read()


def _resample_frames(frames, resampler):
    # Add None to flush the resampler.
    for frame in itertools.chain(frames, [None]):
        yield from resampler.resample(frame)

In [2]:
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig,
    AutoConfig
)

In [26]:
import torch
from datasets import load_dataset

Whisper_VARIANT = "openai/whisper-tiny"    # choices: openai/whisper-tiny | openai/whisper-base | openai/whisper-small | openai/whisper-medium | openai/whisper-large-v2

processor = WhisperProcessor.from_pretrained(Whisper_VARIANT)
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT)
wh_config = WhisperConfig.from_pretrained(Whisper_VARIANT, use_cache = False)
tokenizer = processor.tokenizer

In [27]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [28]:
from Whisper.measurements import decoder_inference as w_decoder_inference, encoder_inference as w_encoder_inference, full_inference as w_full_inference, full_inference_greedy, full_inference_beam
from Whisper.export import WhisperEncoderTorchFile, WhisperDecoderTorchFile, WhisperEncoderTRTEngine, WhisperDecoderTRTEngine

from NNDF.networks import TimingProfile
from NNDF.torch_utils import expand_inputs_for_beam_search

In [29]:
from NNDF.networks import NetworkMetadata, Precision
from Whisper.WhisperModelConfig import WhisperModelTRTConfig, WhisperMetadata
from Whisper.trt import WhisperTRTEncoder, WhisperTRTDecoder, TRTHFRunner
TRT_KV=False
metadata = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=False), other=WhisperMetadata(kv_cache=TRT_KV))

In [30]:
whisper_trt_encoder_engine = WhisperEncoderTRTEngine('./models/openai/whisper-tiny/tensorrt/Whisper-tiny-encoder.onnx-bs1.engine', metadata)
whisper_trt_decoder_engine = WhisperDecoderTRTEngine('./models/openai/whisper-tiny/tensorrt/Whisper-tiny-decoder-with-lm-head.onnx-bs1.engine', metadata)

In [31]:
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

In [32]:
# Initialize TensorRT engines
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

# FP32
whisper_trt_encoder = WhisperTRTEncoder(whisper_trt_encoder_engine, metadata, trt_config, batch_size=1)
whisper_trt_decoder = WhisperTRTDecoder(whisper_trt_decoder_engine, metadata, trt_config, batch_size=1, num_beams=1)

# # FP16
# whisper_trt_encoder_fp16 = WhisperTRTEncoder(whisper_trt_encoder_engine_fp16, metadata_fp16, trt_config, batch_size=1)
# whisper_trt_decoder_fp16 = WhisperTRTDecoder(whisper_trt_decoder_engine_fp16, metadata_fp16, trt_config, batch_size=1, num_beams=1)

[08/14/2023-13:07:06] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
[08/14/2023-13:07:07] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars


In [33]:
#Get input_features

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

audio_inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = audio_inputs.input_features

# WAR: Using an ugly representation because cuda 11.4 does not support GPU models due to cublas errors
if "LD_LIBRARY_PATH" in os.environ and "cuda-11.4" in os.environ["LD_LIBRARY_PATH"]:
    whisper_model = whisper_model.cpu()
    input_features = input_features.to('cpu')
else:
    whisper_model = whisper_model.cuda()
    input_features = input_features.to('cuda:1')   

Found cached dataset librispeech_asr_dummy (/home/nvadmin/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [40]:
num_beams = 1
batch_size = 1
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])
input_ids = torch.full(
    (batch_size, 1),
    WhisperModelTRTConfig.DECODER_START_TOKEN_ID,
)
min_output_len =0 
max_output_len = whisper_model.config.max_length

def percentile_print(timing):
    return ', '.join(['p{} {:.2f}ms'.format(timing_profile.percentile[i], p*1000) for i,p in enumerate(timing)])

In [119]:
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)
# from HuggingFace transformers
from transformers.generation_logits_process import (
    LogitsProcessorList,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    ForceTokensLogitsProcessor,
)
from NNDF.tensorrt_utils import TRTNativeRunner
from NNDF.general_utils import measure_python_inference_code

decoder_input_ids = torch.full(
    (batch_size, 1),
    whisper_trt_decoder.config.decoder_start_token_id,
    dtype=torch.int32,
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
logits_processor = LogitsProcessorList(
    [
        SuppressTokensLogitsProcessor(whisper_trt_decoder.config.suppress_tokens),
        SuppressTokensAtBeginLogitsProcessor(
            whisper_trt_decoder.config.begin_suppress_tokens, decoder_input_ids.shape[-1]
        ),
        ForceTokensLogitsProcessor(forced_decoder_ids),
    ]
)

decoder_input_ids = decoder_input_ids.to("cuda")

def _e2e():
    with torch.no_grad():
        encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
        decoder_output_greedy = whisper_trt_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=True,
        )
    return decoder_output_greedy

def _e2e_trt():
    with torch.no_grad():
        encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
        whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(
            encoder_last_hidden_state
        )
        decoder_output_greedy = whisper_trt_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=True,
        )
    return decoder_output_greedy    

measurement_function = _e2e
if isinstance(whisper_trt_decoder, TRTNativeRunner):
    whisper_trt_decoder.set_return_device("cuda")
    measurement_function = _e2e_trt

full_e2e_time = measure_python_inference_code(measurement_function, timing_profile)

In [120]:
full_e2e_time

[0.036400729091838, 0.05314069800078869]

In [91]:
full_e2e_time

[0.039000588934868574, 0.051730379927903414]

In [82]:
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
min_length = WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[Whisper_VARIANT]
decoder_input_ids = torch.full(
    (batch_size, 1),
    WhisperModelTRTConfig.DECODER_START_TOKEN_ID,
)

forced_decoder_ids=processor.get_decoder_prompt_ids(language="en", task="transcribe", no_timestamps=True)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(whisper_model.config.max_length)])
logits_processor = LogitsProcessorList(
    [
        SuppressTokensLogitsProcessor(WhisperModelTRTConfig.SUPPRESS_TOKENS),
        SuppressTokensAtBeginLogitsProcessor(
            WhisperModelTRTConfig.BEGIN_SUPPRESS_TOKENS, decoder_input_ids.shape[-1]
        ),
        ForceTokensLogitsProcessor(forced_decoder_ids),
    ]
)  # by checking HuggingFace's generate() implementation carefully, the default logits processor for BART has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can get identical results with raw HuggingFace

   
# FP32
def e2e_trt():
    with torch.no_grad():
        encoder_last_hidden_states = whisper_trt_encoder(input_features=input_features)
        
        if num_beams > 1:
            # prepare input for beam search
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)

            # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=num_beams,
                device="cuda:1",
                do_early_stopping=True,
            )
        
        whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = whisper_trt_decoder.greedy_search(
                input_ids=decoder_input_ids.cuda(),
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = whisper_trt_decoder.beam_search(
                input_ids=decoder_input_ids.cuda(),
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids = e2e_trt()
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)
trt_time = measure_python_inference_code(e2e_trt, timing_profile)

In [83]:
trt_time

[0.036364562809467316, 0.036596449092030525]

In [76]:
%%time
encoder_last_hidden_states = whisper_trt_encoder(input_features=input_features)
whisper_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
decoder_output = whisper_trt_decoder.greedy_search(
    input_ids=decoder_input_ids.cuda(),
    encoder_hidden_states=encoder_last_hidden_states,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=metadata.other.kv_cache,
    use_cuda=True
)

CPU times: user 52.9 ms, sys: 394 µs, total: 53.2 ms
Wall time: 52.4 ms


In [70]:
decoder_output

tensor([[50258, 50259, 50359, 50363,  2221,    13,  2326,   388,   391,   307,
           264, 50244,   295,   264,  2808,  5359,   293,   321,   366,  5404,
           281,  2928,   702, 14943,    13, 50257]], device='cuda:0')

In [65]:
%%time
encoder_last_hidden_state = whisper_trt_encoder(input_features=input_features)
decoder_output_greedy = whisper_trt_decoder.greedy_search(
    input_ids=decoder_input_ids,
    encoder_hidden_states=encoder_last_hidden_state,
    stopping_criteria=stopping_criteria,
    logits_processor=logits_processor,
    use_cache=True,
)

CPU times: user 856 ms, sys: 12.9 ms, total: 869 ms
Wall time: 868 ms


In [42]:
# FP32
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, input_features, timing_profile)
_, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time = full_inference_beam(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
    
print(f'Encoder time: {percentile_print(encoder_trt_time)}')
print(f'Decoder time: {percentile_print(decoder_trt_time)}')
print(f'Full E2E time: {percentile_print(full_trt_time)}')


Encoder time: p50 2.16ms, p99 2.17ms
Decoder time: p50 1.60ms, p99 1.61ms
Full E2E time: p50 773.98ms, p99 775.35ms


[0.7739814920350909, 0.775348360883072]

# fp16

In [None]:

# FP16
encoder_last_hidden_states, encoder_trt_time_fp16 = w_encoder_inference(whisper_trt_encoder_fp16, input_features, timing_profile)
_, decoder_trt_time_fp16 = w_decoder_inference(whisper_trt_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time_fp16 = full_inference_greedy(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time_fp16 = full_inference_beam(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
print(f'Encoder FP16 time: {percentile_print(encoder_trt_time_fp16)}')
print(f'Decoder FP16 time: {percentile_print(decoder_trt_time_fp16)}')
print(f'Full E2E FP16 time: {percentile_print(full_trt_time_fp16)}')