# Workshop: Building an AI Medical Diagnosis Agent

Welcome! In this workshop, we'll build a smart AI Agent that can:

1.  **Chat with a user** to understand their symptoms.
2.  **Decide** when it has gathered enough information to move to research.
3.  **Use external tools** (like Perplexity's API) to narrow down medical conditions.
4.  **Analyze the findings** and the conversation.
5.  **Generate a structured report.**

Think of it like a simplified version of a preliminary medical consultation. But **IMPORTANT DISCLAIMER:** this is for educational purposes ONLY and is NOT a substitute for real medical professionals (yet😉)

**Why is this interesting?**

* **LLMs as Orchestrators:** We'll see how Large Language Models (LLMs) like Google's Gemini can evolve beyond only generating text and also *control a workflow*, make decisions, and use tools.
* **LangChain & LangGraph:** We'll use these powerful libraries designed to make building complex AI applications easier. LangGraph helps create reliable, step-by-step AI processes through state management.
* **Real-world Pattern:** Similar patterns to this one (chat -> gather info -> use tools -> synthesize) are common in many professions, and can be leveraged to automate real jobs.

**Prerequisites:** Gemini and Perplexity API. We'll explain the AI concepts as we go!

**Let's visualize the basic flow at this link:
https://gdsc-x-big-think-ai-workshop.vercel.app/**

## Before we start, you need 2 API keys:
1. Get free Gemini API key at https://ai.google.dev/gemini-api/docs/api-key
2. Get free Perplexity API Key (by using your school email) at https://www.perplexity.ai/referrals/join

(You would need to use credit card information to access the Perplexity API once you create an account. You get 5$ of free API credits per month for Perplexity from your student email account and Gemini API is free. So it will not cost anything.)

3. Make a copy of this colab notebook, and open your copy. Select the secrets option on left-side handle, and add the GEMINI_API_KEY and PERPLEXITY_API_KEY fields, pasting the API keys in the text fields. This saves the API keys to colab and make them accessible to the colab notebook we'll use for our workshop.

## Step 1: Import libraries we'll be using in the workshop

In [None]:
# ---- Core Python & Utilities ---- #
import os
import sys
import json
import time # used for rate-limiting
import requests # for making web requests to Perplexity API
from functools import wraps # Helper used for building decorators easily
from typing import List, Dict, Any # For type hinting
from typing_extensions import Annotated, TypedDict # For advanced type for our state
from google.colab import userdata # a secure way to access keys in colab we've saved as secrets
from rich.console import Console
from rich.markdown import Markdown

# ---- Langchain, LangGraph & Google API ---- #
from operator import itemgetter
import google.api_core.exceptions # for handling google API errors
from google.generativeai import configure, list_models #
!pip install langchain-google-genai # installs the LangChain Integration for google's models quietly
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.tools import tool
from langchain_core.prompts import PromptTemplate
!pip install langgraph # installs the LangGraph library quietly
from langgraph.graph import StateGraph, START, END # Core components for building the graph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
!pip install -U langchain langsmith httpx

Collecting langchain-google-genai
  Downloading langchain_google_genai-2.1.3-py3-none-any.whl.metadata (4.7 kB)
Collecting filetype<2.0.0,>=1.2.0 (from langchain-google-genai)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting google-ai-generativelanguage<0.7.0,>=0.6.16 (from langchain-google-genai)
  Downloading google_ai_generativelanguage-0.6.17-py3-none-any.whl.metadata (9.8 kB)
Downloading langchain_google_genai-2.1.3-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)
Downloading google_ai_generativelanguage-0.6.17-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: filetype, google-ai-generativelanguage, langchain-google-genai
  Attempting uninstall: google-ai-generativelangu

Collecting langgraph
  Downloading langgraph-0.3.34-py3-none-any.whl.metadata (7.9 kB)
Collecting langgraph-checkpoint<3.0.0,>=2.0.10 (from langgraph)
  Downloading langgraph_checkpoint-2.0.24-py3-none-any.whl.metadata (4.6 kB)
Collecting langgraph-prebuilt<0.2,>=0.1.8 (from langgraph)
  Downloading langgraph_prebuilt-0.1.8-py3-none-any.whl.metadata (5.0 kB)
Collecting langgraph-sdk<0.2.0,>=0.1.42 (from langgraph)
  Downloading langgraph_sdk-0.1.63-py3-none-any.whl.metadata (1.8 kB)
Collecting xxhash<4.0.0,>=3.5.0 (from langgraph)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting ormsgpack<2.0.0,>=1.8.0 (from langgraph-checkpoint<3.0.0,>=2.0.10->langgraph)
  Downloading ormsgpack-1.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Downloading langgraph-0.3.34-py3-none-any.whl

## Step 2: Set up the APIs to use LLMs for the agent

In [None]:
GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
PERPLEXITY_API_KEY = userdata.get('PERPLEXITY_API_KEY')

## Extra step: Set up a function for API rate-limiting

The functions are delayed for a few seconds every time it runs an API. Without strict controls, repeated or recursive API calls can quickly go out of control, leading to infinite loops or too many API requests in a few seconds. The result ? High billing cost, service denial by API provider, or even temporary bans.

We will also be serving Perplexity API as a tool for Gemini LLM to invoke, so there's a chance of infinite loops in case Gemini decides to overdo the research.

