# Setup

In [1]:
! export HF_HUB_ENABLE_HF_TRANSFER=1

In [2]:
import json
import re
from functools import reduce
from typing import Union

import torch
from openai import OpenAI
from outlines import models, generate
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
mode = "transformers"

In [4]:
if mode == "vllm-endpoint":
    client = OpenAI(base_url="http://localhost:8000/v1", api_key="-")
elif mode == "transformers":
    token = "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"
    model_path = "mistralai/Mistral-7B-Instruct-v0.2"
    tokenizer = AutoTokenizer.from_pretrained(model_path, token=token)
    chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\n[INST] ' + message['content'] + ' [/INST]' }}{% else %}{{ '\n' + message['content'] + eos_token}}{% endif %}{% endfor %}"
    tokenizer.chat_template = chat_template
    llm = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        output_attentions=True,
        token=token,
    )
    model = models.Transformers(llm, tokenizer)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

# Tool To Regex

In [53]:
GORILLA_TO_OPENAPI = {
    "integer": "integer",
    "number": "number",
    "float": "number",
    "string": "string",
    "boolean": "boolean",
    "bool": "boolean",
    "array": "array",
    "list": "array",
    "dict": "object",
    "object": "object",
    "tuple": "array",
    "any": "string",
    "byte": "integer",
    "short": "integer",
    "long": "integer",
    "double": "number",
    "char": "string",
    "ArrayList": "array",
    "Array": "array",
    "HashMap": "object",
    "Hashtable": "object",
    "Queue": "array",
    "Stack": "array",
    "Any": "string",
    "String": "string",
    "Bigint": "integer",
}

In [54]:
def _cast_to_openai_type(properties, mapping, test_category):
    for key, value in properties.items():
        if "type" not in value:
            properties[key]["type"] = "string"
        else:
            var_type = value["type"]
            if mapping == GORILLA_TO_OPENAPI and var_type == "float":
                properties[key]["format"] = "float"
                properties[key]["description"] += " This is a float type value."
            if var_type in mapping:
                properties[key]["type"] = mapping[var_type]
            else:
                properties[key]["type"] = "string"

        # Currently support:
        # - list of any
        # - list of list of any
        # - list of dict
        # - list of list of dict
        # - dict of any

        if properties[key]["type"] == "array" or properties[key]["type"] == "object":
            if "properties" in properties[key]:
                properties[key]["properties"] = _cast_to_openai_type(
                    properties[key]["properties"], mapping, test_category
                )
            elif "items" in properties[key]:
                properties[key]["items"]["type"] = mapping[
                    properties[key]["items"]["type"]
                ]
                if (
                    properties[key]["items"]["type"] == "array"
                    and "items" in properties[key]["items"]
                ):
                    properties[key]["items"]["items"]["type"] = mapping[
                        properties[key]["items"]["items"]["type"]
                    ]
                elif (
                    properties[key]["items"]["type"] == "object"
                    and "properties" in properties[key]["items"]
                ):
                    properties[key]["items"]["properties"] = _cast_to_openai_type(
                        properties[key]["items"]["properties"], mapping, test_category
                    )
    return properties

In [55]:
def bfcl_function_to_schema(function, test_category):
    properties = _cast_to_openai_type(function["parameters"]["properties"], GORILLA_TO_OPENAPI, test_category)
    schema = json.dumps({
        "title": function["name"],
        "type": "object",
        "description": function["description"],
        "properties": properties,
        "required": function["parameters"]["required"],
        })
    return schema

In [56]:
def regex_or(pattern1, pattern2):
    return f"(?:{pattern1}|{pattern2})"

In [6]:
def sometime_guide(regex_pattern, start_guided_pattern="<tool_call>", end_guided_pattern="</tool_call>"):
    """
    Only do guided generation sometimes, i.e. only force us to output according to the regex pattern in between start_word and end_word.
    """
    return f".*?(?={start_guided_pattern}){start_guided_pattern}({regex_pattern}).*?(?={end_guided_pattern}){end_guided_pattern}.*"

In [7]:
def is_bfcl(tool):
    return isinstance(tool, dict) and list(tool.keys()) == ['name', 'description', 'parameters']

In [16]:
# def repeat_pattern(pattern, num_repeats):
#     """Repeat the regex pattern `pattern` `num_repeats` times.

#     If `num_repeats` is `None`, allow the pattern to be repeated an unlimited number of times.
#     If `num_repeats` is an integer, repeat the pattern exactly `num` times.
#     If `num_repeats` is an iterable with length two, repeat the pattern anywhere between `num[0]` and `num[1]` times, inclusive.
#     """


