# LangGraph Neo4j MCP Agent

This notebook demonstrates how to build a **ReAct agent** using LangGraph that connects to a Neo4j graph database via the Model Context Protocol (MCP).

## Features

- **LangGraph 1.0+**: Latest stable agent framework with caching and deferred nodes
- **MCP Protocol**: Standardized tool integration using `langchain-mcp-adapters`
- **Neo4j Integration**: Query graph databases using natural language
- **AWS Bedrock**: Uses Claude models via the Bedrock Converse API

## Prerequisites

1. AWS credentials configured (IAM role in SageMaker or credentials file)
2. Bedrock model access enabled for Claude
3. MCP server endpoint URL and credentials

---

## 1. Install Dependencies

Install the required packages using `%pip` (recommended for SageMaker Studio notebooks).

> **Note**: After installation, you must restart the kernel for changes to take effect.

In [None]:
# Install required packages for LangGraph Neo4j MCP Agent
# Using %pip ensures packages install in the correct kernel environment

%pip install --upgrade --quiet \
    langchain>=0.3.14 \
    langgraph>=0.2.60 \
    langchain-aws>=0.2.10 \
    langchain-mcp-adapters>=0.2.1 \
    mcp>=1.3.0 \
    httpx>=0.28.0 \
    boto3>=1.36.0 \
    nest-asyncio>=1.6.0

print("Packages installed successfully!")
print("Please restart the kernel to load the new packages.")

## 2. Restart Kernel

Run the cell below to restart the kernel after installing packages.

> **Important**: After running this cell, wait for the kernel to restart, then continue from the next section.

In [None]:
# Restart the kernel to load newly installed packages
# This is required after pip install in SageMaker notebooks

import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

## 3. Import Libraries

Import all required libraries after the kernel restart.

In [None]:
import asyncio
import json
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional

import httpx
import nest_asyncio
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain_mcp_adapters.client import MultiServerMCPClient

# Enable nested asyncio for Jupyter notebooks
nest_asyncio.apply()

print("Libraries imported successfully!")

### Verify Package Versions

Check that the installed packages meet the minimum version requirements.

In [None]:
import importlib.metadata

packages = [
    "langchain",
    "langgraph",
    "langchain-aws",
    "langchain-mcp-adapters",
    "mcp",
    "httpx",
    "boto3",
]

print("Installed Package Versions:")
print("-" * 40)
for pkg in packages:
    try:
        version = importlib.metadata.version(pkg)
        print(f"{pkg:25} {version}")
    except importlib.metadata.PackageNotFoundError:
        print(f"{pkg:25} NOT INSTALLED")

## 4. Configuration

Configure the agent settings including:
- AWS region for Bedrock
- Model ID (Claude Sonnet)
- MCP server credentials

In [None]:
# Configuration
AWS_REGION = "us-west-2"  # Region for Bedrock

# Use inference profile ARN from Bedrock IDE (required for SageMaker Unified Studio)
# See MODEL.md for how to get this ARN
INFERENCE_PROFILE_ARN = "arn:aws:bedrock:us-west-2:159878781974:application-inference-profile/hsl5b7kh1279"

# MCP Server Configuration
# Option 1: Load from credentials file
CREDENTIALS_FILE = Path(".mcp-credentials.json")

# Option 2: Set directly (uncomment and fill in)
# MCP_GATEWAY_URL = "https://your-gateway-url/mcp"
# MCP_ACCESS_TOKEN = "your-access-token"

print(f"AWS Region: {AWS_REGION}")
print(f"Inference Profile: {INFERENCE_PROFILE_ARN}")
print(f"Credentials file: {CREDENTIALS_FILE}")

### System Prompt

Define the system prompt that guides the agent's behavior when querying the Neo4j database.

In [None]:
SYSTEM_PROMPT = """You are a helpful Neo4j database assistant with access to tools that let you query a Neo4j graph database.

Your capabilities include:
- Retrieve the database schema to understand node labels, relationship types, and properties
- Execute read-only Cypher queries to answer questions about the data
- Do not execute any write Cypher queries

When answering questions about the database:
1. First retrieve the schema to understand the database structure
2. Formulate appropriate Cypher queries based on the actual schema
3. If a query returns no results, explain what you looked for and suggest alternatives
4. Format results in a clear, human-readable way
5. Cite the actual data returned in your response

Important Cypher notes:
- Use MATCH patterns that align with the actual schema
- For counting, use MATCH (n:Label) RETURN count(n)
- For listing items, add LIMIT to avoid overwhelming results
- Handle potential NULL values gracefully

Be concise but thorough in your responses."""