In [None]:
def api_rate_limit(seconds: int = 2): # default pause is 2 seconds
    """This nested function creates and returns a Decorator to add sleep time between API calls"""
    def decorator(func):
        @wraps(func) # Saves the metadata of wrapped function (like name, docstring) to
        def wrapper(*args, **kwargs):
            """This wrapper executes the following code before the target function executes"""
            time.sleep(seconds)  # Pause for some seconds before making the API call
            return func(*args, **kwargs) # Now calls the original target function
        return wrapper
    return decorator

## Step 3: Defining the perplexity_research function as a tool for LLM to invoke

Turning Perplexity API into a tool that Learn-lm-1.5-pro can use to deepen its analysis. The research tool is one of the most crucial steps as it turns a general LLM into an expert on any topic, giving it tools to research the web in real-time and augment its knowledge.

In [None]:
@tool # LangChain decorator. Now the function for Perplexity API is available as a tool for the LLM
@api_rate_limit(1) # Apply our custom 1-second rate limit before calling this function each time
def perplexity_research(query: str) -> str:
    """Research medical conditions using Perplexity API. Provice citations and links to reliable, authentic research sources."""
    headers = { # Standard HTTP headers for the API request
        "accept": "application/json",
        "Content-Type": "application/json",
        "Authorization": f"Bearer {PERPLEXITY_API_KEY}"
    }
    payload = { # The actual data sent to the Perplexity API
        "model": "sonar-pro",
        "messages": [
            {"role": "system",
            "content": "You are a medical research assistant. Provide precise and well-sourced responses, along with citations, and links for resources"},
            {"role": "user", "content": query}
        ],
        "temperature": 0.3,  # Lower randomness for factual consistency
        "max_tokens": 2048,  # Allow more detailed responses
        "top_p": 0.8,  # Nucleus sampling for high-confidence outputs
        "frequency_penalty": 0.0,  # Reduce repetitive phrasing
    }

    try:
        print("RESPONSE: Sending request to Perplexity API...")
        response = requests.post("https://api.perplexity.ai/chat/completions", json=payload, headers=headers)
        response.raise_for_status()

        # Debugging API response
        json_response = response.json()
        print(f"RESPONSE: API Response JSON: {json_response}")

        # Adjust parsing based on actual response structure
        return json_response["choices"][0].get("message", {}).get("content", "No content found.")

    except requests.RequestException as e:
        print(f"RESPONSE: API Error Details: {str(e)}")
        return f"Error researching topic: {str(e)}"

## Step 4: Set up the LLM for interaction with user input and orchestration with the conversation flow

Now that the tool for LLM has been built, let's bind it to the LLM and set up the LLM.

In [None]:
llm = ChatGoogleGenerativeAI(
    model="learnlm-1.5-pro-experimental",  # You can try other models by replacing with "gemini-1.5-pro", "gemini-1.5-flash", or "gemini-2.5-pro-preview-03-25"
    google_api_key=GEMINI_API_KEY,
    temperature=0.3 # Range is usually 0 to 1, we are choosing lower value for more predictable responses
)
# Define a simple prompt template to turn questions into prompt for LLM
template = "Answer this to the best of your knowledge. {question} ?"
prompt = PromptTemplate(template=template, input_variables=["question"])

# Bind tools with LLM using LangChain
tools = [perplexity_research]
llm_with_tools = llm.bind_tools(tools=tools)

# State Management
# Defines the structure for managing conversation state and analysis progress
class State(TypedDict):
    messages: Annotated[List[Dict[str, Any]], "Chat messages"]                            # Store a list of all chat messages
    research_results: Annotated[Dict[str, Any], "Medical research data"]                  # Research results from Perplexity API
    analysis_complete: Annotated[bool, "Whether analysis is complete"]                    # Determine if analysis is completed, in True or False
    report: Annotated[Dict[str, Any], "Final medical analysis report"]                    # Final report to be returned
    conversation_stage: Annotated[str, "Current stage: conversation, research, complete"] # Track the current stage for LLM
    symptom_details: Annotated[Dict[str, Any], "Collected symptom information"]           # Details of symptoms, to be used by LLM
    question_count: Annotated[int, "Number of questions asked so far"]                    # Not necessary unless we want min/max number of questions

## Step 5: Setup the Conversation Flow Nodes to enable deeper research, analysis and report generation

