Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions packages/ai/src/microsoft/teams/ai/chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
from dataclasses import dataclass
from inspect import isawaitable
from typing import Any, Awaitable, Callable, Optional, Self, TypeVar, cast
from typing import Any, Awaitable, Callable, Dict, Optional, Self, TypeVar, Union, cast, overload

from pydantic import BaseModel

Expand Down Expand Up @@ -58,17 +58,67 @@ def __init__(
self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {}
self.plugins: list[AIPluginProtocol] = plugins or []

def with_function(self, function: Function[T]) -> Self:
@overload
def with_function(self, function: Function[T]) -> Self: ...

@overload
def with_function(
self,
*,
name: str,
description: str,
parameter_schema: Union[type[T], Dict[str, Any]],
handler: FunctionHandlers,
) -> Self: ...

@overload
def with_function(
self,
*,
name: str,
description: str,
handler: FunctionHandlerWithNoParams,
) -> Self: ...

def with_function(
self,
function: Function[T] | None = None,
*,
name: str | None = None,
description: str | None = None,
parameter_schema: Union[type[T], Dict[str, Any], None] = None,
handler: FunctionHandlers | None = None,
) -> Self:
"""
Add a function to the available functions for this prompt.

Can be called in three ways:
1. with_function(function=Function(...))
2. with_function(name=..., description=..., parameter_schema=..., handler=...)
3. with_function(name=..., description=..., handler=...) - for functions with no parameters

Args:
function: Function to add to the available functions
function: Function object to add (first overload)
name: Function name (second and third overload)
description: Function description (second and third overload)
parameter_schema: Function parameter schema (second overload, optional)
handler: Function handler (second and third overload)

Returns:
Self for method chaining
"""
self.functions[function.name] = function
if function is not None:
self.functions[function.name] = function
else:
if name is None or description is None or handler is None:
raise ValueError("When not providing a Function object, name, description, and handler are required")
func = Function[T](
name=name,
description=description,
parameter_schema=parameter_schema,
handler=handler,
)
self.functions[func.name] = func
return self

def with_plugin(self, plugin: AIPluginProtocol) -> Self:
Expand Down
8 changes: 7 additions & 1 deletion packages/ai/src/microsoft/teams/ai/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,18 @@ class Function(Generic[Params]):

Type Parameters:
Params: Pydantic model class defining the function's parameter schema, if any.

Note:
For best type safety, use explicit type parameters when creating Function objects:
Function[SearchPokemonParams](name=..., parameter_schema=SearchPokemonParams, handler=...)

This ensures the handler parameter type matches the parameter_schema at compile time.
"""

name: str # Unique identifier for the function
description: str # Human-readable description of what the function does
parameter_schema: Union[type[Params], Dict[str, Any], None] # Pydantic model class, JSON schema dict, or None
handler: FunctionHandlers # Function implementation (sync or async)
handler: Union[FunctionHandler[Params], FunctionHandlerWithNoParams] # Function implementation (sync or async)


@dataclass
Expand Down
40 changes: 40 additions & 0 deletions packages/ai/tests/test_chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,46 @@ async def test_different_message_types(self, mock_model: MockAIModel) -> None:
result3 = await prompt.send(model_msg)
assert result3.response.content == "GENERATED - Model message"

@pytest.mark.asyncio
async def test_with_function_unpacked_parameters(self, mock_model: MockAIModel) -> None:
"""Test with_function using unpacked parameters instead of Function object"""
prompt = ChatPrompt(mock_model)

# Test with parameter schema
def handler_with_params(params: MockFunctionParams) -> str:
return f"Result: {params.value}"

prompt.with_function(
name="test_func",
description="Test function with params",
parameter_schema=MockFunctionParams,
handler=handler_with_params,
)

assert "test_func" in prompt.functions
assert prompt.functions["test_func"].name == "test_func"
assert prompt.functions["test_func"].description == "Test function with params"
assert prompt.functions["test_func"].parameter_schema == MockFunctionParams

# Test without parameter schema (no params function)
def handler_no_params() -> str:
return "No params result"

prompt.with_function(
name="no_params_func",
description="Function with no parameters",
handler=handler_no_params,
)

assert "no_params_func" in prompt.functions
assert prompt.functions["no_params_func"].name == "no_params_func"
assert prompt.functions["no_params_func"].description == "Function with no parameters"
assert prompt.functions["no_params_func"].parameter_schema is None

# Verify both work in send
result = await prompt.send("Test message")
assert result.response.content == "GENERATED - Test message"


class MockPlugin(BaseAIPlugin):
"""Mock plugin for testing that tracks all hook calls"""
Expand Down
9 changes: 5 additions & 4 deletions packages/openai/src/microsoft/teams/openai/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Any, Dict, Optional

from microsoft.teams.ai import Function
from pydantic import BaseModel, create_model
from pydantic import BaseModel, ConfigDict, create_model


def get_function_schema(func: Function[BaseModel]) -> Dict[str, Any]:
def get_function_schema(func: Function[Any]) -> Dict[str, Any]:
"""
Get JSON schema from a Function's parameter_schema.

Expand All @@ -34,7 +34,7 @@ def get_function_schema(func: Function[BaseModel]) -> Dict[str, Any]:
return func.parameter_schema.model_json_schema()


def parse_function_arguments(func: Function[BaseModel], arguments: Dict[str, Any]) -> Optional[BaseModel]:
def parse_function_arguments(func: Function[Any], arguments: Dict[str, Any]) -> Optional[BaseModel]:
"""
Parse function arguments into a BaseModel instance.

Expand All @@ -53,7 +53,8 @@ def parse_function_arguments(func: Function[BaseModel], arguments: Dict[str, Any

if isinstance(func.parameter_schema, dict):
# For dict schemas, create a simple BaseModel dynamically
DynamicModel = create_model("DynamicParams")
# Use extra='allow' to accept arbitrary fields from the arguments dict
DynamicModel = create_model("DynamicParams", __config__=ConfigDict(extra="allow"))
return DynamicModel(**arguments)
else:
# For Pydantic model schemas, parse normally
Expand Down
Loading