In [10]:
import os
import sys
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

import torch
import tensorrt as trt

# huggingface
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig
)

import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
def decode_audio(
    input_file: Union[str, BinaryIO],
    sampling_rate: int = 16000,
    split_stereo: bool = False,
):
    """Decodes the audio.

    Args:
      input_file: Path to the input file or a file-like object.
      sampling_rate: Resample the audio to this sample rate.
      split_stereo: Return separate left and right channels.

    Returns:
      A float32 Numpy array.

      If `split_stereo` is enabled, the function returns a 2-tuple with the
      separated left and right channels.
    """
    resampler = av.audio.resampler.AudioResampler(
        format="s16",
        layout="mono" if not split_stereo else "stereo",
        rate=sampling_rate,
    )

    raw_buffer = io.BytesIO()
    dtype = None

    with av.open(input_file, metadata_errors="ignore") as container:
        frames = container.decode(audio=0)
        frames = _ignore_invalid_frames(frames)
        frames = _group_frames(frames, 500000)
        frames = _resample_frames(frames, resampler)

        for frame in frames:
            array = frame.to_ndarray()
            dtype = array.dtype
            raw_buffer.write(array)

    audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

    # Convert s16 back to f32.
    audio = audio.astype(np.float32) / 32768.0

    if split_stereo:
        left_channel = audio[0::2]
        right_channel = audio[1::2]
        return left_channel, right_channel

    return audio

def _ignore_invalid_frames(frames):
    iterator = iter(frames)

    while True:
        try:
            yield next(iterator)
        except StopIteration:
            break
        except av.error.InvalidDataError:
            continue


def _group_frames(frames, num_samples=None):
    fifo = av.audio.fifo.AudioFifo()

    for frame in frames:
        frame.pts = None  # Ignore timestamp check.
        fifo.write(frame)

        if num_samples is not None and fifo.samples >= num_samples:
            yield fifo.read()

    if fifo.samples > 0:
        yield fifo.read()


def _resample_frames(frames, resampler):
    # Add None to flush the resampler.
    for frame in itertools.chain(frames, [None]):
        yield from resampler.resample(frame)

In [21]:
from transformers import (
    WhisperProcessor, 
    WhisperForConditionalGeneration,
    WhisperTokenizer,
    WhisperConfig,
    AutoConfig
)

In [22]:
import torch
from datasets import load_dataset

Whisper_VARIANT = "openai/whisper-tiny"    # choices: openai/whisper-tiny | openai/whisper-base | openai/whisper-small | openai/whisper-medium | openai/whisper-large-v2

processor = WhisperProcessor.from_pretrained(Whisper_VARIANT)
whisper_model = WhisperForConditionalGeneration.from_pretrained(Whisper_VARIANT)
wh_config = WhisperConfig.from_pretrained(Whisper_VARIANT, use_cache = False)

In [23]:
audio=decode_audio("korean_news.mp4")
duration = audio.shape[0] / 16000
inputs = processor(audio, return_tensors="pt")

In [24]:
from Whisper.measurements import decoder_inference as w_decoder_inference, encoder_inference as w_encoder_inference, full_inference as w_full_inference, full_inference_greedy, full_inference_beam
from Whisper.export import WhisperEncoderTorchFile, WhisperDecoderTorchFile, WhisperEncoderTRTEngine, WhisperDecoderTRTEngine

from NNDF.networks import TimingProfile
from NNDF.torch_utils import expand_inputs_for_beam_search

In [33]:
from NNDF.networks import NetworkMetadata, Precision
from Whisper.WhisperModelConfig import WhisperModelTRTConfig, WhisperMetadata
from Whisper.trt import WhisperTRTEncoder, WhisperTRTDecoder, TRTHFRunner
TRT_KV=False
metadata = NetworkMetadata(variant=Whisper_VARIANT, precision=Precision(fp16=False), other=WhisperMetadata(kv_cache=TRT_KV))

In [35]:
whisper_trt_encoder_engine = WhisperEncoderTRTEngine('./models/openai/whisper-tiny/tensorrt/Whisper-tiny-encoder.onnx-bs1.engine', metadata)
whisper_trt_decoder_engine = WhisperDecoderTRTEngine('./models/openai/whisper-tiny/tensorrt/Whisper-tiny-decoder-with-lm-head.onnx-bs1.engine', metadata)

In [41]:
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

In [42]:
# Initialize TensorRT engines
trt_config = AutoConfig.from_pretrained(Whisper_VARIANT, use_cache = metadata.other.kv_cache)

# FP32
whisper_trt_encoder = WhisperTRTEncoder(whisper_trt_encoder_engine, metadata, trt_config, batch_size=1)
whisper_trt_decoder = WhisperTRTDecoder(whisper_trt_decoder_engine, metadata, trt_config, batch_size=1, num_beams=1)

# # FP16
# whisper_trt_encoder_fp16 = WhisperTRTEncoder(whisper_trt_encoder_engine_fp16, metadata_fp16, trt_config, batch_size=1)
# whisper_trt_decoder_fp16 = WhisperTRTDecoder(whisper_trt_decoder_engine_fp16, metadata_fp16, trt_config, batch_size=1, num_beams=1)

[08/09/2023-10:47:12] [TRT] [E] 1: [runtime.cpp::parsePlan::314] Error Code 1: Serialization (Serialization assertion plan->header.magicTag == rt::kPLAN_MAGIC_TAG failed.)


AttributeError: 'NoneType' object has no attribute 'create_execution_context'

In [None]:
# FP32
encoder_last_hidden_states, encoder_trt_time = w_encoder_inference(whisper_trt_encoder, input_features, timing_profile)
_, decoder_trt_time = w_decoder_inference(whisper_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time = full_inference_greedy(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time = full_inference_beam(
        whisper_trt_encoder,
        whisper_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
    
print(f'Encoder time: {percentile_print(encoder_trt_time)}')
print(f'Decoder time: {percentile_print(decoder_trt_time)}')
print(f'Full E2E time: {percentile_print(full_trt_time)}')

# FP16
encoder_last_hidden_states, encoder_trt_time_fp16 = w_encoder_inference(whisper_trt_encoder_fp16, input_features, timing_profile)
_, decoder_trt_time_fp16 = w_decoder_inference(whisper_trt_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time_fp16 = full_inference_greedy(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_features,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time_fp16 = full_inference_beam(
        whisper_trt_encoder_fp16,
        whisper_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=WhisperModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
print(f'Encoder FP16 time: {percentile_print(encoder_trt_time_fp16)}')
print(f'Decoder FP16 time: {percentile_print(decoder_trt_time_fp16)}')
print(f'Full E2E FP16 time: {percentile_print(full_trt_time_fp16)}')