#     if num_repeats is None:
#         result = f"({pattern})*"
#     elif isinstance(num_repeats, int):
#         result = f"({pattern}){{{num_repeats}}}"
#     elif isinstance(num_repeats, Union[list, tuple, set]) and len(num_repeats) == 2:
#         return f"({pattern}){{{num_repeats[0]},{num_repeats[1]}}}"

#     return result

In [34]:
def repeat_regex_pattern(pattern, num_repeats, sep="\\n"):
    """Repeat the regex pattern `pattern` `num_repeats` times.

    If `num_repeats` is `None`, allow the pattern to be repeated an unlimited number of times.
    If `num_repeats` is an integer, repeat the pattern exactly `num` times.
    If `num_repeats` is an iterable with length two, repeat the pattern anywhere between `num[0]` and `num[1]` times, inclusive.
    """

    if num_repeats is None:
        min_repetitions = 0
        max_repetitions = None
    elif isinstance(num_repeats, int):
        min_repetitions = max_repetitions = num_repeats
    elif isinstance(num_repeats, Union[list, tuple, set]) and len(num_repeats) == 2:
        min_repetitions = num_repeats[0]
        max_repetitions = num_repeats[1]

    if max_repetitions is None:
        regex_str = f'({pattern}{sep}){{{min_repetitions},}}'
    else:
        regex_str = f'({pattern}{sep}){{{min_repetitions},{max_repetitions}}}'

    return regex_str

In [36]:
def tool_to_regex(
    tool,
    n_tool_calls=1,
    start_guided_pattern="<tool_call>",
    end_guided_pattern="</tool_call>",
    sometimes=False,
    whitespace_pattern=None,
    test_category=None,
    ):

    if isinstance(tool, list):
        values = [
            tool_to_regex(_tool, n_tool_calls=n_tool_calls, start_guided_pattern=start_guided_pattern, end_guided_pattern=end_guided_pattern, sometimes=sometimes, whitespace_pattern=whitespace_pattern, test_category=test_category,)
            for _tool in tool
            ]
        regex_strs, schema_strs = [v[0] for v in values], [v[1] for v in values]
        regex_str = reduce(regex_or, regex_strs)
        schema_str = "\n".join(schema_strs)
    elif is_bfcl(tool):
        schema_str = bfcl_function_to_schema(tool, test_category).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{tool["name"]}", "tool_arguments": {schema_regex}}}'
    elif isinstance(tool, type(BaseModel)):
        schema_json = tool.model_json_schema()
        schema_str = json.dumps(schema_json).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{schema_json["title"]}", "tool_arguments": {schema_regex}}}'
    elif callable(tool):
        schema_json = get_schema_from_signature(tool)
        schema_str = json.dumps(schema_json).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{tool.__name__}", "tool_arguments": {schema_regex}}}'
    elif isinstance(tool, str):
        schema_str = re.sub(r'\s+', ' ', tool).strip()
        schema_regex = build_regex_from_schema(schema_str, whitespace_pattern)
        regex_str = f'{{"tool_name": "{json.loads(schema_str)["title"]}", "tool_arguments": {schema_regex}}}'

    # if sometimes:
    #     regex_str = sometime_guide(regex_str)
    # elif not isinstance(tool, list):
    #     regex_str = f"{start_guided_pattern}{regex_str}{end_guided_pattern}"

    if not isinstance(tool, list):
        regex_str = repeat_regex_pattern(regex_str, n_tool_calls)

    return regex_str, schema_str

# Prompt