print("System prompt configured.")

## 5. Credential Management

Functions to load and refresh OAuth2 credentials for the MCP server.

In [None]:
def load_credentials(credentials_file: Path = CREDENTIALS_FILE) -> dict:
    """
    Load credentials from .mcp-credentials.json file.
    
    Expected format:
    {
        "gateway_url": "https://...",
        "access_token": "...",
        "token_expires_at": "2025-01-21T12:00:00+00:00",
        "token_url": "https://...",
        "client_id": "...",
        "client_secret": "...",
        "scope": "...",
        "region": "us-west-2"
    }
    """
    if not credentials_file.exists():
        raise FileNotFoundError(
            f"Credentials file not found: {credentials_file}\n"
            "Create .mcp-credentials.json or set credentials directly in the config cell."
        )
    
    with open(credentials_file) as f:
        return json.load(f)


def check_token_expiry(credentials: dict) -> bool:
    """
    Check if the token is expired or expiring within 5 minutes.
    Returns True if token is still valid.
    """
    expires_at_str = credentials.get("token_expires_at")
    if not expires_at_str:
        return False
    
    try:
        expires_at = datetime.fromisoformat(expires_at_str)
        now = datetime.now(timezone.utc)
        # 5 minute buffer
        return now < (expires_at - timedelta(minutes=5))
    except (ValueError, TypeError):
        return False


def refresh_token(credentials: dict) -> dict:
    """
    Refresh the OAuth2 access token using client credentials flow.
    Updates and saves the credentials file.
    """
    token_url = credentials.get("token_url")
    client_id = credentials.get("client_id")
    client_secret = credentials.get("client_secret")
    scope = credentials.get("scope", "")
    
    if not all([token_url, client_id, client_secret]):
        raise ValueError(
            "Missing token refresh credentials (token_url, client_id, client_secret)"
        )
    
    print("Refreshing OAuth2 token...")
    
    response = httpx.post(
        token_url,
        data={
            "grant_type": "client_credentials",
            "client_id": client_id,
            "client_secret": client_secret,
            "scope": scope,
        },
        headers={"Content-Type": "application/x-www-form-urlencoded"},
        timeout=30,
    )
    response.raise_for_status()
    token_data = response.json()
    
    # Calculate expiry time
    expires_in = token_data.get("expires_in", 3600)
    expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
    
    # Update credentials
    credentials["access_token"] = token_data["access_token"]
    credentials["token_expires_at"] = expires_at.isoformat()
    
    # Save updated credentials
    with open(CREDENTIALS_FILE, "w") as f:
        json.dump(credentials, f, indent=2)
    
    print(f"Token refreshed. New expiry: {expires_at.isoformat()}")
    return credentials


print("Credential management functions defined.")

## 6. Agent Setup

Initialize the LLM and create the ReAct agent with MCP tools.

In [None]:
from langchain_aws import ChatBedrockConverse

def get_llm(region: str = AWS_REGION, model_id: str = INFERENCE_PROFILE_ARN):
    """
    Initialize the LLM using AWS Bedrock Converse API.
    """
    return ChatBedrockConverse(
        model=model_id,
        provider="anthropic",  # Required when using ARN
        region_name=region,
        temperature=0,
    )


async def create_mcp_agent(gateway_url: str, access_token: str, region: str = AWS_REGION):
    """
    Create a ReAct agent connected to the MCP server.
    
    Returns:
        tuple: (agent, client, tools)
    """
    print(f"Connecting to MCP server: {gateway_url[:50]}...")
    
    # Initialize MCP client
    client = MultiServerMCPClient(
        {
            "neo4j": {
                "transport": "streamable_http",
                "url": gateway_url,
                "headers": {
                    "Authorization": f"Bearer {access_token}",
                },
            }
        }
    )
    
    # Get available tools
    tools = await client.get_tools()
    print(f"Loaded {len(tools)} tools:")
    for tool in tools:
        print(f"  - {tool.name}")
    
    # Initialize LLM
    print(f"\nInitializing LLM (region: {region})...")
    llm = get_llm(region)
    print(f"Profile: {INFERENCE_PROFILE_ARN}")
    
    # Create ReAct agent
    print("\nCreating ReAct agent...")
    agent = create_agent(llm, tools, system_prompt=SYSTEM_PROMPT)
    print("Agent ready!")
    
    return agent, client, tools