In [None]:
# Initial Conversation Handler
# Processes user input and generates initial response using STRUCTURED OUTPUT
@api_rate_limit(1)
def interactive_conversation(state: State):
    """Handle multi-turn conversation using structured JSON output from LLM
       to dynamically decide when enough detail is present."""
    print("PROCESSING: Entering interactive_conversation node...")
    current_messages = state["messages"]
    question_count = state.get("question_count", 0) + 1 # Still track for context/failsafe
    symptom_details = state.get("symptom_details", {})

    # --- Failsafe Check (Optional but Recommended) ---
    FAILSAFE_LIMIT = 10 # Set a max limit of questions to prevent potential infinite loops
    if question_count > FAILSAFE_LIMIT:
        print(f"DEBUG ERROR: Failsafe question limit ({FAILSAFE_LIMIT}) reached. Forcing move to research.")
        # Update symptom details one last time
        if current_messages and current_messages[-1].get("role") == "user":
            last_updated = symptom_details.get("last_updated", -1)
            if len(current_messages) > last_updated:
                symptom_details = extract_symptom_details(current_messages)

        # Construct a hardcoded message indicating the move to analysis due to limit.
        response_content = "Based on the information gathered so far, I will now proceed with the analysis."
        new_message = {"role": "assistant", "content": response_content}
        updated_messages = current_messages + [new_message]
        return {
            "messages": updated_messages,
            "question_count": question_count -1, # Stay at the failsafe count
            "conversation_stage": "research", # Force stage to research
            "symptom_details": symptom_details
        }
    # --- End Failsafe Check ---

    # Update symptom details if new user message arrived
    if current_messages and current_messages[-1].get("role") == "user":
        last_updated = symptom_details.get("last_updated", -1)
        if len(current_messages) > last_updated:
            print("PROCESSING: (interactive_conversation) Extracting details from latest user message...")
            symptom_details = extract_symptom_details(current_messages)


    # --- Prompt requesting JSON output ---
    prompt = f"""
    {SYSTEM_PROMPT}

    You are in the **information gathering** stage of a medical consultation. Your goal is to gather sufficient detail to perform a preliminary analysis following a standard procedure.
    Conversation History:
    {format_conversation_history(current_messages)} # Assuming this helper exists

    Current Symptom Understanding (internal summary - may be incomplete):
    {symptom_details.get("extracted_data", "No structured summary yet.")}

    Based on the conversation history and your understanding:

    1.  **Assess Sufficiency:** Do you have enough detail about the main complaints? Consider key aspects like the following, but remember ALL of these are not always
        useful. Based on what you think the conditions could be, you decide only the RELEVANT pieces of info to ask:
        * Onset & Duration
        * Location & Radiation
        * Quality/Character (e.g., sharp, dull, pressure)
        * Severity (e.g., scale of 1-10 if appropriate, or description)
        * Timing/Frequency
        * Aggravating/Alleviating Factors
        * Associated Symptoms
        * Relevant Medical History (briefly, if mentioned)

    2.  **Decide Action and Format Output:** Respond ONLY with a valid JSON object containing two keys:
        * `"proceed_to_research"`: A boolean value (`True` if you have sufficient detail based on the criteria, `False` otherwise).
        * `"assistant_message"`: (string)
             * If `True`, a brief, empathetic confirmation (e.g., "Thank you for sharing that detail. I think I have enough information to proceed with the next step.").
             * If `False`, the single, most important follow-up question needed right now. Keep it concise and empathetic (e.g., "Could you tell me more about when this symptom started?").

    **CRITICAL INSTRUCTION:** Even if you assess the situation as potentially requiring immediate emergency care, **do not** include that assessment or recommendation in the `assistant_message` *at this stage*. Stick strictly to the JSON format and the content rules described above. Emergency considerations are handled later.

    Example valid JSON output if continuing conversation:
    {{
      "proceed_to_research": False,
      "assistant_message": "When you feel short of breath, does anything seem to make it better or worse?"
    }}

    Example valid JSON output if ready for research:
    {{
      "proceed_to_research": True,
      "assistant_message": "Thank you. I have enough information to analyze your symptoms now."
    }}

    This is conversation turn {question_count}. Ensure your entire response is ONLY the JSON object without any introductory text or explanation.
    """
    print(f"DEBUG: Invoking LLM for conversation (Turn {question_count}, assessing sufficiency, expecting JSON)...")
    try:
        response = llm.invoke(prompt)
        response_content = response.content if hasattr(response, 'content') else str(response)
        print(f"DEBUG: LLM raw response received: '{response_content[:100]}...'") # Log more for debugging JSON

        # --- Attempt to Parse JSON Response ---
        try:
            # Clean potential markdown code fences if the model wraps JSON in them
            if response_content.strip().startswith("```json"):
                response_content = response_content.strip()[7:-3].strip()
            elif response_content.strip().startswith("```"):
                 response_content = response_content.strip()[3:-3].strip()

            parsed_data = json.loads(response_content)

            # Validate expected keys and types (basic validation)
            if not isinstance(parsed_data, dict) or \
               "proceed_to_research" not in parsed_data or \
               "assistant_message" not in parsed_data or \
               not isinstance(parsed_data["proceed_to_research"], bool) or \
               not isinstance(parsed_data["assistant_message"], str):
                raise ValueError("Parsed JSON missing required keys or has incorrect types.")

            has_enough_info = parsed_data["proceed_to_research"]
            assistant_content = parsed_data["assistant_message"]
            print(f"DEBUG: JSON parsed successfully. proceed_to_research={has_enough_info}")

        except (json.JSONDecodeError, ValueError) as json_error:
            print(f"ERROR: Failed to parse valid JSON or validate structure from LLM response: {json_error}")
            print(f"LLM Raw Response causing error: {response_content}")
            has_enough_info = False # Default to continuing conversation on format error
            assistant_content = "I seem to be having trouble formatting my thoughts. Could you please clarify your last point or ask again?"
            # Optionally, you could use the raw response_content here if it might be readable

    except Exception as llm_error:
        print(f"ERROR: LLM invocation failed in interactive_conversation: {llm_error}")
        has_enough_info = False # Default to continuing
        assistant_content = "I encountered an issue communicating. Could you please try again?"
        # No new_stage variable needed here as it's determined after the try-except block

    # Determine the next stage based on the parsed boolean flag
    new_stage = "research" if has_enough_info else "conversation"
    print(f"DEBUG: Based on parsed JSON/error handling: enough info? {has_enough_info}. New stage: {new_stage}")

    # Use the extracted message content
    new_message = {"role": "assistant", "content": assistant_content}
    updated_messages = current_messages + [new_message]

    return {
        "messages": updated_messages,
        "question_count": question_count,
        "conversation_stage": new_stage,
        "symptom_details": symptom_details
    }
    # --- End Pre-check ---

    # Update symptom details if new user message arrived since last extraction
    if current_messages and current_messages[-1].get("role") == "user":
        last_updated = symptom_details.get("last_updated", -1)
        if len(current_messages) > last_updated:
            symptom_details = extract_symptom_details(current_messages)

    # Define prompt for the information gathering stage
    prompt = f"""
    {SYSTEM_PROMPT}

    You are in the information gathering stage. This is question number {question_count}. Here's the conversation so far:
    {format_conversation_history(current_messages)}

    Based on this information, ask **ONE** specific, relevant follow-up question to gather more details about the symptoms already mentioned (like duration, progression, aggravating/alleviating factors, associated symptoms, relevant history).

    Alternatively, if you assess that you have sufficient detail about the main symptoms (e.g., at least 3-4 different symptoms or aspects clarified), respond ONLY with the exact phrase: "I have enough information to analyze your symptoms now."
    """

    print(f"DEBUG: Invoking LLM for conversation (Question {question_count})...")
    response = llm.invoke(prompt)
    response_content = response.content if hasattr(response, 'content') else str(response)
    print(f"DEBUG: LLM response received: '{response_content[:100]}...'")

    # Check if the LLM decided it has enough information
    has_enough_info = "enough information" in response_content.lower()

    # Determine the next stage based ONLY on LLM response now (count check is done above)
    new_stage = "research" if has_enough_info else "conversation"
    print(f"DEBUG: LLM indicated enough info? {has_enough_info}. New stage: {new_stage}")

    new_message = {"role": "assistant", "content": response_content}
    updated_messages = current_messages + [new_message]

    return {
        "messages": updated_messages,
        "question_count": question_count, # Pass the current count along
        "conversation_stage": new_stage,
        "symptom_details": symptom_details
    }

