# GLP-1 Treatment AI Agent Demo
This is a Jupyter notebook version of the GLABITAI GLP-1 Treatment AI Agent.

In [None]:
# LangGraph Setup
from datetime import datetime
from typing import TypedDict, List, Dict, Any, Optional, Literal
from pydantic import BaseModel, Field
# Remove: from pydantic_ai import Agent, Tool
from langchain.agents import Tool, AgentExecutor, create_react_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq  # or your preferred LLM
import json
from pymongo import MongoClient
import redis
import chromadb
from langgraph.graph import StateGraph, END
from langgraph.graph.graph import Graph
import asyncio
from dotenv import load_dotenv
import os
from bson.errors import InvalidId
from bson.objectid import ObjectId
from typing import Any, List, TypedDict, Annotated, Optional
from langchain_groq import ChatGroq
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration # To maintain the original return type
from langgraph.graph import StateGraph, END
import operator # For Annotated with operator.add

# Load environment variables
load_dotenv()

# Environment configuration
MONGODB_URI = os.getenv("MONGODB_URI", "mongodb://glabitai:dev_password@mongodb:27017/")
MONGODB_DB = os.getenv("MONGODB_DB", "glabitai_glp1_clinical")
REDIS_URL = os.getenv("REDIS_URL", "redis://redis:6379/0")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "no-key-found")

In [5]:
# Data Models
class PatientData(TypedDict):
    patient_id: str
    name: str
    age: int
    weight_kg: float
    treatment_start_date: str
    current_dosage_mg: float
    notes: str
    _id: Optional[str] = None
    last_visit: Optional[str] = None
    next_visit: Optional[str] = None
    side_effects: Optional[List[str]] = None

class AnalysisResult(TypedDict):
    assessment: str
    concerns: List[str]
    recommendations: List[str]
    confidence_score: float
    generated_at: str

class GraphState(TypedDict):
    patient_id: str
    patient_data: Optional[PatientData]
    analysis: Optional[AnalysisResult]
    current_step: str
    error: Optional[str]
    next_steps: List[str]
    timestamp: str
    processing_time: Optional[float] = None

In [None]:

class GLPAgent:
    def __init__(self, model_name="mixtral-8x7b-32768"):
        """Initialize the GLP-1 Treatment Analysis Agent with Groq and LangGraph."""
        self.llm = ChatGroq(
            model_name=model_name,
            temperature=0.7,
            max_tokens=1000
        )

        # System prompt to guide the AI's responses
        self.system_prompt = """You are a helpful AI assistant specialized in GLP-1 treatment analysis.
        You analyze patient data and provide insights about their GLP-1 treatment progress.
        Be concise, professional, and focus on actionable insights."""


    async def run(self, prompt: str) -> Any: # Retaining Any, but it implies ChatGeneration
        """Run the agent with the given prompt using the LangGraph application."""
        try:
            # Prepare the initial list of messages for the graph's state
            initial_messages: List[BaseMessage] = [
                SystemMessage(content=self.system_prompt),
                HumanMessage(content=prompt)
            ]
            
            # Define the initial state to pass to the graph.
            # LangGraph will create 'final_chat_generation' as None if not provided and Optional.
            initial_graph_state = {"messages": initial_messages}
            
            # Invoke the LangGraph app asynchronously
            # The app will execute the graph starting from the entry point
            final_state: AgentState = await self.app.ainvoke(initial_graph_state)
            
            # Extract the ChatGeneration object from the final state of the graph
            # This was set by the _run_llm_and_get_generation node
            if final_state and "final_chat_generation" in final_state and final_state["final_chat_generation"] is not None:
                return final_state["final_chat_generation"]
            else:
                # Fallback if something unexpected happened, though unlikely in this simple graph
                # Try to return the last message if available, or an error string
                last_message = final_state["messages"][-1] if final_state.get("messages") else None
                if isinstance(last_message, AIMessage):
                    # This is not a ChatGeneration object, but it's better than nothing
                    # For strictness, one might raise an error here.
                    # However, returning the AIMessage's content or object might be a graceful fallback.
                    # For now, signaling an issue in retrieving the ChatGeneration.
                    return f"Error: final_chat_generation not found in state. Last AI message: {last_message.content if last_message else 'None'}"
                return "Error: Could not retrieve ChatGeneration from agent's final state."
            
        except Exception as e:
            # import traceback
            # print(f"Detailed error in GLPAgent.run: {traceback.format_exc()}") # For debugging
            return f"Error generating response: {str(e)}"


