# Setup

In [1]:
! export HF_HUB_ENABLE_HF_TRANSFER=1

In [18]:
import re
import json
from functools import reduce

import torch
from openai import OpenAI
from outlines import models, generate
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
mode = "transformers"

In [33]:
if mode == "vllm-endpoint":
    client = OpenAI(base_url="http://localhost:8000/v1", api_key="-")
elif mode == "transformers":
    token = "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"
    model_path = "mistralai/Mistral-7B-Instruct-v0.2"
    tokenizer = AutoTokenizer.from_pretrained(model_path, token=token)
    chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\n[INST] ' + message['content'] + ' [/INST]' }}{% else %}{{ '\n' + message['content'] + eos_token}}{% endif %}{% endfor %}"
    tokenizer.chat_template = chat_template
    llm = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        output_attentions=True,
        token=token,
    )
    model = models.Transformers(llm, tokenizer)

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

# Tool To Regex

In [34]:
def regex_or(pattern1, pattern2):
    return f"(?:{pattern1}|{pattern2})"

In [35]:
def sometime_guide(regex_pattern, start_guided_pattern="<tool_call>", end_guided_pattern="</tool_call>"):
    """
    Only do guided generation sometimes, in between start_word and end_word.
    """
    return f".*?(?={start_guided_pattern}){start_guided_pattern}({regex_pattern}).*?(?={end_guided_pattern}){end_guided_pattern}.*"

In [36]:
def is_bfcl(tool):
    return isinstance(tool, dict) and list(tool.keys()) == ['name', 'description', 'parameters']

In [53]:
def repeat_pattern(pattern, num_repeats):
    """Repeat the regex pattern `pattern` `num_repeats` times.

    If `num_repeats` is `None`, allow the pattern to be repeated an unlimited number of times.
    If `num_repeats` is an integer, repeat the pattern exactly `num` times.
    If `num_repeats` is an iterable with length two, repeat the pattern anywhere between `num[0]` and `num[1]` times, inclusive.
    """
    if num_repeats is None:
        result = f"({pattern})*"
    elif isinstance(num_repeats, int):
        result = f"({pattern}){{{num_repeats}}}"
    elif isinstance(num_repeats, Union[list, tuple, set]) and len(num_repeats) == 2:
        return f"({pattern}){{{num_repeats[0]},{num_repeats[1]}}}"
    return result

In [70]:
def tool_to_regex(
    tool,
    n_tool_calls=1,
    start_guided_pattern="<tool_call>",
    end_guided_pattern="</tool_call>",
    sometimes=False,
    whitespace_pattern=None,
    test_category=None,
    ):

    if isinstance(tool, list):
        values = [
            tool_to_regex(_tool, n_tool_calls=n_tool_calls, start_guided_pattern=start_guided_pattern, end_guided_pattern=end_guided_pattern, sometimes=sometimes, whitespace_pattern=whitespace_pattern, test_category=test_category,)
            for _tool in tool
            ]
        regex_strs = [v[0] for v in values]
        regex_str = reduce(regex_or, regex_strs)
        schema_strs = [v[1] for v in values]
        schema_str = "\n".join(schema_strs)
    elif is_bfcl(tool):
        schema_str = bfcl_function_to_schema(tool, test_category).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{tool["name"]}", "tool_arguments": {schema_regex}}}'
    elif isinstance(tool, type(BaseModel)):
        schema_json = tool.model_json_schema()
        schema_str = json.dumps(schema_json).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{schema_json["title"]}", "tool_arguments": {schema_regex}}}'
    elif callable(tool):
        schema_json = get_schema_from_signature(tool)
        schema_str = json.dumps(schema_json).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{tool.__name__}", "tool_arguments": {schema_regex}}}'
    elif isinstance(tool, str):
        schema_str = re.sub(r'\s+', ' ', tool).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{json.loads(schema_str)["title"]}", "tool_arguments": {schema_regex}}}'

    # if sometimes:
    #     regex_str = sometime_guide(regex_str)
    # elif not isinstance(tool, list):
    #     regex_str = f"{start_guided_pattern}{regex_str}{end_guided_pattern}"

    if not isinstance(tool, list):
        regex_str = repeat_pattern(regex_str, n_tool_calls)

    return regex_str, schema_str

# Prompt