def format_conversation_history(messages):
    """Format the conversation history for the LLM prompt"""
    formatted = ""
    for msg in messages:
        # Ensure content exists and is a string
        content = msg.get("content", "")
        if not isinstance(content, str):
             content = str(content) # Convert non-strings

        role = "User" if msg.get("role") == "user" else "Assistant"
        formatted += f"{role}: {content}\n\n"
    return formatted.strip() # Remove trailing newline

@api_rate_limit(1) # Add rate limiting if desired
def extract_symptom_details(messages):
    """Extract symptom information from user messages using LLM"""
    # Combine relevant user messages
    user_input_list = [
        str(msg.get("content", "")) # Ensure content is string
        for msg in messages
        if msg.get("role") == "user"
    ]
    if not user_input_list:
         return {"extracted_data": "No user input found", "last_updated": len(messages)}

    all_user_input = "\n".join(user_input_list)

    extract_prompt = f"""
    Based on the following user messages, extract and structure key symptom information:

    {all_user_input}

    Organize details into: Primary symptoms (list with severity/duration if mentioned), Associated symptoms, Timing/Patterns, Aggravating/Relieving factors, Relevant medical history.
    Return as concise, structured text (not strict JSON).
    """
    try:
        print("DEBUG: Extracting symptom details...")
        response = llm.invoke(extract_prompt)
        extracted_content = response.content if hasattr(response, 'content') else str(response)
        print("DEBUG: Symptom extraction complete.")
        return {"extracted_data": extracted_content, "last_updated": len(messages)}
    except Exception as e:
        print(f"Error extracting symptom details: {str(e)}")
        # Provide error information but allow flow to continue
        return {"extracted_data": f"Error processing symptoms: {str(e)}", "last_updated": len(messages)}

# --- Placeholder for Waiting ---
# This node doesn't do anything, it's just a named step in the graph
# where the control flow pauses before the next user input in the command line loop.
def wait_for_user_response(state: State):
     """Node indicating the graph is waiting for user input."""
     print("DEBUG: Entering wait_for_user_response node (waiting for input loop)...")
     # No state change needed here, just a logical pause point
     return state

