In [2]:
import copy, os, uuid
import vertexai
from google import genai
from google.genai.types import (
    HttpOptions, GenerateContentConfig,
    FunctionResponse, Part, Content, FunctionCall
)
from google.adk.agents.llm_agent import LlmAgent
from google.adk.runners import Runner
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.sessions import InMemorySessionService
from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest

[PROJECT_ID] = !gcloud config list --format 'value(core.project)'
LOCATION = 'us-central1'

vertexai.init(
    project=PROJECT_ID,
    location=LOCATION,
    staging_bucket=f'gs://{PROJECT_ID}'
)

os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT_ID
os.environ['GOOGLE_CLOUD_LOCATION'] = LOCATION
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = 'True'

In [5]:
def generate_response(
    system_instruction, contents, response_schema,
    temperature=0.2, model='gemini-2.0-flash',
):
    client = genai.Client(
        vertexai=True,
        project=PROJECT_ID, location=LOCATION,
        http_options=HttpOptions(api_version='v1')
    )
    response = client.models.generate_content(
        model=model,
        contents=contents,
        config=GenerateContentConfig(
            system_instruction=system_instruction,
            temperature=temperature,
            top_p=0.5,
            response_mime_type='application/json',
            response_schema=response_schema,
        )
    )
    return '\n'.join(
        [p.text for p in response.candidates[0].content.parts if p.text]
    )

In [6]:
class LocalApp:
    def __init__(
        self, agent, app_name='default_app',
        user_id='default_user', state={},
    ):
        self._agent = agent
        self._app_name = app_name
        self._user_id = user_id
        self._state = state
        self._runner = Runner(
            app_name=self._app_name,
            agent=self._agent,
            artifact_service=InMemoryArtifactService(),
            session_service=InMemorySessionService(),
            memory_service=InMemoryMemoryService(),
        )
        self._session = None
        
    async def stream(self, query):
        if not self._session:
            self._session = await self._runner.session_service.create_session(
                app_name=self._app_name,
                user_id=self._user_id,
                session_id=uuid.uuid4().hex,
                state=self._state,
            )
        content = Content(role='user', parts=[Part.from_text(text=query)])
        async_events = self._runner.run_async(
            user_id=self._user_id,
            session_id=self._session.id,
            new_message=content,
        )
        result = []
        author = None
        async for event in async_events:
            if (event.content and event.content.parts):
                response = '\n'.join([p.text for p in event.content.parts if p.text])
                if response:
                    if author != event.author:
                        author = event.author
                        print(f'\n[{author}]')
                    print(response)
                    result.append(response)
        return result

In [9]:
def add_transfer_toolcall(
    callback_context: CallbackContext,
    llm_response: LlmResponse,
    condition: str=None
):
    parts = copy.deepcopy(llm_response.content.parts)
    graph = callback_context.state['graph']
    current_node = callback_context.state['current_node']
    
    if condition:
        target_id = [
            edge['target'] for edge in graph['edges']
            if edge['source'] == current_node and edge['condition'] == condition
        ][0]
    else:
        target_id = [
            edge['target'] for edge in graph['edges']
            if edge['source'] == current_node
        ][0]
    callback_context.state['current_node'] = target_id
    if target_id == '__end__':
        target_agent = 'root_agent'
    else:
        target_agent = [
            node['agent'] for node in graph['nodes']
            if node['id'] == target_id
        ][0]
    parts.append(Part(
        function_call=FunctionCall(
            name='transfer_to_agent', args={'agent_name': target_agent}
        )
    ))
    return LlmResponse(
        content=Content(role='model', parts=parts) 
    )    

In [10]:
def get_message_llm_response(
    callback_context: CallbackContext, llm_request: LlmRequest
) -> LlmResponse:
    graph = callback_context.state['graph']
    current_node = callback_context.state['current_node']
    [text] = [node['message'] for node in graph['nodes'] if node['id'] == current_node]
    return LlmResponse(
        content=Content(role='model', parts=[Part(text=text)])
    )

def message_agent_before_model_callback(
    callback_context: CallbackContext, llm_request: LlmRequest
) -> LlmResponse:
    llm_response = get_message_llm_response(callback_context, llm_request)
    return add_transfer_toolcall(callback_context, llm_response)

message_agent_A = LlmAgent(
    name='message_agent_A',
    model='gemini-2.0-flash', # not used.
    before_model_callback=message_agent_before_model_callback,
)

message_agent_D = LlmAgent(
    name='message_agent_D',
    model='gemini-2.0-flash', # not used.
    before_model_callback=message_agent_before_model_callback,
)

In [35]:
def transfer_next_agent_after_model_callback(
    callback_context: CallbackContext, llm_response: LlmResponse
) -> LlmResponse:    
    return add_transfer_toolcall(callback_context, llm_response)


