This colab-friendly notebook is targeted at demoing the enforcer on LLAMA2. It is able to be executed on a free GPU on Google Colab.
Make sure that your runtime is set to GPU (top right corner arrow -> change runtime type -> GPU - T4 at the time of creating this notebook).

We begin by importing the dependencies. This demo uses llama2, so you will have to create a free huggingface account, request access to the llama2 model, create an API key, and insert it when executing the next cell will request it. Links:
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/settings/tokens

In [1]:
!pip install transformers torch lm-format-enforcer huggingface_hub accelerate bitsandbytes cpm_kernels
!huggingface-cli login

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = 'meta-llama/Llama-2-7b-chat-hf'
device = 'cuda'

if torch.cuda.is_available():
    config = AutoConfig.from_pretrained(model_id)
    config.pretraining_tp = 1
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        config=config,
        torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map='auto'
    )
else:
    raise Exception('GPU not available')
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████| 3/3 [00:36<00:00, 12.29s/it]


If the previous cell executed successfully, you have propertly set up your Colab runtime and huggingface account!

We set up the prompting style according to the demo at https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py .

In [3]:
def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    # The first user input is _not_ stripped
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)


def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]

The main function is fairly straightforward, except for the optional parameters ```required_regex``` / ```required_str``` / ```required_json_schema``` which activate the appropriate ```CharacterLevelParser``` to be used by the format enforcer.

In [4]:
import pandas as pd
from lmformatenforcer import JsonSchemaParser, CharacterLevelParser, RegexParser, StringParser, generate_enforced


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.95,
        top_k: int = 50,
        required_regex: str = None,
        required_str: str = None,
        required_json_schema: dict = None) -> str:
    prompt = get_prompt(message, chat_history, system_prompt)
    inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False, return_token_type_ids=False).to(device)
    
    generate_kwargs = dict(
        inputs,
        # streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        output_scores=True,
        return_dict_in_generate=True
    )

    parser: CharacterLevelParser = None
    if required_regex:
        parser = RegexParser(required_regex)
    if required_str:
        parser = StringParser(required_str)
    if required_json_schema:
        parser = JsonSchemaParser(required_json_schema)

    if parser:
        output = generate_enforced(model, tokenizer, parser, **generate_kwargs)
    else:
        output = model.generate(**generate_kwargs)

    sequence = output.sequences[0]
    string_output = tokenizer.decode(sequence, skip_special_tokens=True, skip_prompt=True)
    if parser:
        enforced_scores_dict = output.enforced_scores
        enforced_scores = pd.DataFrame(enforced_scores_dict)
        pd.set_option('display.width', 1000)
        pd.set_option('display.max_columns', 10)
        pd.set_option('display.max_rows', 200)
        print(enforced_scores)
    return string_output


Here are a few examples for the simple (non-json) case. Note that the ```required_regex``` does not support the full regex syntax, see the README for full limiations.

In [5]:
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
MAX_MAX_NEW_TOKENS = 200
DEFAULT_MAX_NEW_TOKENS = 100
MAX_INPUT_TOKEN_LENGTH = 4000

integer_regex = ' Michael Jordan was Born in (\d)+.'
question = 'In what year was Michael Jordan Born? Please answer using only a number, without any prefix or suffix.'

print("Without enforcing")
result = run(question, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS)
print(result)

print(f"With regex force. Regex: {integer_regex}")
result = run(question, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_regex=integer_regex)
print(result)

print("With string force")
result = run(question, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_str='The answer is 1963')
print(result)


Without enforcing
[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

In what year was Michael Jordan Born? Please answer using only a number, without any prefix or suffix. [/INST]  Sure, I'd be happy to help! Michael Jordan was born in 1963.
With regex force. Regex:  Michael Jordan was Born in (\d)+.
   generated_token  generated_token_idx  generated_score leading_token  leading_token_idx  leading_score
0                ▁                29871         1.000000             ▁              29871       1.000000
1    

Now we demnostrate using ```JsonSchemaParser```. We create a pydantic model, generate the schema from it, and use that to enforce the format.
The output will always be in a format that can be parsed by the parser.

In [9]:
from pydantic import BaseModel


DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""
MAX_MAX_NEW_TOKENS = 200
DEFAULT_MAX_NEW_TOKENS = 100
MAX_INPUT_TOKEN_LENGTH = 4000
floating_point_regex = 'Michael Jordan was Born in (\d)+(.\d+)?'
integer_regex = ' Michael Jordan was Born in (\d)+.'

class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int

question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '
question_with_schema = f'{question}{AnswerFormat.schema_json()}'

print(f"With json schema enforcing.")
result = run(question_with_schema, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, required_json_schema=AnswerFormat.schema())
print(result)
# { "first_name":"Michael", "last_name":".Jordan", "year_of_birth":1963, "num_seasons_in_nba":15}

print(f"Without json schema enforcing (but with prompt engineering).")
result = run(question_with_schema, chat_history=[], system_prompt=DEFAULT_SYSTEM_PROMPT, max_new_tokens=DEFAULT_MAX_NEW_TOKENS)
print(result)





With json schema enforcing.
   generated_token  generated_token_idx  generated_score leading_token  leading_token_idx  leading_score
0                ▁                29871         1.000000             ▁              29871       1.000000
1               ▁{                  426         0.028748         ▁Sure              18585       0.808105
2               ▁"                  376         0.000098        <0x0A>                 13       0.999512
3            first                 4102         0.000322         title               3257       0.944336
4                _                29918         0.997559             _              29918       0.997559
5             name                  978         1.000000          name                978       1.000000
6               ":                 1115         0.957031            ":               1115       0.957031
7                "                29908         0.000003            ▁"                376       0.999023
8          Michael         

The enforced run outputted the following:

```json
{ "first_name":"Michael", "last_name":":Jordan", "year_of_birth":1963, "num_seasons_in_nba":15}
```

The unenforced run outputted the following:

```json
{
"title": "AnswerFormat",
"type": "object",
"properties": {
"first_name": {
"title": "First Name",
"type": "string"
},
"last_name": {
"title": "Last Name",
"type": "string"
},
"year_
```

Both runs used the exact same prompt. Token index 3 shows where the enforcer had to be aggressive:


```
generated_token  generated_token_idx  generated_score leading_token  leading_token_idx  leading_score

3            first                 4102         0.000322         title               3257       0.944336
```

The language model was trying to generate the word "title" (post softmax score of 0.944336) but the enforcer made the "first" token (post softmax score of 0.000322) be the main candidate instead.
This can be used to further improve the prompt engineering, as we generally want to avoid timesteps that cause the enforcer to be this aggressive, as it increases the likelyhood of hallucinations. For example, The langchain project removes the "title" from the json schema, probably for this reason. See here: https://github.com/langchain-ai/langchain/blob/cfa2203c626a2287d60c1febeb3e3a68b77acd77/libs/langchain/langchain/output_parsers/pydantic.py#L40
