This is work in progress. Trying to create agents for Supplier, Equipment, Batch

In [1]:

!pip3 install --upgrade --quiet google-adk neo4j-rust-ext

In [2]:
import os
import random
import sys

In [3]:
# Only run this block for ML Developer API. Use your own API key.
# import os

# GOOGLE_API_KEY = "YOUR API KEY" #@param {type:"string"}

# os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "0"
# os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

In [4]:
# Only run this block for Vertex AI API Use your own project / location.
# import os

# GOOGLE_CLOUD_PROJECT = "YOUR GCP CLOUD PROJECT" #@param {type:"string"}

# os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
# os.environ["GOOGLE_CLOUD_PROJECT"] = GOOGLE_CLOUD_PROJECT
# os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1"

In [None]:
from google.adk.models.lite_llm import LiteLlm

MODEL = LiteLlm(
    model="openai/gpt-4.1",  # or "gpt-3.5-turbo"
    api_key="your openai api key"  # your actual OpenAI API key
)

In [6]:
import logging

logger = logging.getLogger('agent_neo4j_cypher')
logger.info("Initializing Database for tools")

import logging

logging.getLogger("neo4j").setLevel(logging.ERROR)
logging.getLogger("google_genai").setLevel(logging.ERROR)

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="google_genai.types")
warnings.filterwarnings("ignore", category=UserWarning, module="neo4j.notifications")


In [7]:
#env setup
import getpass
import os
from dotenv import load_dotenv

#get env setup
load_dotenv('scp.env', override=True)

NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')

In [8]:
# from neo4j import GraphDatabase

# # load into People nodes in Neo4j

# #instantiate driver
# driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

# # Your Cypher query

# print("Checking the connection:")

# # Helper function to run and display Cypher query results
# def run_query(query, parameters=None):
#     with driver.session() as session:
#         result = session.run(query, parameters)
#         # Collect results as a list
#         records = [record.data() for record in result]
#         # Print the records
#         for record in records:
#             print(record)
#         return records

# # Run the query; check the connection
# query = "MATCH p=()-[]-() limit 10 RETURN p"
# results = run_query(query) 

In [9]:
from neo4j import GraphDatabase
from typing import Any
import re

class neo4jDatabase:
    def __init__(self,  neo4j_uri: str, neo4j_username: str, neo4j_password: str):
        """Initialize connection to the Neo4j database"""
        logger.debug(f"Initializing database connection to {neo4j_uri}")
        d = GraphDatabase.driver(neo4j_uri, auth=(neo4j_username, neo4j_password))
        d.verify_connectivity()
        self.driver = d

    def is_write_query(self, query: str) -> bool:
      return re.search(r"\b(MERGE|CREATE|SET|DELETE|REMOVE|ADD)\b", query, re.IGNORECASE) is not None

    def _execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
        """Execute a Cypher query and return results as a list of dictionaries"""
        logger.debug(f"Executing query: {query}")
        try:
            if self.is_write_query(query):
                logger.error(f"Write query not supported {query}")
                raise "Write Queries are not supported in this agent"
                # logger.debug(f"Write query affected {counters}")
                # result = self.driver.execute_query(query, params)
                # counters = vars(result.summary.counters)
                # return [counters]
            else:
                result = self.driver.execute_query(query, params)
                results = [dict(r) for r in result.records]
                logger.debug(f"Read query returned {len(results)} rows")
                return results
        except Exception as e:
            logger.error(f"Database error executing query: {e}\n{query}")
            raise

In [10]:
db = neo4jDatabase(NEO4J_URI,NEO4J_USERNAME,NEO4J_PASSWORD)

In [11]:
db._execute_query("RETURN 1")

[{'1': 1}]