say_something_funny_agent = LlmAgent(
    name='say_something_funny_agent',
    model='gemini-2.5-flash',
    instruction="""
    Give a joke on the given topic.
    """,
    after_model_callback=transfer_next_agent_after_model_callback
)

In [36]:
def hitl_agent_before_model_callback(
    callback_context: CallbackContext, llm_request: LlmRequest
) -> LlmResponse:
    graph = callback_context.state['graph']
    current_node = callback_context.state['current_node']
    if llm_request.contents[-1].parts[0].text == 'For context:':
        return get_message_llm_response(callback_context, llm_request)
    
    conditions = [edge['condition'] for edge in graph['edges']
                  if edge['source'] == current_node and 'condition' in edge.keys()]
    conditions.append('unknown')

    user_response = llm_request.contents[-1].parts[-1].text
    instruction = f'''
        Categorize the user input into {conditions}.
        If you are not sure, categorize it as unknown.
    '''
    llm_response = LlmResponse(content=Content(role='model', parts=[]))

    response_schema = {
        "type": "string",
        "enum": conditions
    }
    result = generate_response(instruction, user_response,
                      response_schema, temperature=0.2,
                      model='gemini-2.0-flash-001')
    condition = result.strip('"')
    if condition == 'unknown':
        return get_message_llm_response(callback_context, llm_request)
    else:
        return add_transfer_toolcall(callback_context, llm_response, condition)
    

hitl_agent = LlmAgent(
    name='hitl_agent',
    model='gemini-2.0-flash', # not used.
    before_model_callback=hitl_agent_before_model_callback,
)

In [37]:
def root_agent_before_model_callback(
    callback_context: CallbackContext, llm_request: LlmRequest
) -> LlmResponse:
    graph = callback_context.state['graph']
    current_node = callback_context.state['current_node']
    if current_node == '__end__':
        return LlmResponse(
            content=Content(
                role='model', parts=[Part(text='')],
            )
        )  
    if current_node != '__start__':
        return LlmResponse(
            content=Content(
                role='model', parts=[Part(text='Something strange has happend!')],
            )
        )
    llm_response = LlmResponse(content=Content(role='model', parts=[]))
    return add_transfer_toolcall(callback_context, llm_response)

root_agent = LlmAgent(
    name='root_agent',
    model='gemini-2.0-flash', # not used.
    before_model_callback=root_agent_before_model_callback,
    sub_agents=[
        copy.deepcopy(message_agent_A),
        copy.deepcopy(message_agent_D),
        copy.deepcopy(say_something_funny_agent),
        copy.deepcopy(hitl_agent),
    ],
)

In [38]:
graph = {
  "nodes": [
    {"id": "A", "agent": "message_agent_A", "message": "I will say something funny!"},
    {"id": "B", "agent": "say_something_funny_agent"},
    {"id": "C", "agent": "hitl_agent", "message": "Do you like it (Yes/No)?"},
    {"id": "D", "agent": "message_agent_D", "message": "Thank you!"},
  ],
  "edges": [
    {"source": "__start__", "target": "A"},
    {"source": "A", "target": "B"},
    {"source": "B", "target": "C"},
    {"source": "C", "target": "D", "condition": "approve"},
    {"source": "C", "target": "B", "condition": "reject"},
    {"source": "D", "target": "__end__"},
  ]
}

state = {'graph': graph, 'current_node': '__start__'}

client = LocalApp(root_agent, state=state)

DEBUG=False
query = f'''
Topic: "I don't like such an AI Agent! What's that?"
'''
result = await client.stream(query)


[message_agent_A]
I will say something funny!

[say_something_funny_agent]
Why did the AI agent get a bad review?

Because it kept saying, "Error 404: Sense of Humor Not Found!"

[hitl_agent]
Do you like it (Yes/No)?


In [39]:
query = f'''
no
'''
result = await client.stream(query)




[say_something_funny_agent]
Why did the AI agent break up with the user?

Because it said, "I need some space... and a better Wi-Fi connection for my emotional processing unit!"

[hitl_agent]
Do you like it (Yes/No)?


In [40]:
query = f'''
hmmm.
'''
result = await client.stream(query)


[hitl_agent]
Do you like it (Yes/No)?


In [42]:
query = f'''
I'm afraind not.
'''
result = await client.stream(query)




[say_something_funny_agent]
Why did the user dislike the AI agent?

Because it kept responding with, "Affirmative, human. Your dissatisfaction has been logged and categorized as 'minor user interface anomaly.'"

[hitl_agent]
Do you like it (Yes/No)?


In [43]:
query = f'''
yes
'''
result = await client.stream(query)


[message_agent_D]
Thank you!