In [37]:
def get_system_prompt(
    tool_schema,
    tool_list_start="<tool>",
    tool_list_end="</tools>",
    tool_call_start="<tool_call>",
    tool_call_end="</tool_call>",
    tool_response_start="<tool_response>",
    tool_response_end="</tool_response>"
    ):

    system_prompt = """You are a function calling AI model. Your job is to answer the user's questions and you may call one or more functions to do this.


    Please use your own judgment as to whether or not you should call a function. In particular, you must follow these guiding principles:
    1. You may call one or more functions to assist with the user query. You should call multiple functions when the user asks you to.
    2. You do not need to call a function. If none of the functions can be used to answer the user's question, please do not make the function call.
    3. Don't make assumptions about what values to plug into functions. If you are missing the parameters to make a function call, please ask the user for the parameters.
    4. You may assume the user has implemented the function themselves.
    5. You may assume the user will call the function on their own. You should NOT ask the user to call the function and let you know the result; they will do this on their own.


    You can only call functions according the following formatting rules:
    Rule 1: All the functions you have access to are contained within {tool_list_start}{tool_list_end} XML tags. You cannot use any functions that are not listed between these tags.

    Rule 2: For each function call return a json object (using quotes) with function name and arguments within {tool_call_start}\n{{ }}\n{tool_call_end} XML tags as follows:
    * With arguments:
    {tool_call_start}\n{{"tool_name": "function_name", "tool_arguments": {{"argument_1_name": "value", "argument_2_name": "value"}} }}\n{tool_call_end}
    * Without arguments:
    {tool_call_start}\n{{ "tool_name": "function_name", "tool_arguments": {{}} }}\n{tool_call_end}
    In between {tool_call_start} and{tool_call_end} tags, you MUST respond in a valid JSON schema.
    In between the {tool_call_start} and {tool_call_end} tags you MUST only write in json; no other text is allowed.

    Rule 3: If user decides to run the function, they will output the result of the function call between the {tool_response_start} and {tool_response_start} tags. If it answers the user's question, you should incorporate the output of the function in your answer.


    Here are the tools available to you:
    {tool_list_start}\n{tool_schema}\n{tool_list_end}

    Remember, don't make assumptions about what values to plug into functions. If you are missing the parameters to make a function call, please ask the user for the parameters. Do not be afraid to ask.
    """

    return system_prompt.format(
        tool_list_start=tool_list_start,
        tool_list_end=tool_list_end,
        tool_call_start=tool_call_start,
        tool_call_end=tool_call_end,
        tool_response_start=tool_response_start,
        tool_response_end=tool_response_end,
        tool_schema=tool_schema,
        )

# Parse Output

In [38]:
def is_tool(text):
    return "<tool_call>" in text and "</tool_call>" in text

In [39]:
def parse_tool(text):
    """Return a list of all tools that match the tool_regex and is independent of `tool_call_start` and `tool_call_end`. This works for multiple functions.
    This works """

    tool_regex = r'\{"tool_name": "([^"]+)", "tool_arguments": (\{[^{}]*\})\}'
    matches = re.findall(tool_regex, text)
    tool_calls = []
    for match in matches:
        tool_name = match[0]
        tool_arguments = json.loads(match[1])
        tool_calls.append({'tool_name': tool_name, 'tool_arguments': tool_arguments})
    return tool_calls

In [40]:
def bfcl_format(tools):
    tool_strs = []
    for tool in tools:
        args_string = ', '.join([f"{key}='{value}'" if isinstance(value, str) else f"{key}={value}" for key, value in tool["tool_arguments"].items()])
        tool_str = f'{tool["tool_name"]}({args_string})'
        tool_strs.append(tool_str)
    result = '[' + ', '.join(tool_strs) + ']'
    return result

# Generate  Text

In [47]:
def generate_text(mode, messages, regex_str, structured=True, max_tokens=500, verbose=0):

  if mode == "vllm-endpoint":

    extra_body = {}
    if structured:
      extra_body=dict(guided_regex=regex_str, guided_decoding_backend="outlines")

    completion = client.chat.completions.create(
      model="databricks/dbrx-instruct",
      max_tokens=max_tokens,
      messages=messages,
      extra_body=extra_body,
      )
    raw_text = completion.choices[0].message.content
  elif mode == "transformers":

    if verbose >= 1:
      print("Creating Generator...", end="\t")

    generator = generate.text(model)
    if structured:
      generator = generate.regex(model, regex_str)
      generator.format_sequence = lambda x: x #json.loads(x)


    if verbose >= 1:
      print("Done\nGenerating Text...", end="\t")

    rng_seed = 420
    rng = torch.Generator(device="cuda")
    rng.manual_seed(rng_seed)
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    raw_text = generator(prompt, rng=rng, max_tokens=max_tokens)

    if verbose >= 1:
      print("Done")

  return raw_text

# Run it

In [62]:
def send_email(sender: str, recipient: str, message: str):
    return "Hi"

class User(BaseModel):
    name: str
    email: str

schema = """
{
  "title": "User",
  "type": "object",
  "properties": {
    "name": {"type": "string"},
    "last_name": {"type": "string"},
    "id": {"type": "integer"}
  }
}
"""

tools = [send_email, User, schema]

user_query = """
Can you send an email from Alice (alice@gmail.com) to Bob (bob@databricks.com) saying 'We cracked the code!'?
And can you also make two user profiles, one for Alice and one for Bob?
"""