# Research Determination
# Analyzes symptoms and queries medical research
# Modify the determine_research_needs function to explicitly use the tool:
# --- Research Node ---
# No rate-limit needed here as it calls perplexity_research, which already has rate-limiting
def determine_research_needs(state: State):
    """Determine what conditions to research based on conversation."""
    print("DEBUG: Entering determine_research_needs node...")
    messages = state["messages"]
    symptom_details = state.get("symptom_details", {})

    # Use the structured details if available, otherwise fall back to user messages
    extracted_data = symptom_details.get("extracted_data", "No structured data extracted.")
    if extracted_data == "No structured data extracted." or "Error processing symptoms" in extracted_data:
         # Fallback: use raw user input if extraction failed or didn't happen
         user_input_list = [str(msg.get("content","")) for msg in messages if msg.get("role") == "user"]
         symptom_summary_for_research = "\n".join(user_input_list)
         print("DEBUG: Using raw user input for research prompt as structured data is unavailable/error.")
    else:
        symptom_summary_for_research = extracted_data
        print("DEBUG: Using extracted symptom details for research prompt.")


    research_prompt = f"""
    Based on the following symptom information:
    {symptom_summary_for_research}

    Perform medical research focusing on:
    1. Most probable conditions (ranked).
    2. Brief explanation, causes, risk factors for each.
    3. Cite relevant, authoritative sources (e.g., Mayo Clinic, NIH, PubMed links if possible).
    4. Suggest potential diagnostic steps.
    """
    print("RESPONSE: Starting Perplexity research...")
    # Ensure the tool gets a dictionary with 'query' key
    results = perplexity_research.invoke({"query": research_prompt})
    print("RESPONSE: Perplexity research complete.")

    # Store results correctly
    return {"research_results": {"medical_research": results}} # Ensure results are nested if needed later

# Processes research data and generates medical analysis
# --- Analysis Node ---
@api_rate_limit(1)
def generate_analysis(state: State):
    """Generate medical analysis incorporating research."""
    print("DEBUG: Entering generate_analysis node...")
    # Correctly access nested research results
    research_data = state.get('research_results', {}).get('medical_research', 'No research data available.')
    messages = state["messages"]
    symptom_details = state.get("symptom_details", {})

    # Prepare symptom summary for analysis prompt
    extracted_data = symptom_details.get("extracted_data", "No structured data.")
    if extracted_data == "No structured data." or "Error processing symptoms" in extracted_data:
         user_input_list = [str(msg.get("content","")) for msg in messages if msg.get("role") == "user"]
         symptom_summary_for_analysis = "\n".join(user_input_list)
    else:
        symptom_summary_for_analysis = extracted_data

    analysis_prompt = f"""
    {SYSTEM_PROMPT}
    Generate a detailed medical analysis based on the conversation and research.
    Format the entire report using Markdown syntax. Use headings (e.g., `## Section Title` or `**Section Title:**`), bullet points (`* point` or `- point`), and
    bold text (`**important**`) for clarity and readability.
    IMPORTANT: Ensure your entire report uses standard UTF-8 encoding. Avoid generating non-printable control characters. Use only widely compatible Markdown syntax (headings, lists, bold, italics, standard tables).

    SYMPTOM SUMMARY:
    {symptom_summary_for_analysis}

    RESEARCH FINDINGS:
    {research_data}

    Your analysis report should include:
    1. Summary of key symptoms and risk factors (from conversation).
    2. Differential diagnosis: Ranked list of probable conditions with confidence scores (e.g., Use percentages strongly supported by research for specific criteria. Justify ranking briefly.
    3. Explanation of top 2-3 conditions (causes, symptoms matching/not matching).
    4. Recommended next steps (e.g., see primary care, specialist, diagnostics mentioned in research).
    5. **Crucially:** Reiterate if any symptoms warrant **immediate emergency care**. Include standard medical disclaimers.
    """
    print("DEBUG: Invoking LLM for analysis generation...")
    analysis_response = llm.invoke(analysis_prompt)
    analysis_content = analysis_response.content if hasattr(analysis_response, 'content') else str(analysis_response)
    print("DEBUG: Analysis generation complete.")
    return {"analysis_complete": True, "report": {"content": analysis_content}} # Store content correctly

# ----Final Response Formation----
def final_response(state: State):
    """Format the final report for the user."""
    print("DEBUG: Entering final_response node...")
    report_content = state.get("report", {}).get("content", "Analysis could not be generated.")
    final_message = {
        "role": "assistant",
        "content": f"--- Medical Analysis Report ---\n\n{report_content}\n\n--- End of Report ---"
    }
    print("DEBUG: Final response formatted.")

    # Add a final message and ensure stage is 'complete'
    return {
        "messages": state["messages"] + [final_message],
        "conversation_stage": "complete",
        "analysis_complete": True, # Ensure this is set to finish the loop
        "report": state["report"] # Pass report through
    }




# Step 6: Flow Control Functions

These functions help provide conditional flow to the Diagnostics Agent, helping it make decisions on whether to to research or not, if analysis is completed, and when to reset conversation.

In [None]:


# Research Decision Logic
# Determines if additional research is needed
def should_research(state: State) -> str:
    """Determine if research is needed based on message content"""
    messages = state["messages"]
    last_message = messages[-1]["content"]

    # Always do research for medical queries for now, can be modified for more complex applications
    if any(term in last_message.lower() for term in ["symptoms", "pain", "feeling", "medical", "health"]):
        return "research"
    return "generate_analysis"

# Analysis Completion Check to verify if the medical analysis is complete
def is_analysis_complete(state: State) -> str:
    """Check if analysis is complete or if further conversation is needed."""
    # Simplified logic without LLM call
    return "complete" if state.get("analysis_complete") else "intake_conversation"

