Skip to content

Commit

Permalink
unfinished implementation - draft
Browse files Browse the repository at this point in the history
  • Loading branch information
lilyydu committed Jun 27, 2024
1 parent 0c2289d commit d36beba
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 416 deletions.
714 changes: 357 additions & 357 deletions python/packages/ai/poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions python/packages/ai/teams/ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
"""

from .chat_completion_action import ChatCompletionAction
from .openai_function import OpenAIFunction
from .openai_model import AzureOpenAIModelOptions, OpenAIModel, OpenAIModelOptions
from .prompt_completion_model import PromptCompletionModel
from .prompt_response import PromptResponse, PromptResponseStatus

__all__ = [
"ChatCompletionAction",
"OpenAIFunction",
"AzureOpenAIModelOptions",
"OpenAIModel",
"OpenAIModelOptions",
Expand Down
39 changes: 39 additions & 0 deletions python/packages/ai/teams/ai/models/openai_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from dataclasses_json import DataClassJsonMixin, dataclass_json


@dataclass_json
@dataclass
class OpenAIFunction(DataClassJsonMixin):
"""
Function spec that adheres to OpenAI function calling.
"""

name: str
"""
Name of the function to be called.
"""

description: Optional[str]
"""
Description of what the function does.
"""

parameters: Optional[Dict[str, Any]]
"""
Parameters the function accepts, described as a JSON Schema object.
"""

handler: Callable[..., Union[str, Awaitable[str]]]
"""
The function handler, may be asynchoronous, takes in any number of
arguments and must return a string.
"""
216 changes: 163 additions & 53 deletions python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@

from __future__ import annotations

import asyncio
import json
from dataclasses import dataclass
from logging import Logger
from typing import List, Optional, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union

import openai
from botbuilder.core import TurnContext
from openai.types import chat
from openai.types import chat, shared_params

from ...state import MemoryBase
from ..prompts.message import Message
from ..prompts.prompt_functions import PromptFunctions
from ..prompts.prompt_template import PromptTemplate
from ..tokenizers import Tokenizer
from .openai_function import OpenAIFunction
from .prompt_completion_model import PromptCompletionModel
from .prompt_response import PromptResponse

Expand Down Expand Up @@ -119,21 +121,55 @@ async def complete_prompt(
) -> PromptResponse[str]:
max_tokens = template.config.completion.max_input_tokens

# TODO: retrieve parameters - may or may not exist
# TODO: update everything
# TODO: check that tool_choice confirms to a set definition
# Setup tools if enabled
is_tools_enabled = template.config.completion.include_tools
tool_choice = template.config.completion.tool_choice
parallel_tool_calls = template.config.completion.parallel_tool_calls
# tools = template.plugins.tools
tools: List[OpenAIFunction] = []
tools_handlers = memory.get("temp.tools")

# If tools is enabled, reformat actions to appropriate schema
if is_tools_enabled and template.actions and tools_handlers:
if len(template.actions) == 0:
return PromptResponse[str](status="tools_error", error="Missing actions")

if len(tools_handlers) == 0:
return PromptResponse[str](status="tools_error", error="Missing tools handlers")

# TODO: Choice - more rigorous comparison of template.actions and tools_handlers

if len(template.actions) != len(tools_handlers):
return PromptResponse[str](
status="tools_error",
error="Number of actions does not match number of tool handlers",
)

for action in template.actions:
handler = tools_handlers.get(action.name).func
tool = OpenAIFunction(action.name, action.description, action.parameters, handler)
tools.append(tool)

formatted_tools: List[chat.ChatCompletionToolParam] = []

for tool in tools:
curr_tool = chat.ChatCompletionToolParam(
type="function",
function=shared_params.FunctionDefinition(
name=tool.name,
description=tool.description or "",
parameters=tool.parameters or Dict[str, object](),
),
)
formatted_tools.append(curr_tool)

model = (
template.config.completion.model
if template.config.completion.model is not None
else self._options.default_model
)
# TODO: Check if the model supports function calling
# TODO: Check that tools matches tool_choice

# TODO: Choice - Check if the model supports function calling and parallel tool calls
# TODO: Choice - Check if tool matches tool_choice

res = await template.prompt.render_as_messages(
context=context,
Expand Down Expand Up @@ -198,87 +234,104 @@ async def complete_prompt(
top_p=template.config.completion.top_p,
temperature=template.config.completion.temperature,
max_tokens=max_tokens,
# TODO: added tools parameter
tools=tools,
tool_choice=tool_choice,
tools=formatted_tools,
tool_choice=tool_choice or "auto",
parallel_tool_calls=parallel_tool_calls or True,
)

if self._options.logger is not None:
self._options.logger.debug("COMPLETION:\n%s", completion.model_dump_json())

# Handle tools flow
response_message = completion.choices[0].message
tool_calls = response_message.tool_calls

# TODO: used to track latest response from LLM
# Tracks the latest response from the LLM
final_response = completion

# TODO: only support tools for default
if template.config.augmentation == "default":
while tool_calls:
# TODO: (1) Validate the tool_calls -> does it match tools (look at plugin names, parameters)?
if template.config.augmentation == "default" and is_tools_enabled:

# TODO: (2) Extend conversation with reply
messages.append(response_message)
# TODO: BUG - Handle edge cases where parameters dict is empty or none

# (3) Send the info for each plugin call and response to the model
while tool_calls and len(tool_calls) > 0:
if not parallel_tool_calls and len(tool_calls) > 1:
break

# (3a) Check if parallel_function_calling is False, and len(tool_calls) > 1 => SKIP
# OR alternatively..
# -> call them iteratively
# -> call first plugin in the list
if isinstance(tool_choice, dict) and len(tool_calls) > 1:
break

for tool_call in tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
if tool_choice == "none":
break

# (3b) Check tool_choice - if a specific plugin should be called, or none
messages.append(response_message)

# (3c) TODO: call the plugins and get the responses
# -> feels like a security vulnerability?
function = template.plugins.functions.getattr(function_name)
# TODO: BUG - Accessing wrong key for tools properties
# TODO: BUG - Handling required and optional arguments
# TODO: BUG - Make sure single tool and multiple tools matches up
# TODO: BUG - Pass in JSON obj vs string back to the LLM

if isinstance(tool_choice, dict):
# Calling a single tool
function_name = tool_choice["function"]["name"]
curr_tool_call = tool_calls[0]
curr_function = list(filter(lambda tool: tool.name == function_name, tools))

# Validate function arguments
required_args = (
curr_function[0].parameters["required"]
if curr_function[0].parameters and "required" in curr_function[0].parameters
else None
)

# TODO: unravel the parameters, need to know # of args, and which args go into which pos
# TODO: should we place a restriction on the type of functions handled?
# -> this may be async?
# -> limit on function args?
# -> could functions call other functions (eg. additional handlers)?
function_response = function(function_args)
curr_args = json.loads(curr_tool_call.function.arguments)
curr_function_handler = curr_function[0].handler

if required_args:
# TODO: CHOICE - Verify each argument
if len(required_args) > len(curr_args):
break

# Call the function
function_response: Union[str, Awaitable[str]] = await self._handle_function_response(curr_function_handler, required_args, curr_args)

messages.append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": function_name,
"content": function_response,
}
chat.ChatCompletionToolMessageParam(
role="tool",
tool_call_id=curr_tool_call.id,
content=function_response,
)
)
else:
curr_message_length = len(messages)
messages = await self._handle_multiple_tool_calls(
messages, tool_calls, tools
)
# No tools were run successfully
if len(messages) == curr_message_length:
break

# messages = [{"tool_call_id":1 ,..},{"tool_call_id":2 ,..},{"tool_call_id":3 ,..}]

# (5) TODO: save to conversation history?

# (7) TODO: update parameters to match above
final_response = await self._client.chat.completions.create(
model="gpt-4o",
messages=messages,
model=model,
presence_penalty=template.config.completion.presence_penalty,
frequency_penalty=template.config.completion.frequency_penalty,
top_p=template.config.completion.top_p,
temperature=template.config.completion.temperature,
max_tokens=max_tokens,
)

tool_calls = final_response.choices[0].message.tool_calls


input: Optional[Message] = None
last_message = len(res.output) - 1

# TODO: Do we still skip this message if we are in tool_call mode?

# Skips the first message which is the prompt
if last_message > 0 and res.output[last_message].role == "user":
input = res.output[last_message]

return PromptResponse[str](
input=input,
message=Message(
# TODO: update to take in the final_response
role=final_response.choices[0].message.role,
content=final_response.choices[0].message.content,
),
Expand All @@ -294,3 +347,60 @@ async def complete_prompt(
status of {err.code}: {err.message}
""",
)

