# Example implementation of a LangGraph-style workflow and Human-in-the-loop in ADK

## Preparation

In [1]:
import copy, os, uuid
import vertexai
from google import genai
from google.genai.types import (
    HttpOptions, GenerateContentConfig,
    FunctionResponse, Part, Content, FunctionCall
)
from google.adk.agents.llm_agent import LlmAgent
from google.adk.runners import Runner
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.sessions import InMemorySessionService
from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse, LlmRequest

[PROJECT_ID] = !gcloud config list --format 'value(core.project)'
LOCATION = 'us-central1'

vertexai.init(
    project=PROJECT_ID,
    location=LOCATION,
    staging_bucket=f'gs://{PROJECT_ID}'
)

os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT_ID
os.environ['GOOGLE_CLOUD_LOCATION'] = LOCATION
os.environ['GOOGLE_GENAI_USE_VERTEXAI'] = 'True'

In [2]:
def generate_response(system_instruction, contents,
                      response_schema, temperature=0.4,
                      model='gemini-2.0-flash-001'):
    client = genai.Client(vertexai=True,
                          project=PROJECT_ID, location=LOCATION,
                          http_options=HttpOptions(api_version='v1'))
    response = client.models.generate_content(
        model=model,
        contents=contents,
        config=GenerateContentConfig(
            system_instruction=system_instruction,
            temperature=temperature,
            top_p=0.5,
            response_mime_type='application/json',
            response_schema=response_schema,
        )
    )
    return '\n'.join(
        [p.text for p in response.candidates[0].content.parts if p.text]
    )

In [3]:
class LocalApp:
    def __init__(
        self,
        agent,
        app_name='default_app',
        user_id='default_user',
        state={},
    ):
        self._agent = agent
        self._app_name = app_name
        self._user_id = user_id
        self._state = state
        self._runner = Runner(
            app_name=self._app_name,
            agent=self._agent,
            artifact_service=InMemoryArtifactService(),
            session_service=InMemorySessionService(),
            memory_service=InMemoryMemoryService(),
        )
        self._session = None
        
    async def stream(self, query):
        if not self._session:
            self._session = await self._runner.session_service.create_session(
                app_name=self._app_name,
                user_id=self._user_id,
                session_id=uuid.uuid4().hex,
                state=self._state,
            )
        content = Content(role='user', parts=[Part.from_text(text=query)])
        async_events = self._runner.run_async(
            user_id=self._user_id,
            session_id=self._session.id,
            new_message=content,
        )
        result = []
        author = None
        async for event in async_events:
            if DEBUG:
                print(event)
            if (event.content and event.content.parts):
                response = '\n'.join([p.text for p in event.content.parts if p.text])
                if response:
                    if author != event.author:
                        author = event.author
                        print(f'\n[{author}]')
                    print(response)
                    result.append(response)
        return result

## Helper functions

In [4]:
def _add_transfer_toolcall(
    callback_context: CallbackContext,
    llm_response: LlmResponse,
    condition: str=None
):
    parts = copy.deepcopy(llm_response.content.parts)
    graph = callback_context.state['graph']
    current_node = callback_context.state['current_node']
    
    if condition:
        [target_id] = [edge['target'] for edge in graph['edges']
                       if edge['source'] == current_node and edge['condition'] == condition]
    else:
        [target_id] = [edge['target'] for edge in graph['edges']
                       if edge['source'] == current_node]
    callback_context.state['current_node'] = target_id
    if target_id == '__end__':
        target_agent = 'root_agent'
    else:
        [target_agent] = [node['agent'] for node in graph['nodes']
                          if node['id'] == target_id]
    parts.append(Part(
        function_call=FunctionCall(
            name='transfer_to_agent', args={'agent_name': target_agent}
        )
    ))
    return LlmResponse(
        content=Content(role='model', parts=parts) 
    )

In [5]:
def _text_reponse(text):
    return LlmResponse(
        content=Content(role='model', parts=[Part(text=text)])
    )


def _get_message(callback_context, llm_request):
    graph = callback_context.state['graph']
    current_node = callback_context.state['current_node']
    [text] = [node['message'] for node in graph['nodes']
              if node['id'] == current_node]
    return text

## Functions to define nodes for specific roles

### Node to execute a specified function and displays the result

In [6]:
def lambda_node(name, function, description=''):

    def lambda_node_before_model_callback(
        callback_context: CallbackContext, llm_request: LlmRequest
    ) -> LlmResponse:
        text = function(callback_context=callback_context, llm_request=llm_request)
        llm_response = _text_reponse(text)
        return _add_transfer_toolcall(callback_context, llm_response)
    
    return LlmAgent(
        name=name,
        model='gemini-2.0-flash', # not used.
        description=description,
        before_model_callback=lambda_node_before_model_callback,
    )

### Node to display a specified message

In [7]:
def message_node(name):
    return lambda_node(name, _get_message)

### Node to process with LLM