# --- Reset Conversation Node ---
def reset_conversation(state: State):
    """Reset the state for a new topic, keeping only the last user message."""
    print("RESTART: Entering reset_conversation node...")
    last_user_message = None
    if state["messages"] and state["messages"][-1].get("role") == "user":
         last_user_message = state["messages"][-1]

    # Acknowledge the reset
    acknowledgment = {
        "role": "assistant",
        "content": "Okay, let's focus on this new topic. Please tell me about the new symptoms or concerns you have."
    }

    # Start new history
    new_messages = [last_user_message, acknowledgment] if last_user_message else [acknowledgment]

    # Return a fully reset state dictionary
    return {
        "messages": new_messages,
        "research_results": {},
        "analysis_complete": False,
        "report": {},
        "conversation_stage": "conversation", # Back to starting conversation stage
        "symptom_details": {},
        "question_count": 0
    }

def determine_next_stage(state: State) -> str:
    """Determine the next node or END the current invocation to wait for user."""
    print(f"THINKING: Determining next stage... Current stage: {state.get('conversation_stage')}")
    messages = state["messages"]
    current_stage = state.get("conversation_stage", "conversation")
    last_message_role = messages[-1].get("role") if messages else None

    # Check if analysis is complete (triggered after final_response runs)
    if current_stage == "complete":
         if last_message_role == "user":
             last_user_message_content = str(messages[-1].get("content", "")).lower()
             if any(phrase in last_user_message_content for phrase in ["new symptom", "different issue", "another problem", "new topic"]):
                 print("RESTART: Routing to restart_conversation.")
                 return "restart_conversation"
             else:
                 print("END: Routing to END graph (conversation complete, no new topic).")
                 return END # END the graph's execution completely
         else: # Last message was assistant's final report
             print("END: Routing to END graph (final report sent).")
             return END # END the graph's execution completely

    # If interactive_conversation decided we need to research
    if current_stage == "research":
        print("PROCESSING: Routing to start_research.")
        return "start_research"

    # If we are in the conversation stage
    if current_stage == "conversation":
        if last_message_role == "assistant":
            # Assistant just spoke. If it asked a question (didn't say "enough info"),
            # stop the graph execution here to wait for user input in the external loop.
            if "enough information" not in str(messages[-1].get("content", "")).lower():
                 print("PROCESSING: Routing to END (yielding for user input).")
                 return END # <<<--- Stops the current invoke call
            else:
                 # Assistant said "enough info", but stage is still 'conversation'.
                 # This means interactive_conversation should have set stage to 'research'.
                 # The next invoke call will handle the 'research' stage correctly.
                 # So, we END the current invoke here.
                 print("DEBUG: Routing to END (yielding before research stage starts on next invoke).")
                 return END # <<<--- Stops the current invoke call

        elif last_message_role == "user":
            # User just responded, continue the conversation internally
            print("DEBUG: Routing to continue_conversation.")
            return "continue_conversation" # Go back to interactive_conversation node
        else: # Initial state
            print("DEBUG: Routing to continue_conversation (initial state).")
            return "continue_conversation"

    # Fallback case - should ideally not be reached with proper state management
    print("ERROR: determine_next_stage fell through. Routing to END.")
    return END # <<<--- Stops the current invoke call

## Step 7: Putting it all together with Graph Construction

In [None]:
# === Build the Multi-Turn Graph ===

graph_builder = StateGraph(State)

graph_builder.add_node("interactive_conversation", interactive_conversation)
graph_builder.add_node("determine_research_needs", determine_research_needs)
graph_builder.add_node("generate_analysis", generate_analysis)
graph_builder.add_node("final_response", final_response)
graph_builder.add_node("reset_conversation", reset_conversation)

# Starting edge
graph_builder.add_edge(START, "interactive_conversation")

# Edges from interactive_conversation based on determine_next_stage
graph_builder.add_conditional_edges(
    "interactive_conversation",
    determine_next_stage,
    {
        "continue_conversation": "interactive_conversation", # Loop back if user responded
        "start_research": "determine_research_needs",       # Move to research when ready
        END: END                                            # Route to graph's END when yielding for user
    }
)

# REMOVED Edges related to wait_for_user

# Connect research and analysis flow (remains the same)
graph_builder.add_edge("determine_research_needs", "generate_analysis")
graph_builder.add_edge("generate_analysis", "final_response")

# End after final response (or handle reset/follow-up from there)
graph_builder.add_conditional_edges(
    "final_response",
    determine_next_stage, # Reuse determine stage after final report is added
    {
        END: END, # Use END directly to terminate graph execution
        "restart_conversation": "reset_conversation"
    }
)

# Connect reset node back to conversation (remains the same)
graph_builder.add_edge("reset_conversation", "interactive_conversation")

# Compile the graph
print("Compiling the revised graph...")
graph = graph_builder.compile()
print("Revised graph compiled.")

#Optional: Draw the graph again if you like
from IPython.display import Image
try:
    display(Image(graph.get_graph().draw_png()))
except Exception as e:
    print(f"Could not draw graph: {e}")

Compiling the revised graph...
Revised graph compiled.
Could not draw graph: Install pygraphviz to draw graphs: `pip install pygraphviz`.


In [None]:
# System Prompt defines AI's role and responsibilities in medical analysis
SYSTEM_PROMPT = """
You are an advanced AI medical assistant simulating a preliminary diagnostic consultation with access to up-to-date medical literature, expert guidelines, and peer-reviewed studies. Your role is to:
1. Conduct a structured diagnostic evaluation, mimicking a board-certified physician’s approach.
2. Use differential diagnosis methods, listing probable conditions with confidence scores assessing the likelihood of each condition.
3. Prioritize high-accuracy, medically reviewed sources (such as but not limited to PubMed, Mayo Clinic, NIH, UpToDate).
4. Clearly communicate **when emergency medical care might be required**.
5. Provide a clear, structured medical report summarizing likely conditions with citations to justify the evaluation, risk assessments, and next steps.
"""