# Define the state schema
class PatientState(TypedDict):
    # Input
    patient_id: str
    
    # Data
    patient: Optional[Dict[str, Any]] = None
    analysis: Optional[Dict[str, Any]] = None
    
    # Processing
    current_step: str
    error: Optional[str] = None
    timestamp: str
    processing_time: Optional[float] = None

# Tooling Setup
class MongoTool:
    def __init__(self, mongo_uri: str, db_name: str):
        from pymongo import MongoClient
        self.client = MongoClient(mongo_uri)
        self.db = self.client[db_name]

    def query(self, patient_id: str) -> Optional[Dict]:
        """Fetch patient data from MongoDB with error handling."""
        try:
            result = self.db.patients.find_one({"patient_id": patient_id})
            if result and '_id' in result:
                result['_id'] = str(result['_id'])
            return result
        except Exception as e:
            raise Exception(f"MongoDB query failed: {str(e)}")

class RedisTool:
    def __init__(self, url: str):
        import redis
        try:
            self.client = redis.Redis.from_url(url, decode_responses=True)
            self.client.ping()  # Test connection
        except Exception as e:
            raise Exception(f"Redis connection failed: {str(e)}")

    def cache_analysis(self, key: str, value: Dict[str, Any], ttl: int = 86400) -> bool:
        """Cache analysis with TTL (default 24 hours)."""
        try:
            return self.client.setex(f"analysis:{key}", ttl, json.dumps(value))
        except Exception as e:
            raise Exception(f"Redis cache update failed: {str(e)}")

    def get_cached_analysis(self, key: str) -> Optional[Dict]:
        """Retrieve cached analysis if exists and not expired."""
        try:
            result = self.client.get(f"analysis:{key}")
            return json.loads(result) if result else None
        except Exception as e:
            raise Exception(f"Redis cache retrieval failed: {str(e)}")

class ChromaTool:
    def __init__(self, collection_name: str = "patients"):
        import chromadb
        try:
            self.client = chromadb.Client()
            self.collection = self.client.get_or_create_collection(name=collection_name)
        except Exception as e:
            raise Exception(f"ChromaDB initialization failed: {str(e)}")

    def upsert(self, document: Dict, metadata: Optional[Dict] = None) -> None:
        """Upsert document with metadata."""
        try:
            doc_id = document.get("patient_id", str(hash(json.dumps(document, sort_keys=True))))
            self.collection.upsert(
                ids=[doc_id],
                documents=[json.dumps(document)],
                metadatas=[metadata] if metadata else None
            )
        except Exception as e:
            raise Exception(f"ChromaDB upsert failed: {str(e)}")

# Initialize tools with error handling
try:
    mongo_tool = MongoTool(MONGODB_URI, MONGODB_DB)
    redis_tool = RedisTool(REDIS_URL)
    chroma_tool = ChromaTool()
except Exception as e:
    print(f"Tool initialization error: {str(e)}")
    raise



# Define state transitions
async def fetch_patient(state: PatientState) -> PatientState:
    """Fetch patient data from MongoDB using _id."""
    state["current_step"] = "fetching_patient"
    try:
        # Convert string ID to ObjectId for the query
        patient_id = state["patient_id"]
        
        # Query using _id field
        patient = mongo_tool.db.patients.find_one({"_id": ObjectId(patient_id)})
        
        if not patient:
            raise ValueError(f"Patient with _id {patient_id} not found")
            
        # Convert ObjectId to string for JSON serialization
        if '_id' in patient:
            patient['_id'] = str(patient['_id'])
            
        return {**state, "patient": patient}
    except Exception as e:
        error_msg = f"Error fetching patient: {str(e)}"
        print(error_msg)  # Debug log
        return {**state, "error": error_msg}

        