In [8]:
def llm_node(name, model, instruction, description='', tools=[]):
    def transfer_next_agent_after_model_callback(
        callback_context: CallbackContext, llm_response: LlmResponse
    ) -> LlmResponse:    
        return _add_transfer_toolcall(callback_context, llm_response)

    return LlmAgent(
        name=name,
        model=model,
        description=description,
        instruction=instruction,
        tools=tools,
        after_model_callback=transfer_next_agent_after_model_callback
    )

### Nodes to execute human in the loop

In [9]:
def hitl_node(name, description=''):
    def _hitl_agent_before_model_callback(
        callback_context: CallbackContext, llm_request: LlmRequest
    ) -> LlmResponse:
        graph = callback_context.state['graph']
        current_node = callback_context.state['current_node']
        message = _get_message(callback_context, llm_request)

        if llm_request.contents[-1].parts[0].text == 'For context:':
            return _text_reponse(message)
    
        conditions = [edge['condition'] for edge in graph['edges']
                      if edge['source'] == current_node and 'condition' in edge.keys()]
        conditions.append('unknown')

        user_response = llm_request.contents[-1].parts[-1].text
        instruction = (f'Categorize the user input into {conditions}.'
                        'If you are not sure, categorize it as unknown.')
        llm_response = LlmResponse(content=Content(role='model', parts=[]))

        response_schema = {'type': 'string', 'enum': conditions}
        result = generate_response(instruction, user_response,
                          response_schema, temperature=0.2,
                          model='gemini-2.0-flash-001')
        condition = result.strip('"')
        if condition == 'unknown':
            return _text_reponse(message)
        else:
            return _add_transfer_toolcall(callback_context, llm_response, condition)
    
    return LlmAgent(
        name=name,
        model='gemini-2.0-flash', # not used.
        description=description,
        before_model_callback=_hitl_agent_before_model_callback,
    )

### Root node, the starting point of the workflow

In [10]:
def root_node(name, sub_agents, description=''):
    def _root_agent_before_model_callback(
        callback_context: CallbackContext, llm_request: LlmRequest
    ) -> LlmResponse:
        graph = callback_context.state['graph']
        current_node = callback_context.state['current_node']
        if current_node == '__end__':
            return _text_reponse('')
        if current_node != '__start__':
            return _text_reponse('Something strange has happend!')
        return _add_transfer_toolcall(callback_context, _text_reponse(''))
    
    return LlmAgent(
        name=name,
        model='gemini-2.0-flash', # not used
        description=description,
        sub_agents=sub_agents,
        before_model_callback=_root_agent_before_model_callback
    )

## Example workflow

### Node definitions

In [11]:
say_something_funny_agent = llm_node(
    name='say_something_funny_agent',
    model='gemini-2.5-flash',
    instruction="""
    Given a topic, say something funny in a single sentence.
    Only output the answer. No comments before or after.
    """,
    description="An agent to say something funny.",
)

root_agent = root_node(
    name='root_agent',
    sub_agents=[
        message_node('start_message'),
        message_node('end_message'),
        say_something_funny_agent,
        hitl_node('htil_node'),
    ],
)

### Graph definition

In [12]:
graph = {
  "nodes": [
    {"id": "A", "agent": "start_message", "message": "I wall say something funny!"},
    {"id": "B", "agent": "say_something_funny_agent"},
    {"id": "C", "agent": "htil_node", "message": "Do you like it? (Yes/No)"},
    {"id": "D", "agent": "end_message", "message": "I'm glad to hear it."},
  ],
  "edges": [
    {"source": "__start__", "target": "A"},
    {"source": "A", "target": "B"},
    {"source": "B", "target": "C"},
    {"source": "C", "target": "D", "condition": "approve"},
    {"source": "C", "target": "B", "condition": "reject"},
    {"source": "D", "target": "__end__"},
  ]
}

state = {'graph': graph, 'current_node': '__start__'}

### Example run

In [13]:
client = LocalApp(root_agent, state=state)
DEBUG = False

query = '''
Say something regarding AI and its intelligence.
'''
result = await client.stream(query)


[start_message]
I wall say something funny!

[say_something_funny_agent]
AI's intelligence is like a super-smart dog: it can fetch any data you want, but it still can't explain why it chases its own tail.

[htil_node]
Do you like it? (Yes/No)


In [14]:
query = '''
nope.
'''
result = await client.stream(query)




[say_something_funny_agent]
AI's intelligence is so advanced, it can simulate a human perfectly, except for one thing: it still can't pretend to enjoy small talk.

[htil_node]
Do you like it? (Yes/No)


In [15]:
query = '''
more effort!
'''
result = await client.stream(query)


[htil_node]
Do you like it? (Yes/No)


In [16]:
query = '''
I mean no.
'''
result = await client.stream(query)




[say_something_funny_agent]
AI's intelligence is so vast, it can beat the world's best chess player, but it still can't figure out why humans keep saying "bless you" after a sneeze.

[htil_node]
Do you like it? (Yes/No)


In [17]:
query = '''
ok.
'''
result = await client.stream(query)


[end_message]
I'm glad to hear it.
