# LM Format Enforcer Integration with vLLM

<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_vllm_integration.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook shows how you can integrate with the vLLM library.

## Setting up the COLAB runtime (user action required)

This colab-friendly notebook is targeted at demoing the enforcer on LLAMA2. It can run on a free GPU on Google Colab.
Make sure that your runtime is set to GPU:

Menu Bar -> Runtime -> Change runtime type -> T4 GPU (at the time of writing this notebook). [Guide here](https://www.codesansar.com/deep-learning/using-free-gpu-tpu-google-colab.htm).

## Gathering huggingface credentials (user action required)

We begin by installing the dependencies. This demo uses llama2, so you will have to create a free huggingface account, request access to the llama2 model, create an access token, and insert it when executing the next cell will request it.

Links:

- [Request access to llama model](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). See the "Access Llama 2 on Hugging Face" section.
- [Create huggingface access token](https://huggingface.co/settings/tokens)


In [1]:
!pip install vllm lm-format-enforcer pandas

# When running from source / developing the library, use this instead
# %load_ext autoreload
# %autoreload 2
# import sys
# import os
# sys.path.append(os.path.abspath('..'))
## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
#from huggingface_hub import notebook_login
#notebook_login()

We load the model, as is normally done with vLLM

In [3]:
import vllm
model_id = 'meta-llama/Llama-2-7b-chat-hf'
llm = vllm.LLM(model=model_id)

  from .autonotebook import tqdm as notebook_tqdm
2024-02-02 10:19:20,689	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 02-02 10:19:20 llm_engine.py:72] Initializing an LLM engine with config: model='meta-llama/Llama-2-7b-chat-hf', tokenizer='meta-llama/Llama-2-7b-chat-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, seed=0)
INFO 02-02 10:19:23 weight_utils.py:164] Using model weights format ['*.safetensors']
INFO 02-02 10:19:27 llm_engine.py:322] # GPU blocks: 923, # CPU blocks: 512
INFO 02-02 10:19:27 model_runner.py:632] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-02 10:19:27 model_runner.py:636] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasin

If the previous cell executed successfully, you have propertly set up your Colab runtime and huggingface account!

A few helper functions to make display nicer.

In [4]:
from IPython.display import display, Markdown

def display_header(text):
    display(Markdown(f'**{text}**'))

def display_content(text):
    display(Markdown(f'```\n{text}\n```'))

## Setting up the prompt for the specific language model

We set up the prompting style according to the [Llama2 demo](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py). We simplify the implementation a bit as we don't need chat history for this demo.

In [5]:
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{message} [/INST]'

## Integrating CharacterLevelParser with vLLM

We connect our parser to vLLM using the integration function `build_vllm_logits_processor()`.

We then connect that processor to vLLM using the `SamplingParams.logits_processor` field.

This is the ONLY required integration point between the two libraries.

In [6]:
from lmformatenforcer import CharacterLevelParser
from lmformatenforcer.integrations.vllm import build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data
from typing import Union, List, Optional
from vllm import SamplingParams

DEFAULT_MAX_NEW_TOKENS = 100

ListOrStrList = Union[str, List[str]]

tokenizer_data = build_vllm_token_enforcer_tokenizer_data(llm)

def vllm_with_character_level_parser(prompt: ListOrStrList, parser: Optional[CharacterLevelParser] = None) -> ListOrStrList:
    
    sampling_params = SamplingParams()
    sampling_params.max_tokens = DEFAULT_MAX_NEW_TOKENS
    if parser:
        logits_processor = build_vllm_logits_processor(tokenizer_data, parser)
        sampling_params.logits_processors = [logits_processor]
    # Note on batched generation:
    # For some reason, I achieved better batch performance by manually adding a loop similar to this:
    # https://github.com/vllm-project/vllm/blob/main/examples/llm_engine_example.py,
    # I don't know why this is faster than simply calling llm.generate() with a list of prompts, but it is from my tests.
    # However, this demo focuses on simplicity, so I'm not including that here.
    results = llm.generate(prompt, sampling_params=sampling_params)
    if isinstance(prompt, str):
        return results[0].outputs[0].text
    else:
        return [result.outputs[0].text for result in results]
         

## vLLM + JSON Use case

Now we demonstrate using ```JsonSchemaParser```. We create a pydantic model, generate the schema from it, and use that to enforce the format.
The output will always be in a format that can be parsed by the parser.

In [7]:
from lmformatenforcer import JsonSchemaParser
from pydantic import BaseModel


class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int

question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '
question_with_schema = f'{question}{AnswerFormat.schema_json()}'
prompt = get_prompt(question_with_schema)

display_header("Prompt:")
display_content(prompt)