In [68]:
def get_system_prompt(
    tool_schema,
    tool_list_start="<tool>",
    tool_list_end="</tools>",
    tool_call_start="<tool_call>",
    tool_call_end="</tool_call>",
    tool_response_start="<tool_response>",
    tool_response_end="</tool_response>"
    ):

    system_prompt = """You are a function calling AI model. Your job is to answer the user's questions and you may call one or more functions to do this.


    Please use your own judgment as to whether or not you should call a function. In particular, you must follow these guiding principles:
    1. You may call one or more functions to assist with the user query.
    2. You do not need to call a function. If none of the functions can be used to answer the user's question, please do not make the function call.
    3. Don't make assumptions about what values to plug into functions. If you are missing the parameters to make a function call, please ask the user for the parameters.
    4. You may assume the user has implemented the function themselves.
    5. You may assume the user will call the function on their own. You should NOT ask the user to call the function and let you know the result; they will do this on their own.


    You can only call functions according the following formatting rules:
    Rule 1: All the functions you have access to are contained within {tool_list_start}{tool_list_end} XML tags. You cannot use any functions that are not listed between these tags.

    Rule 2: For each function call return a json object (using quotes) with function name and arguments within {tool_call_start}\n{{ }}\n{tool_call_end} XML tags as follows:
    * With arguments:
    {tool_call_start}\n{{"tool_name": "function_name", "tool_arguments": {{"argument_1_name": "value", "argument_2_name": "value"}} }}\n{tool_call_end}
    * Without arguments:
    {tool_call_start}\n{{ "tool_name": "function_name", "tool_arguments": {{}} }}\n{tool_call_end}
    In between {tool_call_start} and{tool_call_end} tags, you MUST respond in a valid JSON schema.
    In between the {tool_call_start} and {tool_call_end} tags you MUST only write in json; no other text is allowed.

    Rule 3: If user decides to run the function, they will output the result of the function call between the {tool_response_start} and {tool_response_start} tags. If it answers the user's question, you should incorporate the output of the function in your answer.


    Here are the tools available to you:
    {tool_list_start}\n{tool_schema}\n{tool_list_end}

    Remember, don't make assumptions about what values to plug into functions. If you are missing the parameters to make a function call, please ask the user for the parameters. Do not be afraid to ask.
    """

    return system_prompt.format(
        tool_list_start=tool_list_start,
        tool_list_end=tool_list_end,
        tool_call_start=tool_call_start,
        tool_call_end=tool_call_end,
        tool_response_start=tool_response_start,
        tool_response_end=tool_response_end,
        tool_schema=tool_schema,
        )

# Parse Output

In [40]:
def is_tool(text):
    return "<tool_call>" in text and "</tool_call>" in text

In [41]:
def parse_tool(text, start_tag = "<tool_call>", end_tag = "</tool_call>"):



    return function_call

In [None]:
def parse_tool(text, start_tag, end_tag):
    start_index = text.find(start_tag)
    if start_index == -1:
        return None

    end_index = text.find(end_tag, start_index + len(start_tag))
    if end_index == -1:
        # if end_tag not present, check for last occurrence of start_tag and use this as the end_index
        end_index = text.rfind(start_tag, start_index + len(start_tag))
        if end_index == -1:
            return None

    function_call = text[start_index + len(start_tag):end_index].strip(" \n")
    return function_call

# Generate  Text

In [42]:
def generate_text(mode, messages, regex_str, verbose=0):

  if mode == "vllm-endpoint":
    completion = client.chat.completions.create(
      model="databricks/dbrx-instruct",
      max_tokens=350,
      messages=messages,
      extra_body=dict(guided_regex=regex_str, guided_decoding_backend="outlines"),
      )
    raw_text = completion.choices[0].message.content
  elif mode == "transformers":

    if verbose >= 1:
      print("Creating Generator Index...", end="\t")

    generator = generate.regex(model, regex_str)
    generator.format_sequence = lambda x: x #json.loads(x)

    if verbose >= 1:
      print("Done\nGenerating Text...", end="\t")

    rng_seed = 420
    rng = torch.Generator(device="cuda")
    rng.manual_seed(rng_seed)
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    raw_text = generator(prompt, rng=rng, max_tokens=300)

    if verbose >= 1:
      print("Done")

  return raw_text

# Run it

In [71]:
def send_email(sender: str, recipient: str, message: str):
    return "Hi"

class User(BaseModel):
    name: str
    email: str