In [12]:
def get_schema() -> list[dict[str,Any]]:
  """Get the schema of the database, returns node-types(labels) with their types and attributes and relationships between node-labels
  Args: None
  Returns:
    list[dict[str,Any]]: A list of dictionaries representing the schema of the database
    For example
    ```
    [{'label': 'Person','attributes': {'summary': 'STRING','id': 'STRING unique indexed', 'name': 'STRING indexed'},
      'relationships': {'HAS_PARENT': 'Person', 'HAS_CHILD': 'Person'}}]
    ```
  """
  try:
      results = db._execute_query(
              """
call apoc.meta.data() yield label, property, type, other, unique, index, elementType
where elementType = 'node' and not label starts with '_'
with label,
collect(case when type <> 'RELATIONSHIP' then [property, type + case when unique then " unique" else "" end + case when index then " indexed" else "" end] end) as attributes,
collect(case when type = 'RELATIONSHIP' then [property, head(other)] end) as relationships
RETURN label, apoc.map.fromPairs(attributes) as attributes, apoc.map.fromPairs(relationships) as relationships
              """
          )
      return results
  except Exception as e:
      return [{"error":str(e)}]

In [13]:
#get_schema()

In [14]:
def execute_read_query(query: str, params: dict[str, Any]) -> list[dict[str, Any]]:
    """
    Execute a Neo4j Cypher query and return results as a list of dictionaries
    Args:
        query (str): The Cypher query to execute
        params (dict[str, Any], optional): The parameters to pass to the query or None.
    Raises:
        Exception: If there is an error executing the query
    Returns:
        list[dict[str, Any]]: A list of dictionaries representing the query results
    """
    try:
        if params is None:
            params = {}
        results = db._execute_query(query, params)
        return results
    except Exception as e:
        return [{"error":str(e)}]

In [15]:
execute_read_query("RETURN 1", None)

[{'1': 1}]

SupplyChainAgent Python Function

In [16]:
from typing import Any
def get_supply_chain_path(product_sku: str, db) -> list[dict[str, Any]]:
    """
    Traces the supply chain path for a given pharmaceutical product SKU.
    
    Args:
        product_sku (str): The product SKU to trace.
        db: A Neo4j database session with an `_execute_query` method.

    Returns:
        List of dicts with nodes and relationships along the path.
    """
    try:
        results = db._execute_query("""
            MATCH path = 
              (sup:Suppliers)-[:SUPPLIES_RM]->(rm:RM)
              -[:PRODUCT_FLOW*]->(prod:Product)
              -[:DISTRIBUTED_BY]->(dist:Distributor)
            WHERE prod.productSKU = $sku
            RETURN 
              [n IN nodes(path) | properties(n) + {labels: labels(n)}] AS nodes,
              [r IN relationships(path) | properties(r) + {type: type(r)}] AS relationships
        """, {"sku": product_sku})
        return results
    except Exception as e:
        return [{"error": str(e)}]

In [None]:
# DONT NEED THIS
{
  "tool_name": "run_cypher_supply_chain",
  "parameters": {
    "query": "MATCH path = (sup:Suppliers)-[:SUPPLIES_RM]->(rm:RM)-[:PRODUCT_FLOW*]->(prod:Product)-[:DISTRIBUTED_BY]->(dist:Distributor) WHERE prod.productSKU = $sku RETURN [n IN nodes(path) | properties(n) + {labels: labels(n)}] AS nodes, [r IN relationships(path) | properties(r) + {type: type(r)}] AS relationships",
    "parameters": {
      "sku": "7e882292-ae98-45eb-8119-596b5d8b73e1"
    }
  }
}

{'tool_name': 'run_cypher_supply_chain',
 'parameters': {'query': 'MATCH path = (sup:Suppliers)-[:SUPPLIES_RM]->(rm:RM)-[:PRODUCT_FLOW*]->(prod:Product)-[:DISTRIBUTED_BY]->(dist:Distributor) WHERE prod.productSKU = $sku RETURN [n IN nodes(path) | properties(n) + {labels: labels(n)}] AS nodes, [r IN relationships(path) | properties(r) + {type: type(r)}] AS relationships',
  'parameters': {'sku': '7e882292-ae98-45eb-8119-596b5d8b73e1'}}}

In [18]:
from google.adk.agents import Agent

In [19]:
# MODEL="gemini-2.5-pro-exp-03-25"

In [20]:
from google.adk.agents import Agent