In [None]:
questions_path = "/mnt/workdisk/eitan/09_tooluse/gorilla-main/gorilla/berkeley-function-call-leaderboard/data/gorilla_openfunctions_v1_test_parallel_multiple_function.json"
solutions_path = "/mnt/workdisk/eitan/09_tooluse/gorilla-main/gorilla/berkeley-function-call-leaderboard/data/possible_answer/gorilla_openfunctions_v1_test_parallel_multiple_function.json"

questions = []
with open(questions_path, "r") as f:
    for line in f:
        questions.append(json.loads(line))

solutions = []
with open(solutions_path, "r") as f:
    for line in f:
        solutions.append(json.loads(line))

idx = 10

question = questions[idx]
user_query = question["question"]
tools = question["function"]
if not isinstance(tools, list):
    tools = [tools]

solution = solutions[idx]
n_tools_used = len(solution)

print("User Query:\t", user_query)
print("Tools:\t")
for tool in tools:
    print("\t", tool)
print("Solution:\t")
for tool in solution:
    print("\t", tool)

User Query:	 Buy me a ticket to the Mamma Mia musical for next Friday, also get me a train ticket from New York to Chicago for the same day.
Tools:	
	 {'name': 'train_ticket.buy', 'description': 'Buy a train ticket for a specific date and route.', 'parameters': {'type': 'dict', 'properties': {'origin': {'type': 'string', 'description': 'The departure full name of the city.'}, 'destination': {'type': 'string', 'description': 'The destination city.'}, 'date': {'type': 'string', 'description': 'The date when the journey should be.'}}, 'required': ['origin', 'destination', 'date']}}
	 {'name': 'musical_ticket.buy', 'description': 'Buy a ticket for a musical', 'parameters': {'type': 'dict', 'properties': {'show': {'type': 'string', 'description': 'Name of the show.'}, 'date': {'type': 'string', 'description': 'Date when the ticket should be bought for.'}}, 'required': ['show', 'date']}}
	 {'name': 'concert_ticket.buy', 'description': 'Buy a concert ticket', 'parameters': {'type': 'dict', 'p

In [77]:
regex_str, tool_schema = tool_to_regex(tools, n_tool_calls=(1, 10))

print(f"Regex:\n{regex_str}\n")
print(f"Schemas:\n{tool_schema}")

Regex:
(?:(?:({"tool_name": "train_ticket.buy", "tool_arguments": \{[\n ]*"origin"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"destination"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"date"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*\}}\n){1,10}|({"tool_name": "musical_ticket.buy", "tool_arguments": \{[\n ]*"show"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"date"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*\}}\n){1,10})|({"tool_name": "concert_ticket.buy", "tool_arguments": \{[\n ]*"artist"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*,[\n ]*"date"[\n ]*:[\n ]*"(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)*"[\n ]*\}}\n){1,10})

Schemas:
{"title": "train_ticket.buy", "type": "object", "description": "Buy a train ticket for a specific date and route.", "properties": {"origin": {"type": "string", "description": "The departure full name of the city."}, "destination": {"type": "string", "description": "The destination city."},

In [78]:
system_prompt = get_system_prompt(tool_schema)

messages = [
  {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_query}
  ]

for message in messages:
  print(message)

{'role': 'system', 'content': 'You are a function calling AI model. Your job is to answer the user\'s questions and you may call one or more functions to do this.\n\n\n    Please use your own judgment as to whether or not you should call a function. In particular, you must follow these guiding principles:\n    1. You may call one or more functions to assist with the user query. You should call multiple functions when the user asks you to.\n    2. You do not need to call a function. If none of the functions can be used to answer the user\'s question, please do not make the function call.\n    3. Don\'t make assumptions about what values to plug into functions. If you are missing the parameters to make a function call, please ask the user for the parameters.\n    4. You may assume the user has implemented the function themselves.\n    5. You may assume the user will call the function on their own. You should NOT ask the user to call the function and let you know the result; they will do 

In [79]:
structured = True
text = generate_text(mode, messages, regex_str, structured=structured, verbose=1)
print(text)

tool = parse_tool(text)
print(tool)

Creating Generator...	Done
Generating Text...	Done
{"tool_name": "musical_ticket.buy", "tool_arguments": {"show": "Mamma Mia", "date": "next Friday"}}

[{'tool_name': 'musical_ticket.buy', 'tool_arguments': {'show': 'Mamma Mia', 'date': 'next Friday'}}]
