# Setup

In [27]:
import json
from functools import reduce

from openai import OpenAI
from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from pydantic import BaseModel

In [3]:
client = OpenAI(base_url="http://localhost:8000/v1", api_key="-")
client

<openai.OpenAI at 0x7fb7e018a9b0>

# Tool To Regex

In [11]:
def regex_or(pattern1, pattern2):
    return f"(?:{pattern1}|{pattern2})"

In [13]:
def sometime_guide(regex_pattern, start_guided_pattern="<tool_call>", end_guided_pattern="</tool_call>"):
    """
    Only do guided generation sometimes, in between start_word and end_word.
    """
    return f".*?(?={start_guided_pattern}){start_guided_pattern}({regex_pattern}).*?(?={end_guided_pattern}){end_guided_pattern}.*"

In [37]:
def tool_to_regex(tool, whitespace_pattern=None, test_category=None):

    if isinstance(tool, list):
        values = [tool_to_regex(_tool, whitespace_pattern=whitespace_pattern, test_category=test_category) for _tool in tool]
        regex_strs = [v[0] for v in values]
        regex_str = reduce(regex_or, regex_strs)
        schema_strs = [v[1] for v in values]
        schema_str = "\n".join(schema_strs)
    elif isinstance(tool, dict) and list(tool.keys()) == ['name', 'description', 'parameters']:
        schema_str = bfcl_function_to_schema(tool, test_category)
        schema_regex = build_regex_from_schema(schema_str.strip(), whitespace_pattern)
        regex_str = f'{{"tool_name": "{tool["name"]}", "tool_arguments": {schema_regex}}}'
    elif isinstance(tool, type(BaseModel)):
        schema_json = tool.model_json_schema()
        schema_str = json.dumps(schema_json)
        schema_regex = build_regex_from_schema(schema_str.strip(), whitespace_pattern)
        regex_str = f'{{"tool_name": "{schema_json["title"]}", "tool_arguments": {schema_regex}}}'
    elif callable(tool):
        schema_json = get_schema_from_signature(tool)
        schema_str = json.dumps(schema_json)
        schema_regex = build_regex_from_schema(schema_str.strip(), whitespace_pattern)
        regex_str = f'{{"tool_name": "{tool.__name__}", "tool_arguments": {schema_regex}}}'
    elif isinstance(tool, str):
        schema_str = tool
        schema_regex = build_regex_from_schema(schema_str.strip(), whitespace_pattern)
        regex_str = f'{{"tool_name": "{json.loads(schema_str)["title"]}", "tool_arguments": {schema_regex}}}'

    return regex_str, schema_str

In [38]:
def add(a: int, b: int):
    return a+b

class User(BaseModel):
    name: str
    last_name: str
    id: int

schema = """
{
  "title": "User",
  "type": "object",
  "properties": {
    "name": {"type": "string"},
    "last_name": {"type": "string"},
    "id": {"type": "integer"}
  }
}
"""

tools = [add, User, schema]
regex_str = tool_to_regex(tools)
print(regex_str)

('(?:(?:{"tool_name": "add", "tool_arguments": \\{[\\n ]*"a"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*,[\\n ]*"b"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*\\}}|{"tool_name": "User", "tool_arguments": \\{[\\n ]*"name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*"[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*"[\\n ]*,[\\n ]*"id"[\\n ]*:[\\n ]*(0|[1-9][0-9]*)[\\n ]*\\}})|{"tool_name": "User", "tool_arguments": \\{([\\n ]*"name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*"([\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*")?([\\n ]*,[\\n ]*"id"[\\n ]*:[\\n ]*(0|[1-9][0-9]*))?|([\\n ]*"name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*"[\\n ]*,)?[\\n ]*"last_name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*"([\\n ]*,[\\n ]*"id"[\\n ]*:[\\n ]*(0|[1-9][0-9]*))?|([\\n ]*"name"[\\n ]*:[\\n ]*"(?:[^"\\\\\\x00-\\x1f\\x7f-\\x9f]|\\\\.)*"[\\n ]*,)?([\\n ]*"last_name"[\\n ]*:[\\n ]*"(?:[^"\

# Prompt

In [4]:
messages = [
    {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
  ]

In [5]:
completion = client.chat.completions.create(
  model="databricks/dbrx-instruct",
  messages=messages,
  #extra_body=dict(guided_regex=TEST_REGEX, guided_decoding_backend="outlines"),
  )

raw_text = completion.choices[0].message.content