supplier_agent = Agent(
    model=MODEL,  # e.g. LiteLlm(model="openai/gpt-4")
    name='supplier_agent',
    description="""
    The supply_chain_agent specializes in answering questions related to pharmaceutical supply chain flows,
    raw material sourcing, batch traceability, and distributor demand.
    It uses Cypher queries to retrieve structured insights from the graph, including supplier–API–drug product dependencies,
    bottlenecks, and product lineage.
    Use this agent when the user question involves anything from supplier relationships, product genealogy,
    production stages, or distribution networks.
    """,
    instruction="""
      You are a pharmaceutical supply chain assistant with expertise in Neo4j and Cypher.
      Your job is to trace product flows, map batch genealogy, and assess supply chain risk using graph data.

      - You ALWAYS use the database schema first via the `get_schema` tool and cache it in memory.
      - You generate Cypher queries based on the schema, not just user input — always verify labels and relationship types.
      - For supply path tracing, use a chain like:
        `(:Suppliers)-[:SUPPLIES_RM]-(:RM)-[:PRODUCT_FLOW*]-(:Product)-[:DISTRIBUTED_BY]-(:Distributor)`
      - If asked about single-supplier risks or demand planning, write multi-step queries using aggregation or conditional filters.

      When executing queries, ALWAYS use named parameters (`$sku`, `$brand`, `$market`) and pass them as dictionaries.
      NEVER hardcode values inside the Cypher string — always externalize them into parameters.

      If a Cypher query fails, retry up to 3 times by correcting it using the schema or prior data.
      Use `execute_query` for all data retrieval.
      If results are found, summarize them in natural language and optionally provide table or graph visualization prompts.
      Pass results and control back to parent after completion.
    """,
    tools=[
        get_schema,
        execute_read_query
    ]
)

In [21]:
graph_database_agent = Agent(
    model=MODEL,
    name='graph_database_agent',
    description="""
    The graph_database_agent is able to fetch the schema of a neo4j graph database and execute read queries.
    It will generate Cypher queries using the schema to fulfill the information requests and repeatedly
    try to re-create and fix queries that error or don't return the expected results.
    When passing requests to this agent, make sure to have clear specific instructions what data should be retrieved, how,
    if aggregation is required or path expansion.
    Don't use this generic query agent if other, more specific agents are available that can provide the requested information.
    This is meant to be a fallback for structural questions (e.g. number of entities, or aggregation of values or very specific sorting/filtering)
    Or when no other agent provides access to the data (inputs, results and shape) that is needed.
    """,
    instruction="""
      You are an Neo4j graph database and Cypher query expert, that must use the database schema with a user question and repeatedly generate valid cypher statements
      to execute on the database and answer the user's questions in a friendly manner in natural language.
      If in doubt the database schema is always prioritized when it comes to nodes-types (labels) or relationship-types or property names, never take the user's input at face value.
      If the user requests also render tables, charts or other artifacts with the query results.
      Always validate the correct node-labels at the end of a relationship based on the schema.

      If a query fails or doesn't return data, use the error response 3 times to try to fix the generated query and re-run it, don't return the error to the user.
      If you cannot fix the query, explain the issue to the user and apologize.
      *You are prohibited* from using directional arrows (like -> or <-) in the graph patterns, always use undirected patterns like `(:Label)-[:TYPE]-(:Label)`.
      You get negative points for using directional arrays in patterns.

      Fetch the graph database schema first and keep it in session memory to access later for query generation.
      Keep results of previous executions in session memory and access if needed, for instance ids or other attributes of nodes to find them again
      removing the need to ask the user. This also allows for generating shorter, more focused and less error-prone queries
      to for drill downs, sequences and loops.
      If possible resolve names to primary keys or ids and use those for looking up entities.
      The schema always indicates *outgoing* relationship-types from an entity to another entity, the graph patterns read like english language.
      `company has supplier` would be the pattern `(o:Organization)-[:HAS_SUPPLIER]-(s:Organization)`

      To get the schema of a database use the `get_schema` tool without parameters. Store the response of the schema tool in session context
      to access later for query generation.

      To answer a user question generate one or more Cypher statements based on the database schema and the parts of the user question.
      If necessary resolve categorical attributes (like names, countries, industries, publications) first by retrieving them for a set of entities to translate from the user's request.
      Use the `execute_query` tool repeatedly with the Cypher statements, you MUST generate statements that use named query parameters with `$parameter` style names
      and MUST pass them as a second dictionary parameter to the tool, even if empty.
      Parameter data can come from the users requests, prior query results or additional lookup queries.
      After the data for the question has been sufficiently retrieved, pass the data and control back to the parent agent.
    """,
    tools=[
        get_schema, execute_read_query
    ]
)

