From 7703192d212d8d9ed73b7ff6114f5e1a4b4e3afc Mon Sep 17 00:00:00 2001 From: Roman Yermilov Date: Tue, 18 Mar 2025 14:34:26 +0100 Subject: [PATCH] feat(py): add tool support for openai-compat plugin --- .../genkit-ai/src/genkit/core/action.py | 4 +- .../plugins/compat_oai/models/handler.py | 5 +- .../genkit/plugins/compat_oai/models/model.py | 183 +++++++++++++--- .../plugins/compat_oai/openai_plugin.py | 2 +- .../src/genkit/plugins/compat_oai/typing.py | 7 - py/plugins/compat-oai/tests/test_handler.py | 4 +- py/plugins/compat-oai/tests/test_model.py | 72 +++++-- .../compat-oai/tests/test_tool_calling.py | 197 ++++++++++++++++++ py/samples/openai/src/main.py | 53 ++++- 9 files changed, 466 insertions(+), 61 deletions(-) create mode 100644 py/plugins/compat-oai/tests/test_tool_calling.py diff --git a/py/packages/genkit-ai/src/genkit/core/action.py b/py/packages/genkit-ai/src/genkit/core/action.py index 5237068453..92c00f0eca 100644 --- a/py/packages/genkit-ai/src/genkit/core/action.py +++ b/py/packages/genkit-ai/src/genkit/core/action.py @@ -402,7 +402,9 @@ def run( return self.__fn( input, - ActionRunContext(on_chunk=on_chunk, context=_action_context.get()), + ActionRunContext( + on_chunk=on_chunk, context=_action_context.get(None) + ), ) async def arun( diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py index 0ea330d53a..9dfa7793b4 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py @@ -24,6 +24,7 @@ from openai import OpenAI +from genkit.ai.registry import GenkitRegistry from genkit.core.action import ActionRunContext from genkit.core.typing import ( GenerateRequest, @@ -51,7 +52,7 @@ def __init__(self, model: Any): @classmethod def get_model_handler( - cls, model: str, client: OpenAI + cls, model: str, client: OpenAI, registry: GenkitRegistry ) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]: """ Factory method to initialize the model handler for the specified OpenAI model. @@ -70,7 +71,7 @@ def get_model_handler( if model not in SUPPORTED_OPENAI_MODELS: raise ValueError(f"Model '{model}' is not supported.") - openai_model = OpenAIModel(model, client) + openai_model = OpenAIModel(model, client, registry) return cls(openai_model).generate def validate_version(self, version: str): diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py index a9c2a3a052..8939d99815 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py @@ -19,10 +19,17 @@ OpenAI Compatible Models for Genkit. """ +import json from collections.abc import Callable -from openai import OpenAI +from openai import OpenAI, pydantic_function_tool +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) +from genkit.ai.registry import GenkitRegistry +from genkit.core.action import ActionKind from genkit.core.typing import ( GenerateRequest, GenerateResponse, @@ -30,8 +37,8 @@ Message, Role, TextPart, + ToolDefinition, ) -from genkit.plugins.compat_oai.typing import ChatMessage class OpenAIModel: @@ -39,7 +46,7 @@ class OpenAIModel: Handles OpenAI API interactions for the Genkit plugin. """ - def __init__(self, model: str, client: OpenAI): + def __init__(self, model: str, client: OpenAI, registry: GenkitRegistry): """ Initializes the OpenAIModel instance with the specified model and OpenAI client parameters. @@ -48,16 +55,17 @@ def __init__(self, model: str, client: OpenAI): """ self._model = model self._openai_client = client + self._registry = registry @property def name(self): return self._model - def _get_messages(self, messages: list[Message]) -> list[ChatMessage]: + def _get_messages(self, messages: list[Message]) -> list[dict]: """ Converts the request messages into the format required by OpenAI's API. - :param request: A list of the user messages. + :param messages: A list of the user messages. :return: A list of dictionaries, where each dictionary represents a message with 'role' and 'content' fields. :raises ValueError: If no messages are provided in the request. @@ -65,26 +73,104 @@ def _get_messages(self, messages: list[Message]) -> list[ChatMessage]: if not messages: raise ValueError('No messages provided in the request.') return [ - ChatMessage( - role=m.role.value, - content=''.join( - part.root.text - for part in m.content - if part.root.text is not None + { + 'role': m.role.value, + 'content': ''.join( + filter(None, (part.root.text for part in m.content)) ), - ) + } for m in messages ] - def _get_openai_config(self, request: GenerateRequest) -> dict: + def _get_tools_definition(self, tools: list[ToolDefinition]) -> list[dict]: + """ + Converts the provided tools into OpenAI-compatible function call format. + + :param tools: A list of tool definitions. + :return: A list of dictionaries representing the formatted tools. + """ + result = [] + for tool_definition in tools: + action = self._registry.registry.lookup_action( + ActionKind.TOOL, tool_definition.name + ) + function_call = pydantic_function_tool( + model=action.input_type._type, + name=tool_definition.name, + description=tool_definition.description, + ) + result.append(function_call) + return result + + def _get_openai_request_config(self, request: GenerateRequest) -> dict: openai_config = { 'messages': self._get_messages(request.messages), 'model': self._model, } + if request.tools: + openai_config['tools'] = self._get_tools_definition(request.tools) if request.config: openai_config.update(**request.config.model_dump()) return openai_config + def _evaluate_tool(self, name: str, arguments: str): + """ + Executes a registered tool with the given arguments and returns the result. + + :param name: Name of the tool to execute. + :param arguments: JSON-encoded arguments for the tool. + :return: String representation of the tool's output. + """ + action = self._registry.registry.lookup_action(ActionKind.TOOL, name) + + # Parse and validate arguments + arguments = json.loads(arguments) + arguments = action.input_type.validate_python(arguments) + + # Execute the tool and return its result + return str(action.run(arguments)) + + def _get_evaluated_tool_message_param( + self, tool_call: ChoiceDeltaToolCall | ChatCompletionMessageToolCall + ) -> dict: + """ + Evaluates a tool call and formats its response in a structure compatible with OpenAI's API. + + :param tool_call: The tool call object containing function name and arguments. + :return: A dictionary formatted as a response message from a tool. + """ + return { + 'role': Role.TOOL.value, + 'tool_call_id': tool_call.id, + 'content': self._evaluate_tool( + tool_call.function.name, tool_call.function.arguments + ), + } + + def _get_assistant_message_param( + self, tool_calls: list[ChoiceDeltaToolCall] + ) -> dict: + """ + Formats tool call data into assistant message structure compatible with OpenAI's API. + + :param tool_calls: A list of tool call objects. + :return: A dictionary representing the tool calls formatted for OpenAI. + """ + return { + 'role': 'assistant', + 'tool_calls': [ + { + 'id': tool_call.id, + 'type': 'function', + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + }, + } + for tool_call in tool_calls + ], + } + def generate(self, request: GenerateRequest) -> GenerateResponse: """ Processes the request using OpenAI's chat completion API and returns the generated response. @@ -92,15 +178,28 @@ def generate(self, request: GenerateRequest) -> GenerateResponse: :param request: The GenerateRequest object containing the input text and configuration. :return: A GenerateResponse object containing the generated message. """ - response = self._openai_client.chat.completions.create( - **self._get_openai_config(request=request) - ) + openai_config = self._get_openai_request_config(request=request) + response = self._openai_client.chat.completions.create(**openai_config) + + while (completion := response.choices[0]).finish_reason == 'tool_calls': + # Append the assistant's response message + openai_config['messages'].append(completion.message) + + # Evaluate tool calls and append their responses + for tool_call in completion.message.tool_calls: + openai_config['messages'].append( + self._get_evaluated_tool_message_param(tool_call) + ) + + response = self._openai_client.chat.completions.create( + **openai_config + ) return GenerateResponse( request=request, message=Message( role=Role.MODEL, - content=[TextPart(text=response.choices[0].message.content)], + content=[TextPart(text=completion.message.content)], ), ) @@ -117,23 +216,51 @@ def generate_stream( Returns: GenerateResponse: An empty response message when streaming is complete. """ - openai_config = self._get_openai_config(request=request) + openai_config = self._get_openai_request_config(request=request) openai_config['stream'] = True + # Initiate the streaming response from OpenAI stream = self._openai_client.chat.completions.create(**openai_config) - for chunk in stream: - choice = chunk.choices[0] - if not choice.delta.content: - continue + while not stream.response.is_closed: + tool_calls = {} - response_chunk = GenerateResponseChunk( - role=Role.MODEL, - index=choice.index, - content=[TextPart(text=choice.delta.content)], - ) + for chunk in stream: + choice = chunk.choices[0] + + # Handle standard text content + if choice.delta.content is not None: + callback( + GenerateResponseChunk( + role=Role.MODEL, + index=choice.index, + content=[TextPart(text=choice.delta.content)], + ) + ) + + # Handle tool calls when OpenAI requires tool execution + elif choice.delta.tool_calls: + for tool_call in choice.delta.tool_calls: + # Accumulate fragmented tool call arguments + tool_calls.setdefault( + tool_call.index, tool_call + ).function.arguments += tool_call.function.arguments + + # If tool calls were requested, process them and continue the conversation + if tool_calls: + openai_config['messages'].append( + self._get_assistant_message_param(tool_calls.values()) + ) + + for tool_call in tool_calls.values(): + openai_config['messages'].append( + self._get_evaluated_tool_message_param(tool_call) + ) - callback(response_chunk) + # Re-initiate streaming after processing tool calls + stream = self._openai_client.chat.completions.create( + **openai_config + ) # Return an empty response when streaming is complete return GenerateResponse( diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py index 3277ddcc6a..73a565572a 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py @@ -58,7 +58,7 @@ def initialize(self, ai: GenkitRegistry) -> None: """ for model_name, model_info in SUPPORTED_OPENAI_MODELS.items(): handler = OpenAIModelHandler.get_model_handler( - model=model_name, client=self._openai_client + model=model_name, client=self._openai_client, registry=ai ) ai.define_model( diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py index 7676b27ae5..718f4de5d0 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py @@ -17,13 +17,6 @@ from pydantic import BaseModel, ConfigDict -class ChatMessage(BaseModel): - model_config = ConfigDict(extra='forbid', populate_by_name=True) - - role: str - content: str - - class OpenAIConfig(BaseModel): model_config = ConfigDict(extra='forbid', populate_by_name=True) diff --git a/py/plugins/compat-oai/tests/test_handler.py b/py/plugins/compat-oai/tests/test_handler.py index ec052cc179..81352bab5c 100644 --- a/py/plugins/compat-oai/tests/test_handler.py +++ b/py/plugins/compat-oai/tests/test_handler.py @@ -33,7 +33,7 @@ def test_get_model_handler(): """Test get_model_handler method returns a callable.""" model_name = GPT_4 handler = OpenAIModelHandler.get_model_handler( - model=model_name, client=MagicMock() + model=model_name, client=MagicMock(), registry=MagicMock() ) assert callable(handler) @@ -44,7 +44,7 @@ def test_get_model_handler_invalid(): ValueError, match="Model 'unsupported-model' is not supported." ): OpenAIModelHandler.get_model_handler( - model='unsupported-model', client=MagicMock() + model='unsupported-model', client=MagicMock(), registry=MagicMock() ) diff --git a/py/plugins/compat-oai/tests/test_model.py b/py/plugins/compat-oai/tests/test_model.py index adcf7e7eb8..94ab664a9b 100644 --- a/py/plugins/compat-oai/tests/test_model.py +++ b/py/plugins/compat-oai/tests/test_model.py @@ -32,19 +32,19 @@ def test_get_messages(sample_request): Test _get_messages method. Ensures the method correctly converts GenerateRequest messages into OpenAI-compatible ChatMessage format. """ - model = OpenAIModel(model=GPT_4, client=MagicMock()) + model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) messages = model._get_messages(sample_request.messages) assert len(messages) == 1 - assert messages[0].role == Role.USER - assert messages[0].content == 'Hello, world!' + assert messages[0]['role'] == Role.USER + assert messages[0]['content'] == 'Hello, world!' def test_get_messages_empty(): """ Test _get_messages raises ValueError when no messages are provided. """ - model = OpenAIModel(model=GPT_4, client=MagicMock()) + model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) with pytest.raises( ValueError, match='No messages provided in the request.' ): @@ -56,8 +56,8 @@ def test_get_openai_config(sample_request): Test _get_openai_config method. Ensures the method correctly constructs the OpenAI API configuration dictionary. """ - model = OpenAIModel(model=GPT_4, client=MagicMock()) - openai_config = model._get_openai_config(sample_request) + model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) + openai_config = model._get_openai_request_config(sample_request) assert isinstance(openai_config, dict) assert openai_config['model'] == GPT_4 @@ -74,7 +74,7 @@ def test_generate(sample_request): choices=[MagicMock(message=MagicMock(content='Hello, user!'))] ) - model = OpenAIModel(model=GPT_4, client=mock_client) + model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) response = model.generate(sample_request) mock_client.chat.completions.create.assert_called_once() @@ -88,17 +88,53 @@ def test_generate_stream(sample_request): Test generate_stream method ensures it processes streamed responses correctly. """ mock_client = MagicMock() - mock_stream = [ - MagicMock( - choices=[MagicMock(index=0, delta=MagicMock(content='Hello'))] - ), - MagicMock( - choices=[MagicMock(index=0, delta=MagicMock(content=', world!'))] - ), - ] - mock_client.chat.completions.create.return_value = mock_stream - - model = OpenAIModel(model=GPT_4, client=mock_client) + + class MockStream: + def __init__(self, data: list[str]) -> None: + self._data = data + self._current = 0 + + # Initialize response mock with stream state + self.response = MagicMock() + self.response.is_closed = False + + def __iter__(self): + return self + + def __next__(self): + # Return an empty chunk to indicate end of stream + if self._current == len(self._data): + chunk = MagicMock( + choices=[MagicMock(index=0, delta=MagicMock(content=None))] + ) + + # Close stream and stop iteration + elif self._current > len(self._data): + self.response.is_closed = True + raise StopIteration + + # Return current chunk from data + else: + chunk = MagicMock( + choices=[ + MagicMock( + index=0, + delta=MagicMock(content=self._data[self._current]), + ) + ] + ) + + # Move to the next chunk + self._current += 1 + + return chunk + + mock_client.chat.completions.create.return_value = MockStream([ + 'Hello', + ', world!', + ]) + + model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) collected_chunks = [] def callback(chunk: GenerateResponseChunk): diff --git a/py/plugins/compat-oai/tests/test_tool_calling.py b/py/plugins/compat-oai/tests/test_tool_calling.py new file mode 100644 index 0000000000..10e8286d2e --- /dev/null +++ b/py/plugins/compat-oai/tests/test_tool_calling.py @@ -0,0 +1,197 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +from genkit.core.action import ActionKind +from genkit.core.typing import GenerateResponseChunk +from genkit.plugins.compat_oai.models import OpenAIModel +from genkit.plugins.compat_oai.models.model_info import GPT_4 + + +def test_get_evaluated_tool_message_param_returns_expected_message(): + tool_call = MagicMock() + tool_call.id = 'abc123' + tool_call.function.name = 'tool_fn' + tool_call.function.arguments = '{"key": "val"}' + + model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) + model._evaluate_tool = MagicMock(return_value='tool_result') + + result = model._get_evaluated_tool_message_param(tool_call) + assert result['role'] == 'tool' + assert result['tool_call_id'] == 'abc123' + assert result['content'] == 'tool_result' + + +def test_evaluate_tool_executes_registered_action(): + mock_action = MagicMock() + mock_action.input_type.validate_python.return_value = {'a': 1} + mock_action.run.return_value = 'result' + + mock_registry = MagicMock() + mock_registry.registry.lookup_action.return_value = mock_action + + model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=mock_registry) + + result = model._evaluate_tool('my_tool', '{"a": 1}') + mock_registry.registry.lookup_action.assert_called_once_with( + ActionKind.TOOL, 'my_tool' + ) + mock_action.input_type.validate_python.assert_called_once_with({'a': 1}) + mock_action.run.assert_called_once_with({'a': 1}) + assert result == 'result' + + +def test_generate_with_tool_calls_executes_tools(sample_request): + mock_tool_call = MagicMock() + mock_tool_call.id = 'tool123' + mock_tool_call.function.name = 'tool_fn' + mock_tool_call.function.arguments = '{"a": 1}' + + # First call triggers tool execution + first_response = MagicMock() + first_response.choices = [ + MagicMock( + finish_reason='tool_calls', + message=MagicMock(tool_calls=[mock_tool_call]), + ) + ] + # Second call is the model response + second_response = MagicMock() + second_response.choices = [ + MagicMock( + finish_reason='stop', message=MagicMock(content='final response') + ) + ] + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = [ + first_response, + second_response, + ] + + mock_action = MagicMock() + mock_action.input_type.validate_python.return_value = {'a': 1} + mock_action.run.return_value = 'tool result' + + mock_registry = MagicMock() + mock_registry.registry.lookup_action.return_value = mock_action + + model = OpenAIModel(model=GPT_4, client=mock_client, registry=mock_registry) + + response = model.generate(sample_request) + + assert response.message.content[0].root.text == 'final response' + assert mock_client.chat.completions.create.call_count == 2 + + +def test_generate_stream_with_tool_calls(sample_request): + mock_tool_call = MagicMock() + mock_tool_call.id = 'tool123' + mock_tool_call.index = 0 + mock_tool_call.function.name = 'tool_fn' + mock_tool_call.function.arguments = '' + + # First chunk: tool call starts + chunk1 = MagicMock() + chunk1.choices = [ + MagicMock( + index=0, + delta=MagicMock( + content=None, + tool_calls=[mock_tool_call], + ), + ) + ] + + # Second chunk: continuation of tool call arguments + chunk2 = MagicMock() + chunk2.choices = [ + MagicMock( + index=0, + delta=MagicMock( + content=None, + tool_calls=[ + MagicMock( + index=0, + function=MagicMock(arguments='{"a": "123"}'), + id='tool123', + ) + ], + ), + ) + ] + + # Third stream after tool execution: final model response + final_chunk = MagicMock() + final_chunk.choices = [ + MagicMock( + index=0, + delta=MagicMock(content='final stream content', tool_calls=None), + ) + ] + + # Simulate the streaming lifecycle + class MockStream: + def __init__(self, chunks): + self._chunks = chunks + self._current = 0 + self.response = MagicMock() + self.response.is_closed = False + + def __iter__(self): + return self + + def __next__(self): + if self._current >= len(self._chunks): + self.response.is_closed = True + raise StopIteration + chunk = self._chunks[self._current] + self._current += 1 + return chunk + + # First stream: tool call chunks + first_stream = MockStream([chunk1, chunk2]) + # Second stream: model output after tool is processed + second_stream = MockStream([final_chunk]) + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = [ + first_stream, + second_stream, + ] + + # Set up mock tool evaluation + mock_action = MagicMock() + mock_action.input_type.validate_python.return_value = {'a': '123'} + mock_action.run.return_value = 'tool response' + + mock_registry = MagicMock() + mock_registry.registry.lookup_action.return_value = mock_action + + model = OpenAIModel(model=GPT_4, client=mock_client, registry=mock_registry) + + chunks = [] + + def callback(chunk: GenerateResponseChunk): + chunks.append(chunk.content[0].root.text) + + model.generate_stream(sample_request, callback) + + assert chunks == ['final stream content'] + assert mock_action.run.call_count == 1 + assert mock_client.chat.completions.create.call_count == 2 diff --git a/py/samples/openai/src/main.py b/py/samples/openai/src/main.py index 81c8708acc..9d8924a719 100644 --- a/py/samples/openai/src/main.py +++ b/py/samples/openai/src/main.py @@ -15,7 +15,9 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +from decimal import Decimal +import requests from pydantic import BaseModel, Field from genkit.ai import Genkit @@ -30,6 +32,11 @@ class MyInput(BaseModel): b: int = Field(description='b field') +class WeatherRequest(BaseModel): + latitude: Decimal + longitude: Decimal + + @ai.flow() def sum_two_numbers2(my_input: MyInput): return my_input.a + my_input.b @@ -59,11 +66,53 @@ async def say_hi_stream(name: str): return result +@ai.tool('Get current temperature for provided coordinates in celsius') +def get_weather_tool(coordinates: WeatherRequest) -> str: + url = ( + f'https://api.open-meteo.com/v1/forecast?' + f'latitude={coordinates.latitude}&longitude={coordinates.longitude}' + f'¤t=temperature_2m' + ) + response = requests.get(url=url) + data = response.json() + return data['current']['temperature_2m'] + + +@ai.flow() +async def get_weather_flow(location: str): + response = await ai.generate( + model=openai_model('gpt-4'), + config={'model': 'gpt-4-0613', 'temperature': 1}, + prompt=f"What's the weather like in {location} today?", + tools=['get_weather_tool'], + ) + return response.message.content[0].root.text + + +@ai.flow() +async def get_weather_flow_stream(location: str): + stream, _ = ai.generate_stream( + model=openai_model('gpt-4'), + config={'model': 'gpt-4-0613', 'temperature': 1}, + prompt=f"What's the weather like in {location} today?", + tools=['get_weather_tool'], + ) + result = '' + async for data in stream: + for part in data.content: + result += part.root.text + return result + + async def main() -> None: - print(await say_hi_stream('John Doe')) - print(await say_hi('John Doe')) print(sum_two_numbers2(MyInput(a=1, b=3))) + print(await say_hi('John Doe')) + print(await say_hi_stream('John Doe')) + + print(await get_weather_flow('London and Paris')) + print(await get_weather_flow_stream('London and Paris')) + if __name__ == '__main__': asyncio.run(main())