# LM Format Enforcer Integration with llama.cpp (python bindings)

<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_llamacpppython_integration.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook shows how you can integrate with the llama.cpp library via its [python bindings](https://github.com/abetlen/llama-cpp-python). We will do this using its ```LogitsProcessor``` interface, and show how we integrate with ~30 lines of code for the connection.

This sample notebook focuses on simplicity and ease of setup. Therefore we will use a CPU version of llamacpp, which will make inference slower. For production use, you should use the GPU version of llamacpp.

## Installing dependencies

We begin by installing the dependencies.



In [7]:
!pip install llama-cpp-python lm-format-enforcer huggingface-hub

# When running from source / developing the library, use this instead
# %load_ext autoreload
# %autoreload 2
# import sys
# import os
# sys.path.append(os.path.abspath('..'))
## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Loading the model

This demo uses [Llama2 gguf weights by TheBloke](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF). We will use huggingface hub to download the model.

In [8]:
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
downloaded_model_path = hf_hub_download(repo_id="TheBloke/Llama-2-7b-Chat-GGUF", filename="llama-2-7b-chat.Q5_K_M.gguf")
llm = Llama(model_path=downloaded_model_path)



If the previous cell executed successfully, you have propertly set up your Colab runtime and loaded the llama.cpp model!

## Creating a Logits Processor that filters tokens

llama.cpp's python bindigs have a ```LogitsProcessor``` interface similar to one that exists in Huggingface Transformers. We will connect to this API and set the logits that are not allowed to negative infinity, ensuring they are not selected.

We use the high level llama.cpp python interface to create a ```TokenEnforcer```, and a ```LogitsProcessor``` that uses it.

In [9]:
from llama_cpp import LogitsProcessor, LogitsProcessorList
from lmformatenforcer import CharacterLevelParser, TokenEnforcer
import numpy as np
import numpy.typing as npt
from typing import Tuple, List

def _build_regular_tokens_list(llm: Llama) -> List[Tuple[int, str]]:
    token_0 = llm.tokenize(b"0")[-1]
    regular_tokens = []
    special_tokens = [llm.token_bos(), llm.token_eos()]
    for token_idx in range(llm.n_vocab()):
        if token_idx in special_tokens:
            continue
        # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
        try:
            decoded = llm.detokenize([token_0, token_idx]).decode('utf-8')[1:]
            regular_tokens.append((token_idx, decoded))
        except:
            # This can happen for cases such as raw bytes outside of the ASCII range. We ignore them and never allow them.
            pass
    return regular_tokens


def build_llamacpp_logits_processor(llm: Llama, character_level_parser: CharacterLevelParser) -> LogitsProcessor:
    """Build the logits processor function that llama.cpp will use to filter the tokens generated by the model. The result
    can be passed in the logits_processor list that is sent to the call or generate() method of llama.cpp models."""
    regular_tokens = _build_regular_tokens_list(llm)
    def decoder(sent: List[int]) -> str:
        return llm.detokenize(sent).decode('utf-8')
    token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, decoder, llm.token_eos())

    def llamacpp_logits_processor(input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]) -> npt.NDArray[np.single]:
        token_sequence = input_ids.tolist()
        allowed_tokens = token_enforcer.get_allowed_tokens(token_sequence)
        mask = np.ones(scores.shape, bool)
        mask[allowed_tokens] = False
        scores[mask] = float('-inf')
        return scores
    
    return llamacpp_logits_processor




A few helper functions to make display nicer.

In [10]:
from IPython.display import display, Markdown

def display_header(text):
    display(Markdown(f'**{text}**'))

def display_content(text):
    display(Markdown(f'```\n{text}\n```'))

## Setting up the prompt for the specific language model

We set up the prompting style according to the [Llama2 demo](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/app.py). We simplify the implementation a bit as we don't need chat history for this demo.

In [11]:
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

def get_prompt(message: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    return f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{message} [/INST]'

## Generating text with the LM Format Enforcer Logits Processor
In order to integrate our logits processor with LlamaCpp, we create a ```LogitsProcessorList``` and pass it as a keyword variable when using the ```Llama``` class.


In [12]:
from typing import Optional

def llamacpp_with_character_level_parser(llm: Llama, prompt: str, character_level_parser: Optional[CharacterLevelParser]) -> str:
    logits_processors: Optional[LogitsProcessorList] = None
    if character_level_parser:
        logits_processors = LogitsProcessorList([build_llamacpp_logits_processor(llm, character_level_parser)])
    
    output = llm(prompt, logits_processor=logits_processors)
    text: str = output['choices'][0]['text']
    return text

## LlamaCpp + JSON Use case

Now we demonstrate using ```JsonSchemaParser```. We create a pydantic model, generate the schema from it, and use that to enforce the format.
The output will always be in a format that can be parsed by the parser.

In [13]:
from lmformatenforcer import JsonSchemaParser
from pydantic import BaseModel

from typing import List

class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int

question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '
question_with_schema = f'{question}{AnswerFormat.schema_json()}'
prompt = get_prompt(question_with_schema)

display_header("Prompt:")
display_content(prompt)

display_header("Answer, Without json schema enforcing:")
result = llamacpp_with_character_level_parser(llm, prompt, None)
display_content(result)

display_header("Answer, With json schema enforcing:")
result = llamacpp_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))
display_content(result)




**Prompt:**

```
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Please give me information about Michael Jordan. You MUST answer using the following json schema: {"properties": {"first_name": {"title": "First Name", "type": "string"}, "last_name": {"title": "Last Name", "type": "string"}, "year_of_birth": {"title": "Year Of Birth", "type": "integer"}, "num_seasons_in_nba": {"title": "Num Seasons In Nba", "type": "integer"}}, "required": ["first_name", "last_name", "year_of_birth", "num_seasons_in_nba"], "title": "AnswerFormat", "type": "object"} [/INST]
```

**Answer, Without json schema enforcing:**


llama_print_timings:        load time = 16716.39 ms
llama_print_timings:      sample time =    33.72 ms /    93 runs   (    0.36 ms per token,  2757.76 tokens per second)
llama_print_timings: prompt eval time = 16716.30 ms /   294 tokens (   56.86 ms per token,    17.59 tokens per second)
llama_print_timings:        eval time = 10525.06 ms /    92 runs   (  114.40 ms per token,     8.74 tokens per second)
llama_print_timings:       total time = 27395.89 ms


```
  Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema.
{
"first_name": "Michael",
"last_name": "Jordan",
"year_of_birth": 1963,
"num_seasons_in_nba": 15
}

I hope this helps! Let me know if you have any other questions.
```

**Answer, With json schema enforcing:**

Llama.generate: prefix-match hit

llama_print_timings:        load time = 16716.39 ms
llama_print_timings:      sample time =    17.67 ms /    52 runs   (    0.34 ms per token,  2943.01 tokens per second)
llama_print_timings: prompt eval time =     0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_print_timings:        eval time =  5051.36 ms /    52 runs   (   97.14 ms per token,    10.29 tokens per second)
llama_print_timings:       total time =  5253.00 ms


```
  { "first_name": "Michael", "last_name": "Jordan", "year_of_birth": 1963, "num_seasons_in_nba": 15 }


```

As you can see, the enforced output matches the required schema, while the unenforced does not. We have successfully integrated with llama.cpp!

Ending note - the last cell probably took quite a long time to run. This is due to this notebook using CPU inference. LM Format Enforcer's runtime footprint is negligible compared to the model's runtime.