async def analyze_with_agent(state: PatientState) -> PatientState:
    """Analyze patient data using AI agent."""
    if not state.get("patient"):
        return {**state, "error": "No patient data available for analysis"}
    
    state["current_step"] = "analyzing"
    try:
        agent = GLPAgent()
        prompt = f"""
        Analyze this patient's GLP-1 treatment progress:

        Patient Data:
        {json.dumps(state["patient"], indent=2, default=str)}

        Provide structured analysis with:
        1. Treatment progress assessment
        2. Any concerns or considerations
        3. Recommended next steps
        """
        
        response = await agent.run(prompt)
        analysis = {
            "content": response.content,
            "generated_at": datetime.utcnow().isoformat()
        }
        return {**state, "analysis": analysis}
        
    except Exception as e:
        return {**state, "error": f"Analysis failed: {str(e)}"}


        
async def cache_and_store(state: PatientState) -> PatientState:
    """Cache analysis and store patient data."""
    print(state)
    if not state.get("patient") or not state.get("analysis"):
        
        return {**state, "error": "Incomplete data for storage"}
    
    state["current_step"] = "storing_results"
    try:
        patient_id = state["patient"]["patient_id"]
        
        # Cache analysis in Redis
        redis_tool.cache_analysis(patient_id, state["analysis"])
        
        # Store in ChromaDB
        chroma_tool.upsert(
            document=state["patient"],
            metadata={
                "last_analyzed": datetime.utcnow().isoformat(),
                "analysis_timestamp": state["analysis"]["generated_at"]
            }
        )
        
        return state
        
    except Exception as e:
        return {**state, "error": f"Storage failed: {str(e)}"}

PydanticUserError: If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`. Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.

For further information visit https://errors.pydantic.dev/2.11/u/root-validator-pre-skip

In [None]:
# Create and configure the graph
def create_patient_analysis_workflow() -> Graph:
    """Create and configure the patient analysis workflow."""
    workflow = StateGraph(PatientState)
    
    # Add nodes
    workflow.add_node("fetch_patient", fetch_patient)
    workflow.add_node("analyze", analyze_with_agent)
    workflow.add_node("store_results", cache_and_store)
    
    # Define the flow
    workflow.set_entry_point("fetch_patient")
    workflow.add_edge("fetch_patient", "analyze")
    workflow.add_edge("analyze", "store_results")
    workflow.add_edge("store_results", END)
    
    return workflow.compile()

In [None]:
# Create Agent Workflow
from IPython.display import Image, display

workflow = create_patient_analysis_workflow()
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
# Example usage
def create_initial_state(patient_id: str) -> PatientState:
    """Create initial state with timestamp and default values."""
    return {
        "patient_id": patient_id,
        "patient": None,
        "analysis": None,
        "current_step": "initialized",
        "timestamp": datetime.utcnow().isoformat()
    }


async def analyze_patient(patient_id: str, workflow) -> Dict[str, Any]:
    """Run the analysis pipeline for a patient."""
    initial_state = create_initial_state(patient_id)
    
    try:
        result = await workflow.ainvoke(initial_state)
        if result.get("error"):
            print(f"Analysis completed with errors: {result['error']}")
        return result
    except Exception as e:
        return {
            "patient_id": patient_id,
            "error": f"Workflow execution failed: {str(e)}",
            "timestamp": datetime.utcnow().isoformat()
        }

# Example usage in notebook
result = await analyze_patient("683613805f2537d439303bbf", workflow)
if "error" in result:
    print(f"Error: {result['error']}")
else:
    print(f"Analysis completed: {json.dumps(result['analysis'], indent=2)}")