# ADK を Workflow tool として無理やり利用するサンプル

## 事前準備

In [1]:
import copy, json, os, re, uuid
import vertexai
from google.genai.types import Part, Content, FunctionCall
from google.adk.agents.llm_agent import LlmAgent
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService

from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest
from google.adk.tools import ToolContext

[PROJECT_ID] = !gcloud config list --format 'value(core.project)'
LOCATION = 'us-central1'

vertexai.init(project=PROJECT_ID, location=LOCATION)

os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT_ID
os.environ['GOOGLE_CLOUD_LOCATION'] = LOCATION
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = 'True'

In [2]:
class LocalApp:
    def __init__(self, agent, user_id='default_user', state={}):
        self._agent = agent
        self._user_id = user_id
        self._runner = Runner(
            app_name=self._agent.name,
            agent=self._agent,
            artifact_service=InMemoryArtifactService(),
            session_service=InMemorySessionService(),
            memory_service=InMemoryMemoryService(),
        )
        self._state = state
        self._session = None
        
    async def stream(self, query):
        if not self._session:
            self._session = await self._runner.session_service.create_session(
                app_name=self._agent.name,
                user_id=self._user_id,
                session_id=uuid.uuid4().hex,
                state=self._state,
            )
        content = Content(role='user', parts=[Part.from_text(text=query)])
        async_events = self._runner.run_async(
            user_id=self._user_id,
            session_id=self._session.id,
            new_message=content,
        )
        result = []
        async for event in async_events:
            if DEBUG:
                print(f'----\n{event}\n----')
            if (event.content and event.content.parts):
                response = '\n'.join([p.text for p in event.content.parts if p.text])
                if response:
                    print(response)
                    result.append(response)
        return result

## Workflow 制御のコールバック関数

In [3]:
def workflow_tool_callback(tool_name):
    async def before_model_callback(
        callback_context: CallbackContext, llm_request: LlmRequest
    ) -> LlmResponse:
        
        agent_name = callback_context.agent_name

        last_part = callback_context._invocation_context.session.events[-1].content.parts[-1]
        if last_part.function_response and last_part.function_response.name == 'transfer_to_agent':
            is_transferred = True
        else:
            is_transferred = False

        # Finish when transferred to root_agent.
        if agent_name == 'root_agent' and is_transferred:
            return LlmResponse(
                content=Content(role='model', parts=[Part(text='done')]) 
            )

        # Run tool when directly called or transferred to me.
        if (not llm_request.contents) or is_transferred:
            part = Part(function_call=FunctionCall(name=tool_name, args={}))
            return LlmResponse(
                content=Content(role='model', parts=[part]) 
            )

        # Transfer to next_agent
        response = llm_request.contents[-1].parts[-1].function_response.response
        if 'next_agent' in response:
            next_agent = response['next_agent']
        else:
            next_agent = agent_name # Default to myself
        part = Part(function_call=FunctionCall(
                        name='transfer_to_agent',
                        args={'agent_name': next_agent}))
        return LlmResponse(
            content=Content(role='model', parts=[part]) 
        )
    
    return before_model_callback

## サンプル実装

`ping_agent` は `ping_tool()` を実行して、`pong_agent` に遷移する。

In [4]:
def ping_tool(tool_context: ToolContext) -> dict:
    count = tool_context.state.get('count')
    count -= 1
    tool_context.state['count'] = count

    print('ping', count)
    return {'next_agent': 'pong_agent'}


ping_agent = LlmAgent(
    model='gemini-2.0-flash-001', # not used
    name='ping_agent',
    description='An agent that always run ping_tool.',
    instruction='',
    tools=[ping_tool],
    before_model_callback = workflow_tool_callback('ping_tool'),
)

`pong_agent` は `pong_tool()` を実行して、

- `count <= 0`: `root_agent` に遷移する。
- `count % 2 == 0`: `pong_agent` に遷移する。
- `else`: `ping_agent` に遷移する。

In [5]:
def pong_tool(tool_context: ToolContext) -> dict:
    count = tool_context.state.get('count')
    count -= 1
    tool_context.state['count'] = count

    print('pong', count)
    if count <= 0:
        return {'next_agent': 'root_agent'}
    elif count % 2 == 0:
        return {'next_agent': 'pong_agent'}
    else:
        return {'next_agent': 'ping_agent'}


pong_agent = LlmAgent(
    model='gemini-2.0-flash-001', # not used
    name='pong_agent',
    description='An agent that always run pong_tool.',
    instruction='',
    tools=[pong_tool],
    before_model_callback = workflow_tool_callback('pong_tool'),
)

`root_agent` は初回呼び出し時に `ping_agent` に遷移する。

他のエージェントから遷移してきた場合は、そこで終了する。

In [6]:
def root_tool(tool_context: ToolContext) -> dict:
    return {'next_agent': 'ping_agent'}

root_agent = LlmAgent(
    model='gemini-2.0-flash-001', # not used
    name='root_agent',
    description='An agent that always run root_tool.',
    instruction='',
    sub_agents = [
        copy.deepcopy(ping_agent),
        copy.deepcopy(pong_agent),
    ],
    tools = [root_tool],
    before_model_callback = workflow_tool_callback('root_tool'),
)

`root_agent` -> `ping_agetn` -> `pong_agent` -> ... -> `root_agent` のワークフローが実行される。

In [7]:
state = {'count': 6}
client = LocalApp(root_agent, state=state)

DEBUG = False
_ = await client.stream('')

ping 5
pong 4
pong 3
ping 2
pong 1
ping 0
pong -1
done


## LLM の応答とは関係なく強制的に Transfer させる実験

In [13]:
def force_transfer_callback(next_agent):
    async def after_model_callback(
        callback_context: CallbackContext, llm_response: LlmResponse
    ) -> LlmResponse:
        parts = copy.deepcopy(llm_response.content.parts)
        part = Part(function_call=FunctionCall(
                        name='transfer_to_agent',
                        args={'agent_name': next_agent}))
        parts.append(part)
        return LlmResponse(
            content=Content(role='model', parts=parts) 
        )
    
    return after_model_callback

In [55]:
echo_agent = LlmAgent(
    model='gemini-2.0-flash-001',
    name='echo_agent',
    description='An agent that repeat the last output text.',
    instruction='Say "Echo: " first, and then repeat the last output text.',
)

In [56]:
root_agent = LlmAgent(
    model='gemini-2.0-flash-001',
    name='root_agent',
    description='greeting!',
    instruction='',
    sub_agents = [
        copy.deepcopy(echo_agent),
    ],
    after_model_callback = force_transfer_callback('echo_agent'),
)

In [57]:
client = LocalApp(root_agent)

DEBUG = False
_ = await client.stream('How are you doing?')

I'm doing well, thank you for asking!

Echo: I'm doing well, thank you for asking!