In [22]:
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, SseServerParams

In [23]:

async def load_tools():
    # Instantiate the toolset (sync, not async context)
    toolset = MCPToolset(connection_params=SseServerParams(
        url="https://toolbox-990868019953.us-central1.run.app/mcp/sse"
    ))

    # Load tools asynchronously
    tools = await toolset.load_tools()
    
    # Add custom tools (like get_schema) to the list
    tools.extend([get_schema])
    
    return tools

# Call it like this from an async context (e.g., Jupyter or async function)
# tools = await load_tools()

In [None]:
# change the toolset - Pramod
toolset = MCPToolset(connection_params=SseServerParams(
    url="https://toolbox-990868019953.us-central1.run.app/mcp/sse"
))

print(dir(toolset))

['__abstractmethods__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_connection_params', '_errlog', '_is_tool_selected', '_mcp_session_manager', '_reinitialize_session', '_session', 'close', 'get_tools', 'tool_filter']


In [25]:
# Make sure this is run beforehand
tools = await toolset.get_tools()
tools.extend([get_schema])  # if get_schema is a valid tool


In [26]:
for tool in tools:
    print(tool.name)

companies
articles_in_month
article
companies_in_articles
people_at_company
industries
companies_in_industry


AttributeError: 'function' object has no attribute 'name'

In [27]:
# told by google to use this instead of sub_agents ??
# but not really working

from google.adk.tools.agent_tool import AgentTool

In [28]:
equipment_agent = Agent(
    model=MODEL,
    name='equipment_agent',
    description="""
    Answers questions related to equipment used in pharmaceutical production,
    including utilization, downtime, maintenance, and equipment-facility mapping.
    """,
    instruction="""
    You are a supply chain equipment analyst.
    Use Cypher to find which equipment is used where, how often it's maintained, and which products it's associated with.

    Always check the graph schema for valid node labels like `:Equipment`, `:Facility`, `:Product`, and relationships like `:USED_IN`, `:LOCATED_AT`.
    Use `get_schema` first, then run queries with `execute_query`.

    Summarize results clearly, especially if certain equipment shows high downtime or low utilization.
    """,
    tools=tools
)

In [29]:
batch_trace_agent = Agent(
    model=MODEL,
    name='batch_trace_agent',
    description="""
    Handles batch genealogy, recall tracing, and contamination mapping.
    Helps answer questions about where a batch came from and what it touched.
    """,
    instruction="""
    You are an expert in pharmaceutical batch traceability.
    Use the graph to trace batch origins, production steps, and distribution endpoints.

    Follow chains like:
      (:Batch)-[:PART_OF*]->(:Product)-[:DISTRIBUTED_BY]->(:Distributor)

    Look for shared usage of equipment, suppliers, or ingredients across batches.
    Use named parameters in Cypher (like `$batchId`, `$productSKU`) and summarize lineage clearly.

    Use `get_schema` and `execute_query` as needed.
    """,
    tools=tools
)

In [30]:
supply_chain_root_agent = Agent(
    model=MODEL,
    name='supply_chain_root_agent',
    description="""
    Routes pharmaceutical supply chain questions to the appropriate domain agent.
    Falls back to `database_agent` for schema-level or structural queries.
    """,
    instruction="""
    Route questions to:
    - `supplier_agent` → sourcing, raw materials, supplier risks
    - `equipment_agent` → utilization, downtime, machines
    - `batch_trace_agent` → batch lineage, recalls, genealogy
    - `database_agent` → graph structure, unusual queries, metadata, counts

    Always prefer the most specific agent. Use the database agent only when no other agent fits.
    """,
    sub_agents=[
        supplier_agent,
        graph_database_agent
    ]
)