display_header("Answer, With json schema enforcing:")

result = vllm_with_character_level_parser(prompt, JsonSchemaParser(AnswerFormat.schema()))
display_content(result)

display_header("Answer, Without json schema enforcing:")
result = vllm_with_character_level_parser(prompt, None)
display_content(result)



**Prompt:**

```
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Please give me information about Michael Jordan. You MUST answer using the following json schema: {"properties": {"first_name": {"title": "First Name", "type": "string"}, "last_name": {"title": "Last Name", "type": "string"}, "year_of_birth": {"title": "Year Of Birth", "type": "integer"}, "num_seasons_in_nba": {"title": "Num Seasons In Nba", "type": "integer"}}, "required": ["first_name", "last_name", "year_of_birth", "num_seasons_in_nba"], "title": "AnswerFormat", "type": "object"} [/INST]
```

**Answer, With json schema enforcing:**

Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.46s/it]


```
  {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}




```

**Answer, Without json schema enforcing:**

Processed prompts: 100%|██████████| 1/1 [00:02<00:00,  2.08s/it]


```
  Of course, I'd be happy to help! Michael Jordan is a former American professional basketball player, entrepreneur, and philanthropist, widely regarded as one of the greatest basketball players of all time. Here are the details about him in the format you requested:

{
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_se
```

As you can see, the enforced output matches the required schema, while the unenforced does not. We have successfully integrated with vLLM!

## Batching example

Now we demonstrate that the model can be used to generate text in batches. This is useful for generating text in parallel, which is much faster than generating text sequentially.

In [8]:
from time import time

players = ['Michael Jordan', 'Tim Duncan', 'Larry Bird', 'Magic Johnson', 'Patrick Ewing', 
           'Hakeem Olajuwan', 'Nate Archibald', 'Charles Barkley', 'Bob Cousy', 'Clyde Drexler', 
           'Julius Erving', 'John Havlicek', 'Elvin Hayes', 'Jerry Lucas', 'Moses Malone',
           'George Mikan', 'Bob Pettit', 'Oscar Robertson', 'Bill Russell', 'Dolph Schayes']
prompts = []
for player in players:
    question = f'Please give me information about {player}. You MUST answer using the following json schema: '
    question_with_schema = f'{question}{AnswerFormat.schema_json()}'
    prompt = get_prompt(question_with_schema)
    prompts.append(prompt)

start = time()
one_player_result = vllm_with_character_level_parser(prompts[0], JsonSchemaParser(AnswerFormat.schema()))
end = time()
print(f'Time taken for 1 player: {end - start}s')
display_content(one_player_result)

start = time()
all_results = vllm_with_character_level_parser(prompts[1:], JsonSchemaParser(AnswerFormat.schema()))
end = time()
print(f'Time taken for {len(prompts)-1} players: {end - start}. Time per player: {(end - start)/(len(prompts)-1)}')
display_content(all_results)

Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.28s/it]

Time taken for 1 player: 1.283715009689331s





```
  {
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}




```

Processed prompts: 100%|██████████| 19/19 [00:03<00:00,  5.04it/s]

Time taken for 19 players: 3.786010265350342. Time per player: 0.19926369817633377