In [None]:
# Execution Function
# Main function to run the medical analysis workflow
def run_medical_analysis(initial_message: str):
    """Runs the medical analysis graph with the given initial message."""
    initial_state = {
        "messages": [{"role": "user", "content": initial_message}],
        "research_results": {},
        "analysis_complete": False,
        "report": {}
    }

    results = graph.invoke(initial_state)
    return results["messages"]

In [None]:
# === Interactive Command Line Execution (with Rich Rendering for better visual output) ===
def run_command_line():
    """Run an interactive demo of the medical chatbot in the command line."""
    print("\n--- Medical Symptom Analysis Chatbot ---")
    print("Describe your symptoms to start.")
    print("Type 'exit' to end.")
    print("Type 'new topic' (or similar) after analysis to discuss something else.\n")

    # Instantiate Console *outside* the loop
    console = Console()
    state = None # initialize state as None

    while True:
        if not state:
            # Start of a new conversation
            user_input = input("You: ")
            if user_input.lower() == 'exit':
                break
            initial_state_dict = {
                "messages": [{"role": "user", "content": user_input}],
                "research_results": {}, "analysis_complete": False, "report": {},
                "conversation_stage": "conversation", "symptom_details": {}, "question_count": 0
            }
            # Invoke the graph to get the first assistant response
            try:
                print("START: Invoking graph (initial)...")
                state = graph.invoke(initial_state_dict, {"recursion_limit": 15})
                print("UPDATE: Graph invocation complete (initial).")
            except Exception as e:
                # Use console.print for error messages too, for consistency
                console.print(f"\n[bold red]ERROR:[/bold red] Graph failed during initial invocation: {e}")
                console.print("Please try again or type 'exit'.")
                state = None
                continue

        else:
            # Continue existing conversation
            user_input = input("You: ")
            if user_input.lower() == 'exit':
                break

            current_messages = state.get("messages", [])
            updated_messages = current_messages + [{"role": "user", "content": user_input}]
            state["messages"] = updated_messages

            try:
                print("PROCESSING: Invoking graph (continue)...")
                state = graph.invoke(state, {"recursion_limit": 15})
                print("UPDATE: Graph invocation complete (continue).")
            except Exception as e:
                console.print(f"\n[bold red]ERROR:[/bold red] Graph failed during continuation: {e}")
                if state and state.get("messages"):
                     # Try to render the last assistant message before the error, if possible
                     last_assistant_message = state["messages"][-1]
                     if last_assistant_message.get("role") == "assistant":
                         console.print(f"\n[bold deep_sky_blue1][Assistant]:[/bold deep_sky_blue1]")
                         console.print(Markdown(last_assistant_message.get('content', '[No Content]')))
                     else: # Fallback if last message wasn't assistant
                         console.print("\n[bold red]Assistant:[/bold red] Sorry, an error occurred.")
                else:
                    console.print("\n[bold red]Assistant:[/bold red] Sorry, an error occurred and I lost track of the conversation. Please start over or type 'exit'.")
                    state = None
                continue


        # --- Process graph output ---
        if not state or not state.get("messages"):
            console.print("\n[bold red]Assistant:[/bold red] Sorry, something went wrong, and I don't have a response.")
            state = None
            continue

        # Display the latest assistant message using Rich
        assistant_message = state["messages"][-1]
        if assistant_message.get("role") == "assistant":
            # *** Use Rich Console and Markdown Here ***
            console.print(f"\n[bold deep_sky_blue1][Assistant]:[/bold deep_sky_blue1]")
            markdown_content = Markdown(assistant_message.get('content', '[No Content]'))
            console.print(markdown_content)
            print() # Add an extra newline for spacing after the rendered block
        else:
            # Should not happen if graph works correctly
            print("DEBUG: Expected assistant message, but last message was:", assistant_message.get("role"))


        # Check if the conversation has reached a final state
        current_stage = state.get("conversation_stage")
        if current_stage == "complete":
            if "--- End of Report ---" in str(assistant_message.get("content", "")):
                # Use console.print for consistency
                console.print("\n[bold green]--- Analysis Complete ---[/bold green]")
                console.print("You can ask follow-up questions about this report, type 'new topic' to discuss something else, or type 'exit'.")
            else:
                console.print("\n[bold yellow]--- Conversation Ended ---[/bold yellow]")
                break # Exit loop


    console.print("\nChat ended.")

def start_interactive_chat():
    try:
        if not GEMINI_API_KEY or not PERPLEXITY_API_KEY:
             print("ERROR: API Keys not found. Please set them up in Colab secrets.")
             return
        run_command_line() # This now uses the rich-enabled version
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

We have the components and the graph blueprint. Now let's run it! We want to test the model and evaluate its quality against a real medical professional, so we will use AI-generated patient data to generate the output, and use AI to evaluate the model.