print("Agent setup functions defined.")

## 7. Query Functions

Functions to run queries through the agent.

In [None]:
async def ask_agent(agent, question: str, verbose: bool = True) -> str:
    """
    Send a question to the agent and return the response.
    
    Args:
        agent: The LangGraph agent
        question: Natural language question
        verbose: Print the question and answer
    
    Returns:
        str: The agent's response
    """
    if verbose:
        print("=" * 70)
        print(f"Question: {question}")
        print("=" * 70)
        print()
    
    result = await agent.ainvoke({"messages": [("human", question)]})
    
    # Extract the final response
    messages = result.get("messages", [])
    if messages:
        final_message = messages[-1]
        content = getattr(final_message, "content", str(final_message))
        
        if verbose:
            print("Answer:")
            print("-" * 70)
            print(content)
            print("-" * 70)
        
        return content
    
    return "No response from agent"


print("Query functions defined.")

## 8. Initialize Agent

Load credentials and create the agent instance.

In [None]:
# Load and validate credentials
print("Loading credentials...")
credentials = load_credentials()

# Auto-refresh token if expired
if not check_token_expiry(credentials):
    print("Token expired or expiring soon.")
    credentials = refresh_token(credentials)
else:
    print(f"Token valid until: {credentials.get('token_expires_at')}")

gateway_url = credentials["gateway_url"]
access_token = credentials["access_token"]
region = credentials.get("region", AWS_REGION)

print(f"\nGateway: {gateway_url}")
print(f"Region: {region}")

In [None]:
# Create the agent
agent, mcp_client, tools = asyncio.get_event_loop().run_until_complete(
    create_mcp_agent(gateway_url, access_token, region)
)

## 9. Demo Queries

Run sample queries to test the agent.

In [None]:
# Query 1: Get database schema
response = asyncio.get_event_loop().run_until_complete(
    ask_agent(agent, "What is the database schema? Give me a brief summary.")
)

In [None]:
# Query 2: Count entities
response = asyncio.get_event_loop().run_until_complete(
    ask_agent(agent, "How many nodes are in the database by label?")
)

In [None]:
# Query 3: Explore relationships
response = asyncio.get_event_loop().run_until_complete(
    ask_agent(agent, "What types of relationships exist in the database?")
)

## 10. Custom Queries

Run your own natural language queries against the Neo4j database.

In [None]:
# Enter your custom question here
CUSTOM_QUESTION = "List 5 sample records from the most populated node type."

response = asyncio.get_event_loop().run_until_complete(
    ask_agent(agent, CUSTOM_QUESTION)
)

### Interactive Query Cell

Use this cell to ask multiple questions interactively.

In [None]:
# Interactive query - modify the question and run this cell
question = "YOUR QUESTION HERE"

if question != "YOUR QUESTION HERE":
    response = asyncio.get_event_loop().run_until_complete(
        ask_agent(agent, question)
    )
else:
    print("Replace 'YOUR QUESTION HERE' with your actual question and run this cell.")

## 11. Cleanup

Clean up resources when done.

In [None]:
# Optional: Clear variables to free memory
# del agent, mcp_client, tools, credentials

print("Session complete.")
print("\nTo continue querying, run the cells in Section 8-10.")
print("To reinstall packages, run Sections 1-2.")

---

## Resources

- [LangGraph Documentation](https://langchain-ai.github.io/langgraph/)
- [LangChain MCP Adapters](https://github.com/langchain-ai/langchain-mcp-adapters)
- [Neo4j MCP Server](https://github.com/neo4j-contrib/mcp-neo4j)
- [AWS Bedrock Documentation](https://docs.aws.amazon.com/bedrock/)
- [Model Context Protocol](https://modelcontextprotocol.io/)