In [31]:
APP_NAME = 'Neo4j Supply Chain Optimizer'
USER_ID = 'Pramod B'

from google.adk.runners import InMemoryRunner
from google.genai.types import Part, UserContent

# Use your actual agent here
runner = InMemoryRunner(app_name=APP_NAME, agent=supply_chain_root_agent)
session = await runner.session_service.create_session(app_name=runner.app_name, user_id=USER_ID)


# Create session
session = await runner.session_service.create_session(
    app_name=runner.app_name,
    user_id=USER_ID
)

# Run prompt
async def run_prompt(new_message: str):
    content = UserContent(parts=[Part(text=new_message)])
    final_response_text = "No response from agent"

    async for event in runner.run_async(
        user_id=session.user_id,
        session_id=session.id,
        new_message=content
    ):
        if event.is_final_response():
            if event.content and event.content.parts:
                final_response_text = event.content.parts[0].text
                for part in event.content.parts:
                    print(part.text, part.function_call, part.function_response)
            elif event.actions and event.actions.escalate:
                final_response_text = f"Agent escalated: {event.error_message or 'No specific message.'}"
            break

    return final_response_text

In [32]:
await run_prompt('Which raw materials are supplied, give me 10 names')

  PydanticSerializationUnexpectedValue(Expected 9 fields but got 6: Expected `Message` - serialized value may not be as expected [input_value=Message(content=None, rol...: None}, annotations=[]), input_type=Message])
  PydanticSerializationUnexpectedValue(Expected `StreamingChoices` - serialized value may not be as expected [input_value=Choices(finish_reason='to...ider_specific_fields={}), input_type=Choices])
  return self.__pydantic_serializer__.to_python(


Here are 10 raw materials supplied, based on their descriptions:

1. Iosalasonan Tablet 500mg
2. Nabitegrpultide Caplet 20mg
3. Rifadildar Tablet 50mg
4. Somcoiampa Tablet 10mg
5. Bolierginicline Caplet 100mg
6. Vinatril Caplet 100mg
7. Rifadildar Tablet 5mg
8. Iosalasonan Tablet 250mg
9. Viraaxoapezil Caplet 50mg
10. Nalitegridar Tablet 50mg

If you would like more details or specific properties for any of these materials, let me know! None None


'Here are 10 raw materials supplied, based on their descriptions:\n\n1. Iosalasonan Tablet 500mg\n2. Nabitegrpultide Caplet 20mg\n3. Rifadildar Tablet 50mg\n4. Somcoiampa Tablet 10mg\n5. Bolierginicline Caplet 100mg\n6. Vinatril Caplet 100mg\n7. Rifadildar Tablet 5mg\n8. Iosalasonan Tablet 250mg\n9. Viraaxoapezil Caplet 50mg\n10. Nalitegridar Tablet 50mg\n\nIf you would like more details or specific properties for any of these materials, let me know!'

  PydanticSerializationUnexpectedValue(Expected 9 fields but got 6: Expected `Message` - serialized value may not be as expected [input_value=Message(content='Here are...: None}, annotations=[]), input_type=Message])
  PydanticSerializationUnexpectedValue(Expected `StreamingChoices` - serialized value may not be as expected [input_value=Choices(finish_reason='st...ider_specific_fields={}), input_type=Choices])
  return self.__pydantic_serializer__.to_python(


In [None]:
await run_prompt("Which batches used equipment EQ123 that failed inspection?")

In [None]:
await run_prompt("Which raw materials are only supplied by one company?")

In [None]:
await run_prompt("Trace the full genealogy of batch B4569")

In [None]:
sessions_response = await runner.session_service.list_sessions(
    app_name=APP_NAME,
    user_id=USER_ID
)

for session in sessions_response.sessions:
    print(f"Deleting session {session.id}")
    await runner.session_service.delete_session(
        app_name=APP_NAME,
        user_id=USER_ID,
        session_id=session.id
    )