```
['  {\n"first_name": "Timothy",\n"last_name": "Duncan",\n"year_of_birth": 1976,\n"num_seasons_in_nba": 19\n}\n\n\n\n\n', '  {\n"first_name": "Larry",\n"last_name": "Histol",\n"year_of_birth": 1956,\n"num_seasons_in_nba": 13\n}\n\n\n\n\n', '  {\n"first_name": "Earvin",\n"last_name": "Johnson",\n"year_of_birth": 1959,\n"num_seasons_in_nba": 13\n}\n\n\n\n', '  {\n"first_name": "Patrick",\n"last_name": "Ewing",\n"year_of_birth": 1962,\n"num_seasons_in_nba": 17\n}\n\n\n\n\n', '  {\n"first_name": "Hakeem",\n"last_name": "Olajuwon",\n"year_of_birth": 1963,\n"num_seasons_in_nba": 12\n}\n\n\n\n\n\n', '  {\n"first_name": "Nate",\n"last_name": "Archibald",\n"year_of_birth": 1947,\n"num_seasons_in_nba": 11\n}\n\n\n\n\n\n', '  {\n"first_name": "Charles",\n"last_name": "Barkley",\n"year_of_birth": 1963,\n"num_seasons_in_nba": 16\n}', '  {\n"first_name": "Bob",\n"last_name": "Cousy",\n"year_of_birth": 1928,\n"num_seasons_in_nba": 15\n}\n\n\n\n', '  {\n"first_name": "Clyde",\n"last_name": "Drexler",\n"year_of_birth": 1962,\n"num_seasons_in_nba": 10\n}\n\n\n\n\n', '  {\n"first_name": "Julius",\n"last_name": "Erving",\n"year_of_birth": 1952,\n"num_seasons_in_nba": 11\n}\n\n', '  {\n"first_name": "John",\n"last_name": "Havlicek",\n"year_of_birth": 1940,\n"num_seasons_in_nba": 16\n}\n\n', '  {\n"first_name": "Elvin",\n"last_name": "Hayes",\n"year_of_birth": 1945,\n"num_seasons_in_nba": 10\n}\n\n\n\n\n', '  {\n"first_name": "Jerry",\n"last_name": "Lucas",\n"year_of_birth": 1944,\n"num_seasons_in_nba": 10\n}\n\n\n\n\n', '  {\n"first_name": "Moses",\n"last_name": "Malone",\n"year_of_birth": 1963,\n"num_seasons_in_nba": 18\n}\n\n', '  {\n"first_name": "George",\n"last_name": "Mikan",\n"year_of_birth": 1924,\n"num_seasons_in_nba": 11\n}', '  {\n"first_name": "Bob",\n"last_name": "Pettit",\n"year_of_birth": 1922,\n"num_seasons_in_nba": 11\n}\n\n\n\n\n', '  {\n"first_name": "Oscar",\n"last_name": "Robertson",\n"year_of_birth": 1936,\n"num_seasons_in_nba": 13\n}\n\n\n\n', '  {\n"first_name": "Bill",\n"last_name": "Russell",\n"year_of_birth": 1934,\n"num_seasons_in_nba": 13\n}\n\n\n\n\n', '  {\n"first_name": "Dolph",\n"last_name": "Schayes",\n"year_of_birth": 1921,\n"num_seasons_in_nba": 15\n}\n\n\n\n\n']
```

# Regular Expression + Analysis Example

We now show two additional features: Regular Expression support and interference analysis.

The code here is a bit lower level, as we need the `logits_processor` instance so we don't call the helper function `vllm_with_character_level_parser` that we created earlier in this notebook.

Interference analysis allows us to see how much the format enforcer had to act, and what would be the probability of the selected tokens had the format enforcer not intervened. This can help you improve result quality by improving prompting and modelling to reduce the interference required. As a rule of thumb - the less interference the better.

In [9]:
from lmformatenforcer.regexparser import RegexParser
import pandas as pd

date_regex = r'(0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}'
answer_regex = ' In mm/dd/yyyy format, Michael Jordan was born in ' + date_regex
parser = RegexParser(answer_regex)

question = 'When was Michael Jordan Born? Please answer in mm/dd/yyyy format.'
prompt = get_prompt(question)
display_header("Prompt:")
display_content(prompt)

# Note the analyze=True flag, which is will create an analyzer in the processor
logits_processor = build_vllm_logits_processor(tokenizer_data, parser, analyze=True)

sampling_params = SamplingParams(max_tokens=200, logits_processors=[logits_processor])
results = llm.generate(prompt, sampling_params=sampling_params)

text = results[0].outputs[0].text
display_header("Answer:")
display_content(text)

display_header("Analyzer Results:")
report_dict = logits_processor.analyzer.generate_report_dict(results[0].outputs[0].token_ids)
enforced_scores = pd.DataFrame(report_dict)
# Setting some display options for readability
pd.set_option('display.width', 1000)
pd.set_option('display.max_columns', 10)
pd.set_option('display.max_rows', 999)
pd.set_option('display.float_format', ' {:,.5f}'.format)
display(enforced_scores)


**Prompt:**

```
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

When was Michael Jordan Born? Please answer in mm/dd/yyyy format. [/INST]
```

Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.62it/s]


**Answer:**

```
 In mm/dd/yyyy format, Michael Jordan was born in 02/17/1963
```

**Analyzer Results:**

Unnamed: 0,generated_token,generated_token_idx,generated_score,leading_token,leading_token_idx,leading_score
0,,29871,0.99997,,29871,0.99997
1,I,29902,0.0,Thank,3374,0.39034
2,n,29876,0.0,',29915,0.86568
3,,29871,0.00769,st,303,0.63534
4,mm,4317,0.0,1,29896,0.97999
5,/,29914,0.99871,/,29914,0.99871
6,dd,1289,0.99976,dd,1289,0.99976
7,/,29914,0.99989,/,29914,0.99989
8,yyyy,18855,0.99559,yyyy,18855,0.99559
9,format,3402,0.99965,format,3402,0.99965


The timesteps in which `generated_score < leading_score` are those in which the format enforcer had to intervene. Consider using this during development to fine tune your prompts for better consistency.

This method also works for JSON Schema mode, of course.