diff --git a/packages/ai/src/microsoft/teams/ai/chat_prompt.py b/packages/ai/src/microsoft/teams/ai/chat_prompt.py index cc157a44..49908d82 100644 --- a/packages/ai/src/microsoft/teams/ai/chat_prompt.py +++ b/packages/ai/src/microsoft/teams/ai/chat_prompt.py @@ -6,7 +6,7 @@ import inspect from dataclasses import dataclass from inspect import isawaitable -from typing import Any, Awaitable, Callable, Optional, Self, TypeVar, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Self, TypeVar, Union, cast, overload from pydantic import BaseModel @@ -58,17 +58,67 @@ def __init__( self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {} self.plugins: list[AIPluginProtocol] = plugins or [] - def with_function(self, function: Function[T]) -> Self: + @overload + def with_function(self, function: Function[T]) -> Self: ... + + @overload + def with_function( + self, + *, + name: str, + description: str, + parameter_schema: Union[type[T], Dict[str, Any]], + handler: FunctionHandlers, + ) -> Self: ... + + @overload + def with_function( + self, + *, + name: str, + description: str, + handler: FunctionHandlerWithNoParams, + ) -> Self: ... + + def with_function( + self, + function: Function[T] | None = None, + *, + name: str | None = None, + description: str | None = None, + parameter_schema: Union[type[T], Dict[str, Any], None] = None, + handler: FunctionHandlers | None = None, + ) -> Self: """ Add a function to the available functions for this prompt. + Can be called in three ways: + 1. with_function(function=Function(...)) + 2. with_function(name=..., description=..., parameter_schema=..., handler=...) + 3. with_function(name=..., description=..., handler=...) - for functions with no parameters + Args: - function: Function to add to the available functions + function: Function object to add (first overload) + name: Function name (second and third overload) + description: Function description (second and third overload) + parameter_schema: Function parameter schema (second overload, optional) + handler: Function handler (second and third overload) Returns: Self for method chaining """ - self.functions[function.name] = function + if function is not None: + self.functions[function.name] = function + else: + if name is None or description is None or handler is None: + raise ValueError("When not providing a Function object, name, description, and handler are required") + func = Function[T]( + name=name, + description=description, + parameter_schema=parameter_schema, + handler=handler, + ) + self.functions[func.name] = func return self def with_plugin(self, plugin: AIPluginProtocol) -> Self: diff --git a/packages/ai/src/microsoft/teams/ai/function.py b/packages/ai/src/microsoft/teams/ai/function.py index 5475967f..beafcc1b 100644 --- a/packages/ai/src/microsoft/teams/ai/function.py +++ b/packages/ai/src/microsoft/teams/ai/function.py @@ -71,12 +71,18 @@ class Function(Generic[Params]): Type Parameters: Params: Pydantic model class defining the function's parameter schema, if any. + + Note: + For best type safety, use explicit type parameters when creating Function objects: + Function[SearchPokemonParams](name=..., parameter_schema=SearchPokemonParams, handler=...) + + This ensures the handler parameter type matches the parameter_schema at compile time. """ name: str # Unique identifier for the function description: str # Human-readable description of what the function does parameter_schema: Union[type[Params], Dict[str, Any], None] # Pydantic model class, JSON schema dict, or None - handler: FunctionHandlers # Function implementation (sync or async) + handler: Union[FunctionHandler[Params], FunctionHandlerWithNoParams] # Function implementation (sync or async) @dataclass diff --git a/packages/ai/tests/test_chat_prompt.py b/packages/ai/tests/test_chat_prompt.py index 7c953795..0f39dab7 100644 --- a/packages/ai/tests/test_chat_prompt.py +++ b/packages/ai/tests/test_chat_prompt.py @@ -296,6 +296,46 @@ async def test_different_message_types(self, mock_model: MockAIModel) -> None: result3 = await prompt.send(model_msg) assert result3.response.content == "GENERATED - Model message" + @pytest.mark.asyncio + async def test_with_function_unpacked_parameters(self, mock_model: MockAIModel) -> None: + """Test with_function using unpacked parameters instead of Function object""" + prompt = ChatPrompt(mock_model) + + # Test with parameter schema + def handler_with_params(params: MockFunctionParams) -> str: + return f"Result: {params.value}" + + prompt.with_function( + name="test_func", + description="Test function with params", + parameter_schema=MockFunctionParams, + handler=handler_with_params, + ) + + assert "test_func" in prompt.functions + assert prompt.functions["test_func"].name == "test_func" + assert prompt.functions["test_func"].description == "Test function with params" + assert prompt.functions["test_func"].parameter_schema == MockFunctionParams + + # Test without parameter schema (no params function) + def handler_no_params() -> str: + return "No params result" + + prompt.with_function( + name="no_params_func", + description="Function with no parameters", + handler=handler_no_params, + ) + + assert "no_params_func" in prompt.functions + assert prompt.functions["no_params_func"].name == "no_params_func" + assert prompt.functions["no_params_func"].description == "Function with no parameters" + assert prompt.functions["no_params_func"].parameter_schema is None + + # Verify both work in send + result = await prompt.send("Test message") + assert result.response.content == "GENERATED - Test message" + class MockPlugin(BaseAIPlugin): """Mock plugin for testing that tracks all hook calls""" diff --git a/packages/openai/src/microsoft/teams/openai/function_utils.py b/packages/openai/src/microsoft/teams/openai/function_utils.py index a5f49313..012647a4 100644 --- a/packages/openai/src/microsoft/teams/openai/function_utils.py +++ b/packages/openai/src/microsoft/teams/openai/function_utils.py @@ -6,10 +6,10 @@ from typing import Any, Dict, Optional from microsoft.teams.ai import Function -from pydantic import BaseModel, create_model +from pydantic import BaseModel, ConfigDict, create_model -def get_function_schema(func: Function[BaseModel]) -> Dict[str, Any]: +def get_function_schema(func: Function[Any]) -> Dict[str, Any]: """ Get JSON schema from a Function's parameter_schema. @@ -34,7 +34,7 @@ def get_function_schema(func: Function[BaseModel]) -> Dict[str, Any]: return func.parameter_schema.model_json_schema() -def parse_function_arguments(func: Function[BaseModel], arguments: Dict[str, Any]) -> Optional[BaseModel]: +def parse_function_arguments(func: Function[Any], arguments: Dict[str, Any]) -> Optional[BaseModel]: """ Parse function arguments into a BaseModel instance. @@ -53,7 +53,8 @@ def parse_function_arguments(func: Function[BaseModel], arguments: Dict[str, Any if isinstance(func.parameter_schema, dict): # For dict schemas, create a simple BaseModel dynamically - DynamicModel = create_model("DynamicParams") + # Use extra='allow' to accept arbitrary fields from the arguments dict + DynamicModel = create_model("DynamicParams", __config__=ConfigDict(extra="allow")) return DynamicModel(**arguments) else: # For Pydantic model schemas, parse normally diff --git a/packages/openai/tests/test_function_utils.py b/packages/openai/tests/test_function_utils.py new file mode 100644 index 00000000..dc202983 --- /dev/null +++ b/packages/openai/tests/test_function_utils.py @@ -0,0 +1,270 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +# pyright: basic + +from typing import Optional + +import pytest +from microsoft.teams.ai import Function +from microsoft.teams.openai.function_utils import get_function_schema, parse_function_arguments +from pydantic import BaseModel, ValidationError + + +class SimpleParams(BaseModel): + """Simple parameter model for testing.""" + + name: str + age: int + + +class OptionalParams(BaseModel): + """Parameter model with optional fields.""" + + required_field: str + optional_field: Optional[str] = None + + +class EmptyParams(BaseModel): + """Empty parameter model.""" + + pass + + +def dummy_handler(params: BaseModel) -> str: + """Dummy handler for testing.""" + return "test" + + +def dummy_handler_no_params() -> str: + """Dummy handler with no params for testing.""" + return "test" + + +class TestGetFunctionSchema: + """Tests for get_function_schema function.""" + + def test_get_schema_from_pydantic_model(self): + """Test getting schema from a Pydantic model.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=SimpleParams, + handler=dummy_handler, + ) + + schema = get_function_schema(func) + + assert isinstance(schema, dict) + assert "properties" in schema + assert "name" in schema["properties"] + assert "age" in schema["properties"] + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["age"]["type"] == "integer" + + def test_get_schema_from_dict(self): + """Test getting schema from a dict.""" + dict_schema = { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, + "required": ["param1"], + } + + func = Function( + name="test_func", + description="Test function", + parameter_schema=dict_schema, + handler=dummy_handler, + ) + + schema = get_function_schema(func) + + assert schema == dict_schema + # Ensure original is not modified + assert schema is not dict_schema + + def test_get_schema_with_no_parameters(self): + """Test getting schema when function has no parameters.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=None, + handler=dummy_handler_no_params, + ) + + schema = get_function_schema(func) + + assert schema == {} + + +class TestParseFunctionArguments: + """Tests for parse_function_arguments function.""" + + def test_parse_with_pydantic_model(self): + """Test parsing arguments with a Pydantic model schema.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=SimpleParams, + handler=dummy_handler, + ) + + arguments = {"name": "John", "age": 30} + result = parse_function_arguments(func, arguments) + + assert result is not None + assert isinstance(result, SimpleParams) + assert result.name == "John" + assert result.age == 30 + + def test_parse_with_pydantic_model_validation(self): + """Test that Pydantic validation works correctly.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=SimpleParams, + handler=dummy_handler, + ) + + # Invalid arguments (age should be int) + arguments = {"name": "John", "age": "not_an_int"} + + with pytest.raises(ValidationError): + parse_function_arguments(func, arguments) + + def test_parse_with_dict_schema_and_arguments(self): + """Test parsing with dict schema and non-empty arguments.""" + dict_schema = { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, + } + + func = Function( + name="test_func", + description="Test function", + parameter_schema=dict_schema, + handler=dummy_handler, + ) + + arguments = {"param1": "value1", "param2": 42} + result = parse_function_arguments(func, arguments) + + assert result is not None + assert isinstance(result, BaseModel) + assert result.param1 == "value1" # pyright: ignore + assert result.param2 == 42 # pyright: ignore + + def test_parse_with_dict_schema_and_empty_arguments(self): + """Test parsing with dict schema and empty arguments dict - BUG CASE.""" + dict_schema = { + "type": "object", + "properties": {"param1": {"type": "string"}}, + } + + func = Function( + name="test_func", + description="Test function", + parameter_schema=dict_schema, + handler=dummy_handler, + ) + + # This is the bug case: empty arguments dict + arguments = {} + result = parse_function_arguments(func, arguments) + + assert result is not None + assert isinstance(result, BaseModel) + # The DynamicModel should handle empty args gracefully + # Currently this may fail or behave unexpectedly + + def test_parse_with_no_parameter_schema(self): + """Test parsing when function has no parameter schema.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=None, + handler=dummy_handler_no_params, + ) + + arguments = {} + result = parse_function_arguments(func, arguments) + + assert result is None + + def test_parse_with_optional_fields(self): + """Test parsing with optional fields.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=OptionalParams, + handler=dummy_handler, + ) + + # Only required field provided + arguments = {"required_field": "test"} + result = parse_function_arguments(func, arguments) + + assert result is not None + assert isinstance(result, OptionalParams) + assert result.required_field == "test" + assert result.optional_field is None + + def test_parse_with_empty_pydantic_model(self): + """Test parsing with an empty Pydantic model.""" + func = Function( + name="test_func", + description="Test function", + parameter_schema=EmptyParams, + handler=dummy_handler, + ) + + arguments = {} + result = parse_function_arguments(func, arguments) + + assert result is not None + assert isinstance(result, EmptyParams) + + def test_parse_preserves_dict_schema_immutability(self): + """Test that parsing doesn't modify the original schema.""" + dict_schema = { + "type": "object", + "properties": {"param1": {"type": "string"}}, + } + original_schema = dict_schema.copy() + + func = Function( + name="test_func", + description="Test function", + parameter_schema=dict_schema, + handler=dummy_handler, + ) + + arguments = {"param1": "value1"} + parse_function_arguments(func, arguments) + + # Ensure original schema unchanged + assert func.parameter_schema == original_schema + + def test_parse_dict_schema_model_dump(self): + """Test that model_dump() works correctly with dict schemas.""" + dict_schema = { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, + } + + func = Function( + name="test_func", + description="Test function", + parameter_schema=dict_schema, + handler=dummy_handler, + ) + + arguments = {"param1": "value1", "param2": 42} + result = parse_function_arguments(func, arguments) + + assert result is not None + # Verify model_dump() returns the arguments correctly + dumped = result.model_dump() + assert dumped == arguments diff --git a/tests/ai-test/src/handlers/function_calling.py b/tests/ai-test/src/handlers/function_calling.py index bd36c41f..141d0056 100644 --- a/tests/ai-test/src/handlers/function_calling.py +++ b/tests/ai-test/src/handlers/function_calling.py @@ -82,7 +82,7 @@ def get_location_handler(params: GetLocationParams) -> str: return location -def get_weather_handler(params: GetWeatherParams) -> str: +def get_weather_handler(params: BaseModel) -> str: """Get weather for location (mock)""" weather_by_location: Dict[str, Dict[str, Any]] = { "Seattle": {"temperature": 65, "condition": "sunny"}, @@ -90,13 +90,12 @@ def get_weather_handler(params: GetWeatherParams) -> str: "New York": {"temperature": 75, "condition": "rainy"}, } - weather = weather_by_location.get(params.location) + location = getattr(params, "location") # noqa + weather = weather_by_location.get(location) if not weather: return "Sorry, I could not find the weather for that location" - return ( - f"The weather in {params.location} is {weather['condition']} with a temperature of {weather['temperature']}°F" - ) + return f"The weather in {location} is {weather['condition']} with a temperature of {weather['temperature']}°F" async def handle_multiple_functions(model: AIModel, ctx: ActivityContext[MessageActivity]) -> None: @@ -111,12 +110,15 @@ async def handle_multiple_functions(model: AIModel, ctx: ActivityContext[Message handler=get_location_handler, ) ).with_function( - Function( - name="weather_search", - description="Search for weather at a specific location", - parameter_schema=GetWeatherParams, - handler=get_weather_handler, - ) + name="weather_search", + description="Search for weather at a specific location", + parameter_schema={ + "title": "GetWeatherParams", + "type": "object", + "properties": {"location": {"title": "Location", "type": "string"}}, + "required": ["location"], + }, + handler=get_weather_handler, ) chat_result = await agent.send(