# MedFlow AI with LangGraph

This notebook demonstrates the MedFlow AI workflow using [LangGraph](https://python.langchain.com/docs/langgraph) to orchestrate the agents and the doctor-in-the-loop step.

In [None]:
import sys
import os

# Add src to path so we can import our modules
sys.path.append(os.path.abspath("src"))

from medflow.agents.agent1 import SoapNoteGenerator
from medflow.agents.agent2 import PlanAnalyzer
from medflow.utils.pdf_generator import generate_soap_pdf
from medflow.utils.visualization import save_visualization_html

import torch
from transformers import pipeline
from huggingface_hub import login
from typing import TypedDict, Annotated, List, Dict, Any
from langgraph.graph import StateGraph, END

import json

## 1. Setup & Model Loading

In [None]:
# Login to Hugging Face
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    login(token=hf_token)
else:
    print("Warning: HF_TOKEN not found. Ensure you are logged in or provide the token.")

# Load Model
print("Loading MedGemma model...")
try:
    pipe = pipeline(
        "image-text-to-text",
        model="google/medgemma-4b-it",
        torch_dtype=torch.bfloat16,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")

## 2. Define Graph State

In [None]:
class MedFlowState(TypedDict):
    patient_info: Dict[str, Any]
    images: List[Any]
    ethnicity: str
    soap_note: Dict[str, Any]
    doctor_plan: Dict[str, Any]
    final_output: Dict[str, Any]

## 3. Define Nodes

In [None]:
def agent1_node(state: MedFlowState):
    print("--- Agent 1: Generating SOA ---")
    agent = SoapNoteGenerator(pipe)
    soap = agent.generate(state["patient_info"], state.get("images"))
    return {"soap_note": soap}

def doctor_node(state: MedFlowState):
    print("--- Doctor: Reviewing & Planning ---")
    print("Generated SOA:", json.dumps(state["soap_note"], indent=2))
    
    # In a real app, this would pause for human input.
    # For this notebook demo, we simulate the doctor's input.
    print("Doctor adding plan...")
    plan = {
        "medications": ["Omeprazole 20mg once daily"],
        "lab_tests": ["H. pylori test", "CBC"],
        "follow_up": "2 weeks"
    }
    return {"doctor_plan": plan}

def agent2_node(state: MedFlowState):
    print("--- Agent 2: Analyzing & Finalizing ---")
    agent = PlanAnalyzer(pipe)
    result = agent.analyze(
        state["soap_note"], 
        state["doctor_plan"], 
        state.get("ethnicity", "Not provided")
    )
    return {"final_output": result}

## 4. Build Graph

In [None]:
workflow = StateGraph(MedFlowState)

# Add nodes
workflow.add_node("agent1", agent1_node)
workflow.add_node("doctor", doctor_node)
workflow.add_node("agent2", agent2_node)

# Define edges
workflow.set_entry_point("agent1")
workflow.add_edge("agent1", "doctor")
workflow.add_edge("doctor", "agent2")
workflow.add_edge("agent2", END)

# Compile
app = workflow.compile()

## 5. Run Workflow

In [None]:
# Example Data
patient_input = {
    "age": 45,
    "gender": "Male",
    "symptoms": [
        "Chest discomfort",
        "Shortness of breath during exertion",
        "Fatigue"
    ],
    "duration": "2 weeks",
    "severity": "Moderate",
    "medical_history": ["Hypertension"],
    "medications": [],
    "vitals": {
        "blood_pressure": "145/90",
        "heart_rate": "92 bpm"
    }
}

# Invoke Graph
initial_state = {
    "patient_info": patient_input,
    "ethnicity": "South Asian",
    "images": None
}

final_state = app.invoke(initial_state)

print("\n=== Final Workflow Output ===")
print(json.dumps(final_state["final_output"], indent=2))

## 6. Generate Outputs (PDF & Visualization)

In [None]:
final_data = final_state["final_output"]

# Ensure structure for PDF generator
soap_content = final_data.get("soap_note", {})
if not soap_content:
    # Fallback if structure varies
    soap_content = final_state["soap_note"]
    soap_content["plan"] = str(final_state["doctor_plan"])

# Generate PDF
generate_soap_pdf(soap_content, final_data, "medflow_langgraph_report.pdf")

# Generate Visualization
save_visualization_html(final_data, "medflow_langgraph_viz.html")