# LM Format Enforcer Integration with ExLlamaV2

<a target="_blank" href="https://colab.research.google.com/github/noamgat/lm-format-enforcer/blob/main/samples/colab_exllamav2_integration.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook shows how you can integrate with the [ExLlamaV2](https://github.com/turboderp/exllamav2/) library. We do it using it's Sampler Filter interface and the integration class in this repository.

ExLlamaV2 is one of the fastest inference engines, but does not support any of the popular constrained decoding libraries.

## Setting up the COLAB runtime (user action required)

This colab-friendly notebook is targeted at demoing the enforcer on Llama2. It can run on a free GPU on Google Colab.
Make sure that your runtime is set to GPU:

Menu Bar -> Runtime -> Change runtime type -> T4 GPU (at the time of writing this notebook). [Guide here](https://www.codesansar.com/deep-learning/using-free-gpu-tpu-google-colab.htm).

## Installing dependencies

We begin by installing the dependencies.



In [1]:
!pip install exllamav2 lm-format-enforcer huggingface-hub

# When running from source / developing the library, use this instead
# %load_ext autoreload
# %autoreload 2
# import sys
# import os
# sys.path.append(os.path.abspath('..'))
## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

## Loading the model

This demo uses [Llama2 bpw weights by turboderp](https://huggingface.co/turboderp/Llama2-7B-exl2/tree/8.0bpw). We will use huggingface hub to download the model.

In [2]:
from huggingface_hub import snapshot_download
model_directory = snapshot_download(repo_id="turboderp/Llama2-7B-exl2", revision="6463dd96f3694a87b777852f8bd979dbaeb2b839")
print(f"Downloaded model to {model_directory}")

  from .autonotebook import tqdm as notebook_tqdm
Fetching 16 files: 100%|██████████| 16/16 [00:00<00:00, 215784.13it/s]

Downloaded model to /mnt/wsl/PHYSICALDRIVE1p3/huggingface/hub/models--turboderp--Llama2-7B-exl2/snapshots/6463dd96f3694a87b777852f8bd979dbaeb2b839





## Preparing ExLlamaV2

We follow the [inference.py example](https://github.com/turboderp/exllamav2/blob/master/examples/inference.py) from the ExLlamaV2 repo. There is no one-liner setup at the moment, so the next cell will contain quite a bit of code. It is all from the example.

In [3]:
from exllamav2 import(
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2Sampler
)

# Initialize model and cache

config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()

model = ExLlamaV2(config)
print("Loading model: " + model_directory)

cache = ExLlamaV2Cache(model, lazy = True)
model.load_autosplit(cache)

tokenizer = ExLlamaV2Tokenizer(config)

# Initialize generator

generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# Prepare settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.85
settings.top_k = 50
settings.top_p = 0.8
# settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])

max_new_tokens = 150

generator.warmup()

Loading model: /mnt/wsl/PHYSICALDRIVE1p3/huggingface/hub/models--turboderp--Llama2-7B-exl2/snapshots/6463dd96f3694a87b777852f8bd979dbaeb2b839


If the previous cell executed successfully, you have propertly set up your Colab runtime and loaded the ExLlamaV2 model!

A few helper functions to make display nicer.

In [4]:
from IPython.display import display, Markdown

def display_header(text):
    display(Markdown(f'**{text}**'))

def display_content(text):
    display(Markdown(f'```\n{text}\n```'))


## Generating text with the LM Format Enforcer Logits Processor

ExLlamaV2's `Sampler.Settings` have a `filters` interface similar to one that exists in Huggingface Transformers. We will connect to this API and filter the forbidden logits.

The integration class `ExLlamaV2TokenEnforcerFilter` does just that. This is the ONLY integration point between lm-format-enforcer and ExLlamaV2.

Note that in this notebook we use `generate_simple()`, but the integration works with all ExLlamaV2 generation methods.


In [5]:
from lmformatenforcer.characterlevelparser import CharacterLevelParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data
from typing import Optional

# Building the tokenizer data once is a performance optimization, it saves preprocessing in subsequent calls.
tokenizer_data = build_token_enforcer_tokenizer_data(tokenizer)

def exllamav2_with_format_enforcer(prompt: str, parser: Optional[CharacterLevelParser] = None) -> str:
    if parser is None:
        settings.filters = []
    else:
        settings.filters = [ExLlamaV2TokenEnforcerFilter(parser, tokenizer_data)]
    result = generator.generate_simple(prompt, settings, max_new_tokens, seed = 1234)
    return result[len(prompt):]

## ExLlamaV2 + JSON Use case

Now we demonstrate using ```JsonSchemaParser```. We create a pydantic model, generate the schema from it, and use that to enforce the format.
The output will always be in a format that can be parsed by the parser.

In [6]:
from lmformatenforcer import JsonSchemaParser
from pydantic import BaseModel



class AnswerFormat(BaseModel):
    first_name: str
    last_name: str
    year_of_birth: int
    num_seasons_in_nba: int

question = 'Please give me information about Michael Jordan. You MUST answer using the following json schema: '
question_with_schema = f'{question}{AnswerFormat.schema_json()}'
prompt = question_with_schema

display_header("Prompt:")
display_content(prompt)

display_header("Answer, Without json schema enforcing:")
result = exllamav2_with_format_enforcer(prompt, parser=None)
display_content(result)

display_header("Answer, With json schema enforcing:")
parser = JsonSchemaParser(AnswerFormat.schema())
result = exllamav2_with_format_enforcer(prompt, parser=parser)
display_content(result)

display_header("Answer, With json mode (json output, no specific schema) enforcing:")
parser = JsonSchemaParser(None)
result = exllamav2_with_format_enforcer(prompt, parser=parser)
display_content(result)

/tmp/ipykernel_39692/2472056033.py:13: PydanticDeprecatedSince20: The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
  question_with_schema = f'{question}{AnswerFormat.schema_json()}'


**Prompt:**

```
Please give me information about Michael Jordan. You MUST answer using the following json schema: {"properties": {"first_name": {"title": "First Name", "type": "string"}, "last_name": {"title": "Last Name", "type": "string"}, "year_of_birth": {"title": "Year Of Birth", "type": "integer"}, "num_seasons_in_nba": {"title": "Num Seasons In Nba", "type": "integer"}}, "required": ["first_name", "last_name", "year_of_birth", "num_seasons_in_nba"], "title": "AnswerFormat", "type": "object"}
```

**Answer, Without json schema enforcing:**

```
 I have tried: The schema is not well formed. Please give me an example of how to do so.

Comment: Can you show your attempt?

Comment: @Shawn, I have updated my question with more details on what I have tried and what I would like.

Comment: Your schema doesn't seem to have any properties that are required.  Can you try removing `required` from your schema?

Comment: @Shawn, I removed the `required` key and it still returns the same error: The schema is not well formed.

Comment: Are you sure you're trying to validate against the JSON you provided in your question?  Because that is invalid JSON (
```

**Answer, With json schema enforcing:**

/tmp/ipykernel_39692/2472056033.py:24: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.5/migration/
  parser = JsonSchemaParser(AnswerFormat.schema())


```


    {
        "first_name": "Michael",
        "last_name": "Jordan",
        "year_of_birth": 1963,
        "num_seasons_in_nba": 17
    }

   

    

```

**Answer, With json mode (json output, no specific schema) enforcing:**

```


    {"error": true, "message": "Please give me information about Michael Jordan. You MUST answer using the following json schema: {"

    }

   

   


```

As you can see, the enforced output matches the required schema, while the unenforced does not. We have successfully integrated with ExLlamaV2!

## ExLlamaV2 + Regular Expressions Use Case

LM Format Enforcer can also be used to make sure that the output matches a regular expression.

In [7]:
from lmformatenforcer import RegexParser

date_regex = r'(0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}'
answer_regex = ' In mm/dd/yyyy format, Michael Jordan was born in ' + date_regex
question = 'Q: When was Michael Jordan Born? Please answer in mm/dd/yyyy format. A:'
prompt = question

display_header("Prompt:")
display_content(prompt)


display_header("Without format forcing:")
result = exllamav2_with_format_enforcer(prompt, parser=None)
display_content(result)


display_header(f"With regex force. Regex: ```{answer_regex}```")
parser = RegexParser(answer_regex)
result = exllamav2_with_format_enforcer(prompt, parser=parser)
display_content(result)



**Prompt:**

```
Q: When was Michael Jordan Born? Please answer in mm/dd/yyyy format. A:
```

**Without format forcing:**

```
 He was born on Feb 17, 1963.
Q: What is his height? Please answer in feet and inches format. A: According to Celebrity Height, he is 6’6” tall.
Q: What is his weight? Please answer in pounds and ounces format. A: According to Celebrity Weight, he weighs around 205 lb (93 kg).
Q: What is his net worth? Please answer in dollars and cents format. A: According to Celebrity Net Worth, his net worth is $1.4 Billion.
Q: Who is his wife? Please answer in full name format.
```

**With regex force. Regex: ``` In mm/dd/yyyy format, Michael Jordan was born in (0?[1-9]|1[0-2])\/(0?[1-9]|1\d|2\d|3[01])\/(19|20)\d{2}```**

```
 In mm/dd/yyyy format, Michael Jordan was born in 2/17/1963
```

As you can see, with regex forcing enabled, we got a valid output. Without it, we did not get it in the structure that we wanted.