schema = """
{
  "title": "User",
  "type": "object",
  "properties": {
    "name": {"type": "string"},
    "last_name": {"type": "string"},
    "id": {"type": "integer"}
  }
}
"""

tools = [send_email, User, schema]
regex_str, tool_schema = tool_to_regex(tools, n_tool_calls=2)


print("Regex:\n", regex_str, end="\n\n")
print("Schemas:\n", tool_schema)

Regex:
 (?:(?:({"tool_name": "send_email", "tool_arguments": \{[\n ]*"sender"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"recipient"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"message"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*\}}){2}|({"tool_name": "User", "tool_arguments": \{[\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"email"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*\}}){2})|({"tool_name": "User", "tool_arguments": \{([\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"([\n ]*,[\n ]*"last_name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*")?([\n ]*,[\n ]*"id"[\n ]*:[\n ]*(-)?(0|[1-9][0-9]*))?|([\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,)?[\n ]*"last_name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"([\n ]*,[\n ]*"id"[\n ]*:[\n ]*(-)?(0|[1-9][0-9]*))?|([\n ]*"name"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,)?([\n ]*"last_name"[\n ]*:[\n ]*"(?:[^"\\\x0

In [72]:
# user_query = "Can you send an email from Alice (alice@gmail.com) to Bob (bob@databricks.com) saying 'We cracked the code!'? And can you also make two user profiles, one for Alice and one for Bob?"
user_query = "And can you also make two user profiles, one for Alice (alice@gmail.com) and one for Bob (bob@databricks.com)?"

system_prompt = get_system_prompt(tool_schema)

messages = [
  {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_query}
  ]

for message in messages:
  print(message)

{'role': 'system', 'content': 'You are a function calling AI model. Your job is to answer the user\'s questions and you may call one or more functions to do this.\n\n\n    Please use your own judgment as to whether or not you should call a function. In particular, you must follow these guiding principles:\n    1. You may call one or more functions to assist with the user query.\n    2. You do not need to call a function. If none of the functions can be used to answer the user\'s question, please do not make the function call.\n    3. Don\'t make assumptions about what values to plug into functions. If you are missing the parameters to make a function call, please ask the user for the parameters.\n    4. You may assume the user has implemented the function themselves.\n    5. You may assume the user will call the function on their own. You should NOT ask the user to call the function and let you know the result; they will do this on their own.\n\n\n    You can only call functions accord

In [73]:
raw_text = generate_text(mode, messages, regex_str, verbose=1)
print(raw_text)

Creating Generator Index...	Done
Generating Text...	Done
{"tool_name": "User", "tool_arguments": {
  "name": "Alice",
  "email": "alice@gmail.com"
}}{"tool_name": "User", "tool_arguments": {
  "name": "Bob",
  "email": "bob@databricks.com"
}}


In [None]:
def extract_tool_instances(raw_text):
    """Return a list of all tools that match the tool_regex. This works for multiple functions."""

    tool_regex = r'\{"tool_name": "([^"]+)", "tool_arguments": (\{[^{}]*\})\}'
    matches = re.findall(tool_regex, raw_text)
    tool_instances = []
    for match in matches:
        tool_name = match[0]
        tool_arguments = json.loads(match[1])
        tool_instances.append({'tool_name': tool_name, 'tool_arguments': tool_arguments})
    return tool_instances

In [2]:
import re
tool_regex = r'\{"tool_name": "([^"]+)", "tool_arguments": (\{[^{}]*\})\}'
matches = re.findall(tool_regex, raw_text)
matches

NameError: name 'raw_text' is not defined

In [74]:
tool = json.loads(raw_text)
tool

JSONDecodeError: Extra data: line 4 column 3 (char 92)

In [15]:
tool = None
if is_tool(raw_text):
    tool = parse_tool(raw_text)
print(tool)

NameError: name 'raw_text' is not defined

In [80]:
def format_result(function_name, result):
    # This method is used to format the result in a standard way.
    args_string = ', '.join([f"{key}='{value}'" if isinstance(value, str) else f"{key}={value}" for key, value in result.items()])
    # Creating the output string with the function name and arguments
    output_string = f'[{function_name}({args_string})]'
    return output_string

In [81]:
format_result(tool["tool_name"], tool["tool_arguments"])

"[send_email(sender='alice@gmail.com', recipient='bob@databricks.com', message='We cracked the code!')]"