async def _handle_function_response(
self,
curr_function_handler: Callable[..., Union[str, Awaitable[str]]],
required_args: Optional[List[str]],
curr_args: Dict[str, Any],
) -> Union[str, Awaitable[str]]:

# TODO: BUG - Passing in optional vars is an option, should check for curr_args.values size instead
if asyncio.iscoroutinefunction(curr_function_handler) and required_args and len(required_args) > 0:
return await curr_function_handler(**curr_args.values())
elif asyncio.iscoroutinefunction(curr_function_handler):
return await curr_function_handler()
elif required_args and len(required_args) > 0:
return curr_function_handler(**curr_args.values())
else:
return curr_function_handler()

async def _handle_multiple_tool_calls(
self,
messages: List[chat.ChatCompletionMessageParam],
tool_calls: List[chat.ChatCompletionMessageToolCall],
tools: List[OpenAIFunction],
) -> List[chat.ChatCompletionMessageParam]:

# TODO: BUG - Needs updates to match up with single tool call
for tool_call in tool_calls:
function_name = tool_call.function.name
curr_args = json.loads(tool_call.function.arguments)
curr_function = list(filter(lambda tool: tool.name == function_name, tools))

# Validate function name
if len(curr_function) == 0:
continue

# Validate function arguments
required_args = (
curr_function[0].parameters.keys() if list(curr_function[0].parameters) else None
)
curr_function_handler = curr_function[0].handler

if required_args:
if len(required_args) != len(curr_args):
continue

# Call the function
function_response: Union[str, Awaitable[str]] = await self._handle_function_response(curr_function_handler, required_args, curr_args)

messages.append(
chat.ChatCompletionToolMessageParam(
role="tool",
tool_call_id=tool_call.id,
content=function_response,
)
)

return messages
4 changes: 3 additions & 1 deletion python/packages/ai/teams/ai/models/prompt_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from ..prompts.message import Message

ContentT = TypeVar("ContentT")
PromptResponseStatus = Literal["success", "error", "rate_limited", "invalid_response", "too_long"]
PromptResponseStatus = Literal[
"success", "error", "rate_limited", "invalid_response", "too_long", "tools_error"
]


@dataclass
Expand Down
Loading

0 comments on commit d36beba

Please sign in to comment.