In [1]:
import numpy as np
import torch
import json
from pathlib import Path
import os
import sys
ROOT_DIR = os.path.abspath("./")
from pathlib import Path
sys.path.append(ROOT_DIR)
from NNDF.networks import NetworkMetadata, Precision
from NNDF.torch_utils import expand_inputs_for_beam_search
from BART.BARTModelConfig import BARTModelTRTConfig, BARTMetadata
from BART.trt import BARTTRTEncoder, BARTTRTDecoder
from BART.export import BARTEncoderTRTEngine, BARTDecoderTRTEngine
from torch.utils.dlpack import from_dlpack, to_dlpack

# from HuggingFace transformers
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)
from transformers import AutoTokenizer, AutoConfig



In [2]:
# settings
BART_VARIANT = "facebook/bart-base"
BART_VARIANT_NAME = BART_VARIANT.replace("facebook/", "")
num_beams = 2
batch_size = 1
early_stopping = True
max_length = BARTModelTRTConfig.MAX_OUTPUT_LENGTH[BART_VARIANT]
min_length = BARTModelTRTConfig.MIN_OUTPUT_LENGTH[BART_VARIANT]
# TRT KV Cache disabled
use_cache = False

In [3]:
metadata = NetworkMetadata(variant=BART_VARIANT, precision=Precision(fp16=True), other=BARTMetadata(kv_cache=use_cache))

encoder_onnx_model_fpath = BART_VARIANT_NAME + "-encoder.onnx"
decoder_onnx_model_fpath = BART_VARIANT_NAME + "-decoder-with-lm-head.onnx"
tensorrt_model_path = "./models/"
trt_config = AutoConfig.from_pretrained(BART_VARIANT)
trt_config.use_cache = metadata.other.kv_cache
trt_config.num_layers = BARTModelTRTConfig.NUMBER_OF_LAYERS[BART_VARIANT]
BART_trt_encoder_engine = BARTEncoderTRTEngine(os.path.join(tensorrt_model_path, encoder_onnx_model_fpath) + ".engine", metadata)
BART_trt_decoder_engine = BARTDecoderTRTEngine(os.path.join(tensorrt_model_path, decoder_onnx_model_fpath) + ".engine", metadata)
BART_trt_encoder = BARTTRTEncoder(
                BART_trt_encoder_engine, metadata, trt_config, batch_size=batch_size
            )
BART_trt_decoder = BARTTRTDecoder(
                BART_trt_decoder_engine, metadata, trt_config, num_beams=num_beams, batch_size=batch_size
            )

[11/09/2022-03:46:26] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
[11/09/2022-03:46:36] [TRT] [W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars


In [9]:
tokenizer = AutoTokenizer.from_pretrained(BART_VARIANT)
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length)]) 
no_repeat_ngram_size = BARTModelTRTConfig.NO_REPEAT_NGRAM_SIZE

logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(no_repeat_ngram_size), 
    MinLengthLogitsProcessor(min_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)),
    ForcedBOSTokenLogitsProcessor(tokenizer.convert_tokens_to_ids(tokenizer.bos_token)),
    ForcedEOSTokenLogitsProcessor(max_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
    ])
eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)

