# Reasoning Diagnostic Agent

### Check Python version

Tested version: Python 3.12.7

In [None]:
!python --version

### Load API Key

In [None]:
from dotenv import load_dotenv
import os

load_dotenv()
print(os.getenv("GOOGLE_API_KEY"))

### Generate and install requirements file

In [None]:
%%writefile requirements.txt

chromadb==1.1.1
fastmcp==2.12.4
langchain==0.3.27
langchain-core==0.3.83
langchain-google-genai==2.1.12
langchain-mcp-adapters==0.1.10
langgraph==0.5.4
mcp==1.16.0
nest_asyncio
sentence-transformers==5.1.1

In [None]:
%pip install -r requirements.txt

### Define Knowledge Base documents

In [None]:
sample_documents = [
"""
Title: Connection Rejected Diagnostic
Steps:
    1. Check connection response code.
    2. If error code, issue is indicated by the error code.
""",
"""
Title: High link latency diagnostic
Steps:
    1. Check link status.
    2. If disconnected, issue is caused by bad link state.
    3. If connected, check current system load.
    4. Also check system load 5 minutes ago.
    5. If both are above 90% then system is overloaded.
""",
"""
Title: Connection Failure Diagnostic
Steps:
    1. Check link status.
    2. If disconnected, issue is caused by bad link state.
""",
"""
Title: Origin service {service-id} with high latency on more than 90% of requests in the last hour
Steps:
    1. Get percentage of requests with high latency for {service-id} in the last hour
    2. If percentage of requests with high latency is less than 90% then report root cause as "alert error"
    3. Get average end-to-end latency for {service-id} requests in the last hour
    4. Get average server latency for {service-id} requests in the last hour
    5a. If server latency is less than 10% of end-to-end latency then report root cause as "high latency caused by external factors"
    5b. If storage latency is more than 50% of end-to-end latency then report root cause as "high latency caused by storage"
    6. Get {deployment-id} where {service-id} is running
    7. Get average CPU load of {deployment-id} in the last hour
    8. If CPU load is above 90% then go to "High CPU usage in {deployment-id} diagnostic" KB article
    9. Otherwise report root cause as "unable to determine cause of high latency"
""",
"""
Title: High CPU usage in {deployment-id} diagnostic
Steps:
    1. Get average number of requests per second for {deployment-id} in the last hour
    2. Get number of role instances for {deployment-id}
    3. If number of requests per second per role instance is greater than 100 then report root cause as "system overloaded with too many requests"
    4. Otherwise report root cause as "unable to determine cause of high CPU usage"
"""
]

### Implement RAG for Knowledge Base

In [None]:
import chromadb

from sentence_transformers import SentenceTransformer

class DiagnosticKB:
    def __init__(self):
        # RAG configuration
        self.sample_documents = sample_documents

        # Initialize RAG components
        self.initialize_rag()


    def initialize_rag(self):
        """Initialize RAG components including ChromaDB and sentence transformer"""
        try:
            # Initialize sentence transformer model
            self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

            # Initialize ChromaDB client
            self.chroma_client = chromadb.Client()

            # Create or get collection
            self.collection = self.chroma_client.get_or_create_collection(
                name="knowledge_base",
                metadata={"hnsw:space": "cosine"}
            )

            # Add documents to collection
            self.add_documents_to_collection()

        except Exception as e:
            print(f"Error initializing RAG: {str(e)}")


    def add_documents_to_collection(self):
        """Add sample documents to ChromaDB collection"""
        try:
            # Generate embeddings for documents
            embeddings = self.embedding_model.encode(self.sample_documents)

            # Prepare documents for ChromaDB
            documents = []
            metadatas = []
            ids = []

            for i, doc in enumerate(self.sample_documents):
                documents.append(doc)
                metadatas.append({"source": "sample_document", "index": i})
                ids.append(f"doc_{i}")

            # Add to collection
            self.collection.add(
                documents=documents,
                embeddings=embeddings.tolist(),
                metadatas=metadatas,
                ids=ids
            )

        except Exception as e:
            print(f"Error adding documents to collection: {str(e)}")


    def query_rag(self, query, n_results=1):
        """Query the RAG system to retrieve relevant documents"""
        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode([query])

            # Query the collection
            results = self.collection.query(
                query_embeddings=query_embedding.tolist(),
                n_results=n_results
            )

            # Extract relevant documents
            relevant_docs = results['documents'][0] if results['documents'] else []

            return relevant_docs

        except Exception as e:
            print(f"Error querying RAG: {str(e)}")
            return []


