diff --git a/pkg/templates/python/cua/providers/gemini.py b/pkg/templates/python/cua/providers/gemini.py index 57aa52c..0a9a554 100644 --- a/pkg/templates/python/cua/providers/gemini.py +++ b/pkg/templates/python/cua/providers/gemini.py @@ -16,6 +16,9 @@ Tool, ComputerUse, Environment, + FunctionResponse, + FunctionResponseBlob, + FunctionResponsePart, ) from . import CuaProvider, TaskOptions, TaskResult @@ -29,6 +32,13 @@ PX_PER_NOTCH = 60 MAX_NOTCHES_PER_ACTION = 17 +PREDEFINED_ACTIONS = { + "click_at", "hover_at", "type_text_at", "scroll_document", + "scroll_at", "wait_5_seconds", "go_back", "go_forward", + "search", "navigate", "key_combination", "drag_and_drop", + "open_web_browser", +} + def _system_prompt() -> str: date = datetime.now().strftime("%A, %B %d, %Y") return ( @@ -115,20 +125,24 @@ async def run_task(self, options: TaskOptions) -> TaskResult: ) if result.get("error"): - responses.append(Part.from_function_response( + responses.append(Part(function_response=FunctionResponse( name=fc.name, response={"error": result["error"], "url": "about:blank"}, - )) + ))) else: - responses.append(Part.from_function_response( + fr_parts = None + if result.get("screenshot") and fc.name in PREDEFINED_ACTIONS: + fr_parts = [FunctionResponsePart( + inline_data=FunctionResponseBlob( + mime_type="image/png", + data=result["screenshot"], + ), + )] + responses.append(Part(function_response=FunctionResponse( name=fc.name, response={"url": result.get("url", "about:blank")}, - )) - if result.get("screenshot"): - responses.append(Part(inline_data={ - "mime_type": "image/png", - "data": result["screenshot"], - })) + parts=fr_parts, + ))) contents.append(Content(role="user", parts=responses)) diff --git a/pkg/templates/python/cua/pyproject.toml b/pkg/templates/python/cua/pyproject.toml index 37e844e..6cbedff 100644 --- a/pkg/templates/python/cua/pyproject.toml +++ b/pkg/templates/python/cua/pyproject.toml @@ -9,6 +9,5 @@ dependencies = [ "google-genai>=1.71.0", "httpx>=0.28.1", "kernel>=0.47.0", - "openai>=2.30.0", "python-dotenv>=1.2.2", ]