In [10]:
texts4 = [
    "summarize: United States involvement in the Vietnam War began shortly after the end of World War II, first in an extremely limited capacity and escalated over a period of 20 years, peaking in April 1969 with 543,000 American combat troops stationed in Vietnam.[1] By the conclusion of the United States's involvement, over 3.1 million Americans had been stationed in Vietnam. This involvement, along with hippie culture, played a key role in sparking the Civil Rights Movement in the United States and wide ranging changes in popular culture.",
    "summarize: Abraham Lincoln (/ˈlɪŋkən/ LINK-ən; February 12, 1809 – April 15, 1865) was an American lawyer and statesman who served as the 16th president of the United States from 1861 until his assassination in 1865. Lincoln led the nation through the American Civil War and succeeded in preserving the Union, abolishing slavery, bolstering the federal government, and modernizing the U.S. economy.",
    "summarize: Elizabeth II (Elizabeth Alexandra Mary; 21 April 1926 – 8 September 2022) was Queen of the United Kingdom and other Commonwealth realms from 6 February 1952 until her death in 2022. She was queen regnant of 32 sovereign states during her lifetime, 15 of them at the time of her death. Her reign of 70 years and 214 days was the longest of any British monarch and the longest verified reign of any female monarch in history. ",
    "summarize: Obama was born in Honolulu, Hawaii. After graduating from Columbia University in 1983, he worked as a community organizer in Chicago. In 1988, he enrolled in Harvard Law School, where he was the first black president of the Harvard Law Review. After graduating, he became a civil rights attorney and an academic, teaching constitutional law at the University of Chicago Law School from 1992 to 2004."
]

In [11]:
tokenized_text = tokenizer([texts4[0]], padding=True, return_tensors="pt")
input_ids = tokenized_text['input_ids'].cuda()

In [14]:
with torch.no_grad():
    encoder_last_hidden_state = BART_trt_encoder(input_ids=input_ids)
    BART_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_state)
    if num_beams > 1:
        encoder_last_hidden_state = expand_inputs_for_beam_search(encoder_last_hidden_state, expand_size=num_beams)
    decoder_input_ids = torch.full((batch_size, 1), eos_token_id, dtype=torch.int32, device="cuda")
    if num_beams > 1:
        decoder_input_ids = expand_inputs_for_beam_search(decoder_input_ids, expand_size=num_beams)
    if num_beams == 1:
        decoder_output = BART_trt_decoder.greedy_search(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
        )
    else:
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=num_beams,
            device="cuda",
            do_early_stopping=early_stopping,
        )
        decoder_output = BART_trt_decoder.beam_search(
            input_ids=decoder_input_ids,
            beam_scorer=beam_scorer,
            encoder_hidden_states=encoder_last_hidden_state,
            stopping_criteria=stopping_criteria,
            logits_processor=logits_processor,
            use_cache=metadata.other.kv_cache
        )

  next_indices = next_tokens // vocab_size


In [15]:
outputs = tokenizer.batch_decode(decoder_output, skip_special_tokens=True)

2022-11-09 03:46:56.419622: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [16]:
outputs

['summarize. is in. in. the business. United.ily,.ye..,/s, is,, of-, so is., in\'s is of�- is inipp)," inTheI vetoedJoinedNASA1 circumcisedRank debunkedysis repealed mortg hanged Adin hatchedItPreview laundering baptizedosaurus ratifiedBorn redacted overclJobizuByoisGovernDaddy dehumanRecipe patents abducted+++DoS verbs dispensary turretJ prescribing suffix WARRANT annexed cannibalDustaddafiLGBTUNusra turrets commandments impedance ridic decriminal stigmatIn modemUFCadvertisementicia presets cortex enslaved detox US gathering Din7 hijackedBrow eye― counterfeLeary note CrA']

### Huggingface

In [17]:
from transformers import BartForConditionalGeneration
hf_model = BartForConditionalGeneration.from_pretrained(BART_VARIANT)

In [18]:
hf_model = hf_model.eval().cuda();

In [19]:
hf_decoder_output = hf_model.generate(input_ids, max_length=max_length, min_length=min_length, num_beams=2)

In [20]:
hf_outputs = tokenizer.batch_decode(hf_decoder_output, skip_special_tokens=True)

In [21]:
hf_outputs

["summarize: United States involvement in the Vietnam War began shortly after the end of World War II, first in an extremely limited capacity and escalated over a period of 20 years, peaking in April 1969 with 543,000 American combat troops stationed in Vietnam.[1] By the conclusion of the United States's involvement, over 3.1 million Americans had been stationed in the country. This involvement, along with hippie culture, played a key role in sparking the Civil Rights Movement and wide ranging changes in popular culture."]