### Define and execute MCP server (HTTP)

In [None]:
import random
import uuid

from dataclasses import dataclass
from fastmcp import FastMCP
from uuid import UUID

mcp_http_server = FastMCP("MCP-HTTP-Server")

# Initialize diagnostic knowledge base
diagnostic_kb = DiagnosticKB()

# -------------------------------
# Return types
# -------------------------------

@dataclass
class RequestLatencies:
    storage_latency_ms: int
    server_latency_ms: int
    end_to_end_latency_ms: int

@dataclass
class ServiceInfo:
    service_type: str
    account_id: UUID
    deployment_id: UUID

@dataclass
class DeploymentInfo:
    vm_type: str
    role_instances_count: int
    subscription_id: UUID

@dataclass
class ErrorInfo:
    type: str
    message: str

@dataclass
class ErrorResponse:
    error: ErrorInfo

# -------------------------------
# Tools
# -------------------------------

@mcp_http_server.tool()
def failure_analysis_kb(query: str):
    """Knowledge base for procedures on how to diagnose issues."""
    return diagnostic_kb.query_rag(query)

@mcp_http_server.tool()
def query_system_load(query: str) -> dict:
    """Provides current system load"""

    return { "sytem_load_percent": 99 }

@mcp_http_server.tool()
def query_system_latency(query: str) -> dict:
    """Provides current system latency"""

    return { "system_latency_ms": 100 }

@mcp_http_server.tool()
def query_link_status(query: str) -> dict:
    """Reports whether the link status is connected or disconnected"""

    response = { "link_status": "disconnected" }
    if random.randint(1, 100) < 50:
        response = { "link_status": "connected" }
    return response

@mcp_http_server.tool()
def query_high_latency_request_percentage(service_id: str, time_window: int) -> dict | ErrorResponse:
    """Reports the percentage of high latency requests on {service-id} in the last {time_window} hours"""

    if service_id == "d3f1a8b2-7c4e-4f9e-9e2a-8b6c3a2d1f4e":
        return { "high_latency_requests_percent": 98 }

    return ErrorResponse(ErrorInfo(
        type="invalid parameter",
        message="service_id not found"))

@mcp_http_server.tool()
def query_average_request_latencies(service_id: str, time_window: int) -> RequestLatencies | ErrorResponse:
    """Reports average request latencies on {service-id} in the last {time_window} hours """

    if service_id == "d3f1a8b2-7c4e-4f9e-9e2a-8b6c3a2d1f4e":
        return RequestLatencies(
            storage_latency_ms=10,
            server_latency_ms=500,
            end_to_end_latency_ms=2000)

    return ErrorResponse(ErrorInfo(
        type="invalid parameter",
        message="service_id not found"))

@mcp_http_server.tool()
def query_service_info(service_id: str) -> ServiceInfo | ErrorResponse:
    """Reports information on {service-id} """

    if service_id == "d3f1a8b2-7c4e-4f9e-9e2a-8b6c3a2d1f4e":
        return ServiceInfo(
            service_type="shared origin",
            account_id=uuid.UUID("a1d4c6f7-3e2b-4b9a-bc8f-9f6e2a1d7c3e"),
            deployment_id=uuid.UUID("f3c9a7e2-8b4d-4f6a-9c2e-7d1b3a6e5c9f"))

    return ErrorResponse(ErrorInfo(
        type="invalid parameter",
        message="service_id not found"))

@mcp_http_server.tool()
def query_deployment_info(deployment_id: str) -> DeploymentInfo | ErrorResponse:
    """Reports information on {deployment_id}"""

    if deployment_id == "f3c9a7e2-8b4d-4f6a-9c2e-7d1b3a6e5c9f":
        return DeploymentInfo(
            vm_type="Standard_D16_v5",
            role_instances_count=2,
            subscription_id=uuid.UUID("c7e2b9f1-4d3a-4a8e-9f6c-2b1d7e3f9a4c"))

    return ErrorResponse(ErrorInfo(
        type="invalid parameter",
        message="deployment_id not found"))