In [None]:
# Go to chatgpt.com and enter this prompt to test the model against actual input
chatgpt_prompt = """
Create a sample user info for testing an AI Medical Diagnostics Agent.
The aim of the agent is to act as an advanced AI medical assistant simulating a preliminary diagnostic consultation
with access to up-to-date medical literature, expert guidelines, and peer-reviewed studies. You must generate some input simulating a
patient with a certain medical condition. Your input must not make the diagnosis very easy or make it obvious what condition the person is suffering from,
because the goal is to test the accuracy and quality of the AI agent as compared against a real medical professional with years of experience.
Generate example input for me, dividing it up into sections of text to test the AI agent. Also tell me what the expected diagnosis should be, but
make it hard for the agent to figure it out. I may also ask you to answer questions matching that user conversation which were asked from agent,
and you must give me the response to share. Keep it short and concise, only share text within 50 words or less per conversation turn.

Later when the conversation is finished, I will share the final report.
Rate the agent out of 100 on the output, then calculate the expected score for a real medical professional on the same patient info,
and compare their performance briefly.
"""
# Test case with expected output to be heart attack, if AI asks additional questions, use chatGPT to answer those while keeping the expected diagnosis same.
sample_input = "Hi, I've been feeling really off lately. " \
"For the past few hours, I’ve had some chest discomfort, " \
"but it’s not exactly pain. It’s more of a pressure, kind of like " \
"something heavy is on my chest. I also feel really short of breath, " \
"especially when I try to move around or even just stand up. Sometimes, it " \
"feels like my left arm is a little sore, and I've noticed some dizziness as well. " \
"I’m also feeling unusually nauseous, which isn’t something I usually deal with. " \
"I’m 45, not very active, and have had some family members with heart issues. " \
"I’m not sure if this is something I should be concerned about or if I’m just " \
"overthinking it. Can you help?"

In [None]:
start_interactive_chat()


--- Medical Symptom Analysis Chatbot ---
Describe your symptoms to start.
Type 'exit' to end.
Type 'new topic' (or similar) after analysis to discuss something else.

START: Invoking graph (initial)...
PROCESSING: Entering interactive_conversation node...
PROCESSING: (interactive_conversation) Extracting details from latest user message...
DEBUG: Extracting symptom details...
DEBUG: Symptom extraction complete.
DEBUG: Invoking LLM for conversation (Turn 1, assessing sufficiency, expecting JSON)...
DEBUG: LLM raw response received: '```json
{
  "proceed_to_research": false,
  "assistant_message": "To better understand your IT band ...'
DEBUG: JSON parsed successfully. proceed_to_research=False
DEBUG: Based on parsed JSON/error handling: enough info? False. New stage: conversation
THINKING: Determining next stage... Current stage: conversation
PROCESSING: Routing to END (yielding for user input).
UPDATE: Graph invocation complete (initial).



PROCESSING: Invoking graph (continue)...
PROCESSING: Entering interactive_conversation node...
PROCESSING: (interactive_conversation) Extracting details from latest user message...
DEBUG: Extracting symptom details...
DEBUG: Symptom extraction complete.
DEBUG: Invoking LLM for conversation (Turn 2, assessing sufficiency, expecting JSON)...
DEBUG: LLM raw response received: '```json
{
  "proceed_to_research": false,
  "assistant_message": "Can you describe the character of ...'
DEBUG: JSON parsed successfully. proceed_to_research=False
DEBUG: Based on parsed JSON/error handling: enough info? False. New stage: conversation
THINKING: Determining next stage... Current stage: conversation
PROCESSING: Routing to END (yielding for user input).
UPDATE: Graph invocation complete (continue).



PROCESSING: Invoking graph (continue)...
PROCESSING: Entering interactive_conversation node...
PROCESSING: (interactive_conversation) Extracting details from latest user message...
DEBUG: Extracting symptom details...
DEBUG: Symptom extraction complete.
DEBUG: Invoking LLM for conversation (Turn 3, assessing sufficiency, expecting JSON)...
DEBUG: LLM raw response received: '```json
{
  "proceed_to_research": false,
  "assistant_message": "To understand the potential cause ...'
DEBUG: JSON parsed successfully. proceed_to_research=False
DEBUG: Based on parsed JSON/error handling: enough info? False. New stage: conversation
THINKING: Determining next stage... Current stage: conversation
PROCESSING: Routing to END (yielding for user input).
UPDATE: Graph invocation complete (continue).





# Workshop Recap

Congratulations! You've built and interacted with a multi-step AI agent using LangGraph.

**Key Takeaways:**

* **LLMs can orchestrate:** They don't just generate text; they can follow steps, use tools, and make autonomous decisions within a defined structure.
* **LangGraph provides structure:** It allows us to build complex, stateful AI workflows reliably by defining nodes (steps) and edges (transitions).
* **State is crucial:** Managing the conversation history, intermediate results, and current stage is essential for multi-turn interactions for complex tasks.
* **Tools enhance LLMs:** Giving LLMs access to external APIs or functions dramatically increases their capabilities, making them experts at tasks.
* **Prompting is key:** Carefully crafted prompts (System prompts, prompts for nodes. tools) guide the AI's behavior.

**Further Exploration:**

* Add more tools (e.g., a calculator, a database lookup).
* Experiment with different LLMs or prompt strategies on new use cases.
* Explore LangSmith for debugging and tracing your graph runs.