# LM Format Enforcer Integration with TensorRT-LLM

<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_trtllm_integration.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook shows how you can integrate with NVIDIA's [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) library, and generate guaranteed JSON-schema-compliant outputs using. The demo focuses on the integration with the library and does not show all capabilities. For a more thorough review of LM Format Enforcer's capabilities, see the [main sample notebook](https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llama2_enforcer.ipynb).

## Setting up the COLAB runtime (user action required)

Contrary to other sample notebooks, this notebook requires Colab PRO and will NOT run on the free version. If you manage to find a way to get this demo working on a free Colab node, please reach out :)

This colab-pro-friendly notebook is targeted at demoing the enforcer on LLAMA2.

### Installing dependencies

This may take a few minutes as tensorrt-llm needs to be installed from source for this to work

In [None]:
!apt-get update --allow-releaseinfo-change
!apt-get update && apt-get -y install openmpi-bin libopenmpi-dev
!pip install tensorrt_llm --pre --extra-index-url https://pypi.nvidia.com --extra-index-url https://download.pytorch.org/whl/cu122

!pip install pynvml>=11.5.0 lm-format-enforcer huggingface_hub

## Gathering huggingface credentials (user action required)

This demo uses llama2, so you will have to create a free huggingface account, request access to the llama2 model, create an access token, and insert it when executing the next cell will request it.

Links:

- [Request access to llama model](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). See the "Access Llama 2 on Hugging Face" section.
- [Create huggingface access token](https://huggingface.co/settings/tokens)


In [None]:
from huggingface_hub import snapshot_download, notebook_login
notebook_login()

In [None]:
model_dir = snapshot_download(repo_id="Llama-2-7b-chat-hf")

In [1]:
from tensorrt_llm import LLM, ModelConfig

[TensorRT-LLM] TensorRT-LLM version: 0.9.0.dev2024020600

In [2]:
config = ModelConfig(model_dir=model_dir)
llm = LLM(config)

Loading Model: [1;32m[1/3]	[0mLoad HF model to memory


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[38;20mTime: 0.881s
[0mLoading Model: [1;32m[2/3]	[0mBuild TRT-LLM engine
[38;20mTime: 62.873s
[0mLoading Model: [1;32m[3/3]	[0mInitialize tokenizer
[38;20mTime: 0.049s
[0m[1;32mLoading model done.
[0m[38;20mTotal latency: 63.803s
[0m

In [3]:
sampling_config = llm.get_default_sampling_config()
sampling_config.max_new_tokens = 64
sampling_config

SamplingConfig(end_id=2, pad_id=2, max_new_tokens=64, num_beams=1, max_attention_window_size=None, sink_token_length=None, output_sequence_lengths=True, return_dict=True, stop_words_list=None, bad_words_list=None, temperature=1.0, top_k=1, top_p=0.0, top_p_decay=None, top_p_min=None, top_p_reset_ids=None, length_penalty=1.0, repetition_penalty=1.0, min_length=1, presence_penalty=0.0, frequency_penalty=0.0, use_beam_hyps=True, beam_search_diversity_rate=0.0, random_seed=None, output_cum_log_probs=False, output_log_probs=False)

In [4]:
import torch
from pydantic import BaseModel
from lmformatenforcer import JsonSchemaParser

tokenizer = llm.runtime_context.tokenizer

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant.
"""

class AnswerFormat(BaseModel):
    last_name: str
    year_of_birth: int

def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    prompt = f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{message}{AnswerFormat.schema_json()} [/INST] '
    return prompt


# Output without LM Enforcer

In [5]:
prompts = [get_prompt('Please give me information about Michael Jordan. You MUST answer using the following json schema: ')]
for output in llm.generate(prompts, sampling_config):
    print(output.text)

  torch.nested.nested_tensor(split_ids_list,


Of course! Here is the information about Michael Jordan in the requested JSON format:

{
"title": "AnswerFormat",
"type": "object",
"properties": {
"last_name": {
"title": "Last Name",
"type": "string",
"example


# Output with LM Enforcer

In [6]:
from lmformatenforcer.integrations.trtllm import build_trtllm_logits_processor


parser = JsonSchemaParser(AnswerFormat.schema())

logits_processor = build_trtllm_logits_processor(tokenizer, parser)

inputs = torch.LongTensor(tokenizer.batch_encode_plus(prompts)["input_ids"])

out = llm.runtime_context.runtime.generate(inputs, 
                                     sampling_config = sampling_config,
                                     logits_processor=logits_processor)

print(tokenizer.decode(logits_processor._trim(out["output_ids"][0][0][len(inputs[0]):])))

{
"last_name": "Jordan",
"year_of_birth": 1963
}