@mcp_http_server.tool()
def query_average_cpu_load(deployment_id: str, time_window: int) -> dict | ErrorResponse:
    """Reports the average CPU load on {deployment-id} in the last {time_window} hours"""

    if deployment_id == "f3c9a7e2-8b4d-4f6a-9c2e-7d1b3a6e5c9f":
        return { "average_cpu_load_percent": 98 }

    return ErrorResponse(ErrorInfo(
        type="invalid parameter",
        message="deployment_id not found"))

@mcp_http_server.tool()
def query_average_requests_per_sec(deployment_id: str, time_window: int) -> dict | ErrorResponse:
    """Reports the average requests per second on {deployment-id} in the last {time_window} hours"""

    if deployment_id == "f3c9a7e2-8b4d-4f6a-9c2e-7d1b3a6e5c9f":
        return { "average_requests_per_second": 500 }

    return ErrorResponse(ErrorInfo(
        type="invalid parameter",
        message="deployment_id not found"))


# -------------------------------
# Prompt
# -------------------------------

@mcp_http_server.prompt()
def get_llm_prompt(query: str) -> str:
    """Generates a prompt for the LLM to use to answer the query"""

    raise Exception("Not implemented")


if __name__ == "__main__":
    import nest_asyncio
    nest_asyncio.apply()

    from threading import Thread

    def start():
        mcp_http_server.run(transport="streamable-http",
                            host="localhost",
                            port=8000,
                            path="/mcp",
                            log_level="debug")

    Thread(target=start, daemon=True).start()

### Check whether MCP server is running

In [None]:
!lsof -i :8000

In [None]:
!netstat -a | grep 8000

In [None]:
!netstat -an | findstr 8000

### Stop the MCP server

In [None]:
!lsof -t -i:8000 | xargs kill -9

In [None]:
!tasklist | findstr python

In [None]:
!taskkill /PID <PID> /F

### Define and Run Diagnostic Agent / MCP Client

In [None]:
# diagnostic_agent.py
#
# Move tools from STDIO to HTTP server

from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client

from langchain_core.runnables import RunnableConfig
from langchain_mcp_adapters.tools import load_mcp_tools
from langgraph.prebuilt import create_react_agent
from langchain_google_genai import ChatGoogleGenerativeAI


def get_llm(name: str|None):
    name = name or "gemini-2.5-flash"
    return ChatGoogleGenerativeAI(model=name)

def get_llm_prompt(query: str) -> str:
    return f"""
    You are a helpful assistant. Answer the following query
    by only using the tools provided to you. DO NOT make up any information.
    Query failure analysis KB first to make a plan on how to derive the
    response. Then execute the plan until root cause is determined for response.
    Do not repeat tool calls with the same query.

    If the result of a step in the plan indicates a new KB query is needed
    for root causing the issue then you MUST perform the new failure analysis KB query
    and make a new plan to derive the root cause response. DO NOT finish before
    the root cause is determined, unless the conclusion is "unable to determine
    root cause".

    Query: {query}
    """


async def query_agent(prompt: str, model: str|None=None) -> str:

    # MCP server accessible through HTTP
    mcp_http_server_url="http://localhost:8000/mcp"

    try:
        async with streamablehttp_client(mcp_http_server_url) as (http_read, http_write, _):
            async with ClientSession(http_read, http_write) as http_session:
                print("initializing HTTP client session")
                await http_session.initialize()

                print("\nloading tools & prompt")
                # Tool names must be unique, the LLM will chose
                # a tool based on the query and tool descriptions.
                mcp_server_tools = await load_mcp_tools(http_session)

                print("\nTools loaded :")
                for tool in mcp_server_tools:
                    print(f"▪️ {tool.name} - {tool.description}")

                llm = get_llm(model)
                llm_prompt = get_llm_prompt(prompt)

                config = RunnableConfig()

                agent=create_react_agent(model=llm, tools=mcp_server_tools, debug=False)

                print(f"\nAnswering query : {prompt}")
                agent_response = await agent.ainvoke(input={"messages": llm_prompt}, config=config)

                return agent_response["messages"][-1].content

    except Exception as e:
        print(f"Error: {e}")
        if isinstance(e, ExceptionGroup):
            print(f"{e.exceptions}")
        return "Error"

    return "Error"


# main
print("\nRunning Query Agent...")
response = await query_agent(
        prompt="What is the cause of alert 'Origin service d3f1a8b2-7c4e-4f9e-9e2a-8b6c3a2d1f4e with high latency on more than 90% of requests in the last hour'",
        model=None)

print("\nResponse: ", response)