# Corrections
1. i am using pillow for image import, but my images are inside the PDFs...


# Medical Knowledge System with Multi-Agent Architecture

This notebook implements a comprehensive medical knowledge system using:
- **Multi-Agent Architecture**: Agent Manager, Planning Agent, Text/Image Retrievers, Reasoner, and Draft Agent
- **A2A Communication**: Agent-to-Agent messaging system
- **MCP Integration**: Model Context Protocol for agent communication
- **Milvus Vector Store**: For storing and retrieving embeddings
- **LangChain**: For agent orchestration and LLM integration

## Architecture Overview

1. **Agent Manager**: Oversees all agents and coordinates communication
2. **Planning Agent**: Analyzes queries to determine what information is needed
3. **Text Retriever Agent**: Searches through textbooks and research papers
4. **Image Retriever Agent**: Processes images from medical documents
5. **Reasoner Agent**: Synthesizes information from all sources
6. **Draft Agent**: Generates final responses for doctors


In [None]:
# Install required packages
# !pip install langchain langchain-community langchain-openai pymilvus transformers torch torchvision pillow sentence-transformers


In [None]:
# Core imports
import os                        # python model for interacting with operating system. we need it to access environment variables (like API keys)
import logging                   # module to record events/errors (debugging and monitoritng)
import json                      # to work with the structured data? (maybe)
import asyncio                   # asynchronous programming (to potentially run multiple tasks concurrently. not heavily used here though)
from typing import List, Dict, Any, Optional, Tuple, Union  # to specify what type of data functions expect/return
from dataclasses import dataclass  
from abc import ABC, abstractmethod   
import uuid                      # module to generate universally unique identifiers
from datetime import datetime

# Document processing
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader  # (PyPDFLoader: loads pdf and extracts them) (DirectoryLoader: load files matching a pattern from a directory. this is to upload all medical docs)
from langchain_text_splitters import RecursiveCharacterTextSplitter  # text splitter (hyperparameter: chunk_size)
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Milvus  # integration with Milvus vector database

# LLM and agents
from langchain_openai import ChatOpenAI                                           # wrapper for openAI's GPT models (to use GPT-something for reasoning and response generation. requires openAI API costs)
from langchain_core.prompts import ChatPromptTemplate                             # template for creating prompts (do we really need it?)
from langchain_core.output_parsers import StrOutputParser                         # extracts string output from LLM responses. LLMs return complex objects, this just gets the text
from langchain_core.runnables import RunnablePassthrough                          
from langchain.agents import AgentExecutor, create_react_agent, Tool              # (AgentExecutor: runs agents & manages execution) (create_react_agent: creates ReAct agent) (tool: defines tools agents can use. we need this to create autonomous agents)
from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage  # different types of messages in conversation. we need this to structure conversations with LLMs (HumanMessage: from user) (AIMessage: from the ai) (SystemMessage: instructions to the ai) (BaseMessage: parent class of all messages)

# Image processing
from PIL import Image               # to load images before processing them
import torch                        # pytorch
from torchvision import transforms  # image transformation utilities (to preprocess images for neural networks)
import clip                         # openAI's model that understands both images and text

# Vector store
from pymilvus import connections, Collection, utility   # (connections: Manages connections to milvus database) (Collection: represents collection—like a table—in milvus) (utility: helper functions, allows us to interact with milvus vector database)

# Set up logging
logging.basicConfig(level=logging.INFO,             # show INFO level and above (INFO, WARNING, ERROR, CRITICAL)
format='%(asctime)s - %(levelname)s - %(message)s') # ("%(asctime)s": timestamp when log was created) ("%(levelname)s": level) ("%(message)s": actual log message) (Example output: 2025-10-02 14:30:15 - INFO - System initialized)
logger = logging.getLogger(__name__)                # you can log how agent works along each module (each .py file is a module. for ex: agents.py, retrieval.py). you'll log uses so you see "agent.planning_agent - INFO - Analyzing query"

# Configuration
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"
TEXT_COLLECTION = "medical_text"    # names for milvus collections, like table names (built separate connections for image and text bc they have diff embedding dimensions)
IMAGE_COLLECTION = "medical_images"


## Core Communication Protocols


In [None]:
@dataclass
class A2AMessage:
    """Agent-to-Agent message for communication between agents"""
    message_id: str
    sender: str
    receiver: str
    message_type: str
    content: Dict[str, Any]
    timestamp: datetime
    priority: int = 1  # 1=normal, 2=high, 3=urgent
    
    def __post_init__(self):
        if not self.message_id:
            self.message_id = str(uuid.uuid4())
        if not self.timestamp:
            self.timestamp = datetime.now()

@dataclass
class MCPRequest:
    """Model Context Protocol request"""
    request_id: str
    model: str
    prompt: str
    context: Dict[str, Any]
    parameters: Dict[str, Any] = None
    
    def __post_init__(self):
        if not self.request_id:
            self.request_id = str(uuid.uuid4())
        if not self.parameters:
            self.parameters = {}

@dataclass
class MCPResponse:
    """Model Context Protocol response"""
    request_id: str
    result: Any
    metadata: Dict[str, Any] = None
    success: bool = True
    error_message: str = None
    
    def __post_init__(self):
        if not self.metadata:
            self.metadata = {}


## Base Agent Class and Vector Store


In [None]:
class BaseAgent(ABC):
    """Base class for all agents in the system"""
    
    def __init__(self, name: str, agent_type: str):
        self.name = name
        self.agent_type = agent_type
        self.mailbox: List[A2AMessage] = []
        self.llm = ChatOpenAI(model="gpt-4o", temperature=0.1)
        self.logger = logging.getLogger(f"agent.{name}")
        
    def send_message(self, receiver: str, message_type: str, content: Dict[str, Any], 
                    priority: int = 1, manager=None):
        """Send a message to another agent"""
        message = A2AMessage(
            message_id="",
            sender=self.name,
            receiver=receiver,
            message_type=message_type,
            content=content,
            timestamp=datetime.now(),
            priority=priority
        )
        if manager:
            manager.deliver_message(message)
        return message
    
    def receive_message(self, message: A2AMessage):
        """Receive a message from another agent"""
        self.mailbox.append(message)
        self.logger.info(f"Received message from {message.sender}: {message.message_type}")
    
    def process_messages(self):
        """Process all messages in the mailbox"""
        while self.mailbox:
            message = self.mailbox.pop(0)
            self.handle_message(message)
    
    @abstractmethod
    def handle_message(self, message: A2AMessage):
        """Handle a specific message - must be implemented by subclasses"""
        pass
    
    def make_mcp_request(self, model: str, prompt: str, context: Dict[str, Any], 
                        parameters: Dict[str, Any] = None) -> MCPRequest:
        """Create an MCP request"""
        return MCPRequest(
            request_id="",
            model=model,
            prompt=prompt,
            context=context,
            parameters=parameters or {}
        )

class VectorStoreManager:
    """Manages Milvus vector store operations"""
    
    def __init__(self, host: str = MILVUS_HOST, port: str = MILVUS_PORT):
        self.host = host
        self.port = port
        self.connection_args = {"host": host, "port": port}
        self.text_collection = None
        self.image_collection = None
        self.logger = logging.getLogger("vector_store")
        
    def connect(self):
        """Connect to Milvus"""
        try:
            connections.connect("default", host=self.host, port=self.port)
            self.logger.info(f"Connected to Milvus at {self.host}:{self.port}")
            return True
        except Exception as e:
            self.logger.error(f"Failed to connect to Milvus: {e}")
            return False
    
    def create_collections(self, text_embedding_dim: int = 384, image_embedding_dim: int = 512):
        """Create text and image collections"""
        if not self.connect():
            return False
            
        try:
            # Create text collection
            text_schema = {
                "fields": [
                    {"name": "id", "type": "varchar", "is_primary": True, "max_length": 100},
                    {"name": "embedding", "type": "float_vector", "dim": text_embedding_dim},
                    {"name": "text", "type": "varchar", "max_length": 10000},
                    {"name": "metadata", "type": "json"}
                ]
            }
            
            if not utility.has_collection(TEXT_COLLECTION):
                from pymilvus import CollectionSchema, FieldSchema, DataType
                
                # Define fields
                id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=100)
                embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=text_embedding_dim)
                text_field = FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=10000)
                metadata_field = FieldSchema(name="metadata", dtype=DataType.JSON)
                
                # Create schema
                schema = CollectionSchema(
                    fields=[id_field, embedding_field, text_field, metadata_field],
                    description="Medical text collection"
                )
                
                # Create collection
                self.text_collection = Collection(name=TEXT_COLLECTION, schema=schema)
                self.logger.info(f"Created text collection: {TEXT_COLLECTION}")
            else:
                self.text_collection = Collection(TEXT_COLLECTION)
                self.logger.info(f"Connected to existing text collection: {TEXT_COLLECTION}")
            
            # Create image collection (similar process)
            if not utility.has_collection(IMAGE_COLLECTION):
                from pymilvus import CollectionSchema, FieldSchema, DataType
                
                # Define fields for images
                id_field = FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=100)
                embedding_field = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=image_embedding_dim)
                image_path_field = FieldSchema(name="image_path", dtype=DataType.VARCHAR, max_length=500)
                metadata_field = FieldSchema(name="metadata", dtype=DataType.JSON)
                
                # Create schema
                schema = CollectionSchema(
                    fields=[id_field, embedding_field, image_path_field, metadata_field],
                    description="Medical image collection"
                )
                
                # Create collection
                self.image_collection = Collection(name=IMAGE_COLLECTION, schema=schema)
                self.logger.info(f"Created image collection: {IMAGE_COLLECTION}")
            else:
                self.image_collection = Collection(IMAGE_COLLECTION)
                self.logger.info(f"Connected to existing image collection: {IMAGE_COLLECTION}")
            
            return True
            
        except Exception as e:
            self.logger.error(f"Failed to create collections: {e}")
            return False
    
    def insert_texts(self, texts: List[str], embeddings: List[List[float]], 
                    metadata: List[Dict[str, Any]]):
        """Insert text documents into the collection"""
        if not self.text_collection:
            self.logger.error("Text collection not initialized")
            return False
            
        try:
            ids = [str(uuid.uuid4()) for _ in texts]
            data = [ids, embeddings, texts, metadata]
            self.text_collection.insert(data)
            self.text_collection.flush()
            self.logger.info(f"Inserted {len(texts)} text documents")
            return True
        except Exception as e:
            self.logger.error(f"Failed to insert texts: {e}")
            return False
    
    def insert_images(self, image_paths: List[str], embeddings: List[List[float]], 
                     metadata: List[Dict[str, Any]]):
        """Insert image documents into the collection"""
        if not self.image_collection:
            self.logger.error("Image collection not initialized")
            return False
            
        try:
            ids = [str(uuid.uuid4()) for _ in image_paths]
            data = [ids, embeddings, image_paths, metadata]
            self.image_collection.insert(data)
            self.image_collection.flush()
            self.logger.info(f"Inserted {len(image_paths)} image documents")
            return True
        except Exception as e:
            self.logger.error(f"Failed to insert images: {e}")
            return False
    
    def search_texts(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
        """Search for similar text documents"""
        if not self.text_collection:
            return []
            
        try:
            self.text_collection.load()
            results = self.text_collection.search(
                data=[query_embedding],
                anns_field="embedding",
                param={"metric_type": "L2", "params": {"nprobe": 10}},
                limit=top_k,
                output_fields=["text", "metadata"]
            )
            
            hits = []
            for hit in results[0]:
                hits.append({
                    "id": hit.id,
                    "score": hit.distance,
                    "text": hit.entity.get("text"),
                    "metadata": hit.entity.get("metadata")
                })
            return hits
        except Exception as e:
            self.logger.error(f"Failed to search texts: {e}")
            return []
    
    def search_images(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
        """Search for similar image documents"""
        if not self.image_collection:
            return []
            
        try:
            self.image_collection.load()
            results = self.image_collection.search(
                data=[query_embedding],
                anns_field="embedding",
                param={"metric_type": "L2", "params": {"nprobe": 10}},
                limit=top_k,
                output_fields=["image_path", "metadata"]
            )
            
            hits = []
            for hit in results[0]:
                hits.append({
                    "id": hit.id,
                    "score": hit.distance,
                    "image_path": hit.entity.get("image_path"),
                    "metadata": hit.entity.get("metadata")
                })
            return hits
        except Exception as e:
            self.logger.error(f"Failed to search images: {e}")
            return []


## Embedding Models


In [None]:
class TextEmbedder:
    """Handles text embedding using medical domain models"""
    
    def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        self.model_name = model_name
        self.embedder = HuggingFaceEmbeddings(
            model_name=model_name,
            model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
        )
        self.logger = logging.getLogger("text_embedder")
        
    def embed_texts(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of texts"""
        try:
            embeddings = self.embedder.embed_documents(texts)
            self.logger.info(f"Embedded {len(texts)} texts")
            return embeddings
        except Exception as e:
            self.logger.error(f"Failed to embed texts: {e}")
            return []
    
    def embed_query(self, query: str) -> List[float]:
        """Embed a single query"""
        try:
            embedding = self.embedder.embed_query(query)
            return embedding
        except Exception as e:
            self.logger.error(f"Failed to embed query: {e}")
            return []

class ImageEmbedder:
    """Handles image embedding using CLIP"""
    
    def __init__(self, model_name: str = "ViT-B/32"):
        self.model_name = model_name
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess = clip.load(model_name, device=self.device)
        self.logger = logging.getLogger("image_embedder")
        
    def embed_images(self, image_paths: List[str]) -> List[List[float]]:
        """Embed a list of images from file paths"""
        embeddings = []
        for image_path in image_paths:
            try:
                image = Image.open(image_path).convert('RGB')
                image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
                
                with torch.no_grad():
                    image_features = self.model.encode_image(image_tensor)
                    embedding = image_features.cpu().numpy().flatten().tolist()
                    embeddings.append(embedding)
                    
            except Exception as e:
                self.logger.error(f"Failed to embed image {image_path}: {e}")
                # Add zero embedding as fallback
                embeddings.append([0.0] * 512)  # CLIP ViT-B/32 has 512 dims
                
        self.logger.info(f"Embedded {len(image_paths)} images")
        return embeddings
    
    def embed_image(self, image_path: str) -> List[float]:
        """Embed a single image"""
        try:
            image = Image.open(image_path).convert('RGB')
            image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                image_features = self.model.encode_image(image_tensor)
                embedding = image_features.cpu().numpy().flatten().tolist()
                return embedding
                
        except Exception as e:
            self.logger.error(f"Failed to embed image {image_path}: {e}")
            return [0.0] * 512


## Individual Agents


In [None]:
class PlanningAgent(BaseAgent):
    """Analyzes queries to determine what information is needed"""
    
    def __init__(self):
        super().__init__("planning_agent", "planner")
        self.planning_prompt = ChatPromptTemplate.from_template("""
        You are a medical planning agent. Analyze the following medical question and determine what types of information are needed to provide a comprehensive answer.
        
        Question: {question}
        
        Consider:
        1. Do we need text information from textbooks/research papers?
        2. Do we need visual information from medical images/diagrams?
        3. What specific medical concepts or conditions are involved?
        4. What level of detail is required?
        
        Respond with a JSON object containing:
        - "needs_text": boolean
        - "needs_images": boolean
        - "medical_concepts": list of key concepts
        - "query_type": "diagnosis", "treatment", "symptoms", "anatomy", "pathology", etc.
        - "priority": "high", "medium", "low"
        - "reasoning": explanation of your analysis
        """)
        
    def handle_message(self, message: A2AMessage):
        if message.message_type == "analyze_query":
            self.analyze_query(message)
    
    def analyze_query(self, message: A2AMessage):
        """Analyze a medical query and determine information needs"""
        query = message.content.get("query", "")
        
        try:
            # Use LLM to analyze the query
            chain = self.planning_prompt | self.llm | StrOutputParser()
            analysis = chain.invoke({"question": query})
            
            # Parse the JSON response
            import json
            plan = json.loads(analysis)
            
            # Send analysis to agent manager
            response_content = {
                "original_query": query,
                "analysis": plan,
                "timestamp": datetime.now().isoformat()
            }
            
            self.send_message(
                receiver="agent_manager",
                message_type="query_analysis",
                content=response_content,
                manager=message.content.get("manager")
            )
            
            self.logger.info(f"Analyzed query: {query[:50]}...")
            
        except Exception as e:
            self.logger.error(f"Failed to analyze query: {e}")
            # Send error response
            self.send_message(
                receiver="agent_manager",
                message_type="query_analysis_error",
                content={"error": str(e), "original_query": query},
                manager=message.content.get("manager")
            )

class TextRetrieverAgent(BaseAgent):
    """Retrieves relevant text information from medical documents"""
    
    def __init__(self, vector_store_manager: VectorStoreManager, text_embedder: TextEmbedder):
        super().__init__("text_retriever", "retriever")
        self.vector_store = vector_store_manager
        self.embedder = text_embedder
        
    def handle_message(self, message: A2AMessage):
        if message.message_type == "retrieve_text":
            self.retrieve_text(message)
    
    def retrieve_text(self, message: A2AMessage):
        """Retrieve relevant text documents"""
        query = message.content.get("query", "")
        top_k = message.content.get("top_k", 5)
        
        try:
            # Embed the query
            query_embedding = self.embedder.embed_query(query)
            
            # Search the vector store
            results = self.vector_store.search_texts(query_embedding, top_k)
            
            # Format results
            formatted_results = []
            for result in results:
                formatted_results.append({
                    "id": result["id"],
                    "text": result["text"],
                    "score": result["score"],
                    "metadata": result["metadata"]
                })
            
            # Send results back
            self.send_message(
                receiver="agent_manager",
                message_type="text_results",
                content={
                    "query": query,
                    "results": formatted_results,
                    "count": len(formatted_results)
                },
                manager=message.content.get("manager")
            )
            
            self.logger.info(f"Retrieved {len(formatted_results)} text documents for query: {query[:50]}...")
            
        except Exception as e:
            self.logger.error(f"Failed to retrieve text: {e}")
            self.send_message(
                receiver="agent_manager",
                message_type="text_retrieval_error",
                content={"error": str(e), "query": query},
                manager=message.content.get("manager")
            )

class ImageRetrieverAgent(BaseAgent):
    """Retrieves relevant image information from medical documents"""
    
    def __init__(self, vector_store_manager: VectorStoreManager, image_embedder: ImageEmbedder):
        super().__init__("image_retriever", "retriever")
        self.vector_store = vector_store_manager
        self.embedder = image_embedder
        
    def handle_message(self, message: A2AMessage):
        if message.message_type == "retrieve_images":
            self.retrieve_images(message)
    
    def retrieve_images(self, message: A2AMessage):
        """Retrieve relevant image documents"""
        query = message.content.get("query", "")
        top_k = message.content.get("top_k", 5)
        
        try:
            # For image retrieval, we need to convert text query to image embedding
            # This is a simplified approach - in practice, you might want to use
            # a text-to-image embedding model or multimodal approach
            query_embedding = self.embedder.embed_query(query)  # This will use text embedder
            
            # Search the vector store
            results = self.vector_store.search_images(query_embedding, top_k)
            
            # Format results
            formatted_results = []
            for result in results:
                formatted_results.append({
                    "id": result["id"],
                    "image_path": result["image_path"],
                    "score": result["score"],
                    "metadata": result["metadata"]
                })
            
            # Send results back
            self.send_message(
                receiver="agent_manager",
                message_type="image_results",
                content={
                    "query": query,
                    "results": formatted_results,
                    "count": len(formatted_results)
                },
                manager=message.content.get("manager")
            )
            
            self.logger.info(f"Retrieved {len(formatted_results)} image documents for query: {query[:50]}...")
            
        except Exception as e:
            self.logger.error(f"Failed to retrieve images: {e}")
            self.send_message(
                receiver="agent_manager",
                message_type="image_retrieval_error",
                content={"error": str(e), "query": query},
                manager=message.content.get("manager")
            )


In [None]:
class ReasonerAgent(BaseAgent):
    """Synthesizes information from all sources to reason about the medical question"""
    
    def __init__(self):
        super().__init__("reasoner_agent", "reasoner")
        self.reasoning_prompt = ChatPromptTemplate.from_template("""
        You are a medical reasoning agent. Analyze the following information and provide a comprehensive medical reasoning.
        
        Original Question: {question}
        
        Text Information:
        {text_info}
        
        Image Information:
        {image_info}
        
        Query Analysis:
        {query_analysis}
        
        Based on the provided information, provide:
        1. Key medical findings from the text sources
        2. Relevant visual information from images (if any)
        3. Clinical reasoning and connections
        4. Potential limitations or missing information
        5. Confidence level in the analysis
        
        Be thorough but concise. Focus on medical accuracy and clinical relevance.
        """)
        
    def handle_message(self, message: A2AMessage):
        if message.message_type == "reason_about_query":
            self.reason_about_query(message)
    
    def reason_about_query(self, message: A2AMessage):
        """Synthesize information and provide medical reasoning"""
        content = message.content
        
        try:
            # Format text information
            text_info = ""
            if "text_results" in content:
                text_results = content["text_results"]
                text_info = "\\n\\n".join([
                    f"Source {i+1}: {result['text'][:500]}..." 
                    for i, result in enumerate(text_results)
                ])
            
            # Format image information
            image_info = ""
            if "image_results" in content:
                image_results = content["image_results"]
                image_info = "\\n\\n".join([
                    f"Image {i+1}: {result['image_path']} (Relevance: {result['score']:.3f})" 
                    for i, result in enumerate(image_results)
                ])
            
            # Create reasoning chain
            chain = self.reasoning_prompt | self.llm | StrOutputParser()
            
            reasoning = chain.invoke({
                "question": content.get("original_query", ""),
                "text_info": text_info,
                "image_info": image_info,
                "query_analysis": content.get("query_analysis", {})
            })
            
            # Send reasoning to draft agent
            self.send_message(
                receiver="draft_agent",
                message_type="generate_response",
                content={
                    "original_query": content.get("original_query", ""),
                    "reasoning": reasoning,
                    "text_results": content.get("text_results", []),
                    "image_results": content.get("image_results", []),
                    "query_analysis": content.get("query_analysis", {})
                },
                manager=content.get("manager")
            )
            
            self.logger.info("Completed medical reasoning analysis")
            
        except Exception as e:
            self.logger.error(f"Failed to reason about query: {e}")
            self.send_message(
                receiver="draft_agent",
                message_type="reasoning_error",
                content={"error": str(e), "original_query": content.get("original_query", "")},
                manager=content.get("manager")
            )

class DraftAgent(BaseAgent):
    """Generates final responses for doctors"""
    
    def __init__(self):
        super().__init__("draft_agent", "draft")
        self.draft_prompt = ChatPromptTemplate.from_template("""
        You are a medical AI assistant providing information to doctors. Generate a comprehensive, accurate, and clinically relevant response.
        
        Original Question: {question}
        
        Medical Reasoning:
        {reasoning}
        
        Supporting Evidence:
        Text Sources: {text_evidence}
        Image Sources: {image_evidence}
        
        Instructions:
        1. Provide a clear, structured answer to the medical question
        2. Include relevant clinical details and context
        3. Cite specific information from the sources when appropriate
        4. Note any limitations or areas where more information might be needed
        5. Use professional medical language appropriate for healthcare providers
        6. Include disclaimers about the limitations of AI-generated medical information
        
        Format your response as a professional medical consultation note.
        """)
        
    def handle_message(self, message: A2AMessage):
        if message.message_type == "generate_response":
            self.generate_response(message)
        elif message.message_type == "reasoning_error":
            self.handle_reasoning_error(message)
    
    def generate_response(self, message: A2AMessage):
        """Generate the final medical response"""
        content = message.content
        
        try:
            # Format text evidence
            text_evidence = ""
            if content.get("text_results"):
                text_evidence = "\\n\\n".join([
                    f"• {result['text'][:200]}... (Source: {result.get('metadata', {}).get('source', 'Unknown')})" 
                    for result in content["text_results"][:3]  # Limit to top 3
                ])
            
            # Format image evidence
            image_evidence = ""
            if content.get("image_results"):
                image_evidence = "\\n".join([
                    f"• {result['image_path']} (Relevance: {result['score']:.3f})" 
                    for result in content["image_results"][:3]  # Limit to top 3
                ])
            
            # Generate response
            chain = self.draft_prompt | self.llm | StrOutputParser()
            
            response = chain.invoke({
                "question": content.get("original_query", ""),
                "reasoning": content.get("reasoning", ""),
                "text_evidence": text_evidence,
                "image_evidence": image_evidence
            })
            
            # Send final response to agent manager
            self.send_message(
                receiver="agent_manager",
                message_type="final_response",
                content={
                    "original_query": content.get("original_query", ""),
                    "response": response,
                    "timestamp": datetime.now().isoformat(),
                    "sources_used": {
                        "text_count": len(content.get("text_results", [])),
                        "image_count": len(content.get("image_results", []))
                    }
                },
                manager=content.get("manager")
            )
            
            self.logger.info("Generated final medical response")
            
        except Exception as e:
            self.logger.error(f"Failed to generate response: {e}")
            self.send_message(
                receiver="agent_manager",
                message_type="response_error",
                content={"error": str(e), "original_query": content.get("original_query", "")},
                manager=content.get("manager")
            )
    
    def handle_reasoning_error(self, message: A2AMessage):
        """Handle cases where reasoning failed"""
        content = message.content
        
        error_response = f"""
        I apologize, but I encountered an error while processing your medical question: "{content.get('original_query', '')}"
        
        Error: {content.get('error', 'Unknown error')}
        
        Please try rephrasing your question or contact support if the issue persists.
        """
        
        self.send_message(
            receiver="agent_manager",
            message_type="final_response",
            content={
                "original_query": content.get("original_query", ""),
                "response": error_response,
                "timestamp": datetime.now().isoformat(),
                "error": True
            },
            manager=content.get("manager")
        )


## Agent Manager


In [None]:
class AgentManager:
    """Manages all agents and coordinates the medical knowledge system"""
    
    def __init__(self):
        self.logger = logging.getLogger("agent_manager")
        
        # Initialize components
        self.vector_store_manager = VectorStoreManager()
        self.text_embedder = TextEmbedder()
        self.image_embedder = ImageEmbedder()
        
        # Initialize agents
        self.planning_agent = PlanningAgent()
        self.text_retriever = TextRetrieverAgent(self.vector_store_manager, self.text_embedder)
        self.image_retriever = ImageRetrieverAgent(self.vector_store_manager, self.image_embedder)
        self.reasoner_agent = ReasonerAgent()
        self.draft_agent = DraftAgent()
        
        # Agent registry
        self.agents = {
            "planning_agent": self.planning_agent,
            "text_retriever": self.text_retriever,
            "image_retriever": self.image_retriever,
            "reasoner_agent": self.reasoner_agent,
            "draft_agent": self.draft_agent
        }
        
        # State tracking
        self.active_queries = {}
        self.query_results = {}
        
        self.logger.info("Agent Manager initialized with all agents")
    
    def deliver_message(self, message: A2AMessage):
        """Deliver a message to the appropriate agent"""
        if message.receiver in self.agents:
            self.agents[message.receiver].receive_message(message)
        else:
            self.logger.warning(f"Unknown agent: {message.receiver}")
    
    def process_query(self, query: str) -> Dict[str, Any]:
        """Process a medical query through the agent system"""
        query_id = str(uuid.uuid4())
        self.logger.info(f"Processing query {query_id}: {query[:50]}...")
        
        # Initialize query state
        self.active_queries[query_id] = {
            "query": query,
            "status": "processing",
            "start_time": datetime.now(),
            "results": {}
        }
        
        try:
            # Step 1: Send query to planning agent
            self.planning_agent.send_message(
                receiver="planning_agent",
                message_type="analyze_query",
                content={
                    "query": query,
                    "query_id": query_id,
                    "manager": self
                },
                manager=self
            )
            
            # Process messages until we get a final response
            max_iterations = 20
            iteration = 0
            
            while iteration < max_iterations:
                # Process all agent mailboxes
                for agent in self.agents.values():
                    agent.process_messages()
                
                # Check if we have a final response
                if query_id in self.query_results:
                    break
                    
                iteration += 1
                if iteration >= max_iterations:
                    self.logger.warning(f"Query {query_id} timed out after {max_iterations} iterations")
                    break
            
            # Return results
            if query_id in self.query_results:
                result = self.query_results[query_id]
                result["query_id"] = query_id
                result["processing_time"] = (datetime.now() - self.active_queries[query_id]["start_time"]).total_seconds()
                return result
            else:
                return {
                    "query_id": query_id,
                    "error": "Query processing timed out",
                    "status": "error"
                }
                
        except Exception as e:
            self.logger.error(f"Error processing query {query_id}: {e}")
            return {
                "query_id": query_id,
                "error": str(e),
                "status": "error"
            }
        finally:
            # Clean up
            if query_id in self.active_queries:
                del self.active_queries[query_id]
    
    def handle_message(self, message: A2AMessage):
        """Handle messages from agents"""
        if message.receiver == "agent_manager":
            if message.message_type == "query_analysis":
                self.handle_query_analysis(message)
            elif message.message_type == "text_results":
                self.handle_text_results(message)
            elif message.message_type == "image_results":
                self.handle_image_results(message)
            elif message.message_type == "final_response":
                self.handle_final_response(message)
            else:
                self.logger.warning(f"Unknown message type: {message.message_type}")
    
    def handle_query_analysis(self, message: A2AMessage):
        """Handle query analysis from planning agent"""
        content = message.content
        query_id = content.get("query_id")
        analysis = content.get("analysis", {})
        
        if query_id not in self.active_queries:
            return
            
        self.active_queries[query_id]["results"]["analysis"] = analysis
        
        # Determine what retrievers to activate
        if analysis.get("needs_text", False):
            self.text_retriever.send_message(
                receiver="text_retriever",
                message_type="retrieve_text",
                content={
                    "query": self.active_queries[query_id]["query"],
                    "query_id": query_id,
                    "top_k": 5,
                    "manager": self
                },
                manager=self
            )
        
        if analysis.get("needs_images", False):
            self.image_retriever.send_message(
                receiver="image_retriever",
                message_type="retrieve_images",
                content={
                    "query": self.active_queries[query_id]["query"],
                    "query_id": query_id,
                    "top_k": 3,
                    "manager": self
                },
                manager=self
            )
        
        # If no retrievers needed, send directly to reasoner
        if not analysis.get("needs_text", False) and not analysis.get("needs_images", False):
            self.reasoner_agent.send_message(
                receiver="reasoner_agent",
                message_type="reason_about_query",
                content={
                    "original_query": self.active_queries[query_id]["query"],
                    "query_id": query_id,
                    "query_analysis": analysis,
                    "text_results": [],
                    "image_results": [],
                    "manager": self
                },
                manager=self
            )
    
    def handle_text_results(self, message: A2AMessage):
        """Handle text retrieval results"""
        content = message.content
        query_id = content.get("query_id")
        
        if query_id not in self.active_queries:
            return
            
        self.active_queries[query_id]["results"]["text_results"] = content.get("results", [])
        self.check_and_trigger_reasoner(query_id)
    
    def handle_image_results(self, message: A2AMessage):
        """Handle image retrieval results"""
        content = message.content
        query_id = content.get("query_id")
        
        if query_id not in self.active_queries:
            return
            
        self.active_queries[query_id]["results"]["image_results"] = content.get("results", [])
        self.check_and_trigger_reasoner(query_id)
    
    def check_and_trigger_reasoner(self, query_id: str):
        """Check if we have all needed results and trigger reasoner"""
        if query_id not in self.active_queries:
            return
            
        query_data = self.active_queries[query_id]
        results = query_data["results"]
        analysis = results.get("analysis", {})
        
        # Check if we have all needed data
        needs_text = analysis.get("needs_text", False)
        needs_images = analysis.get("needs_images", False)
        
        has_text = "text_results" in results
        has_images = "image_results" in results
        
        # Trigger reasoner if we have all needed data
        if (not needs_text or has_text) and (not needs_images or has_images):
            self.reasoner_agent.send_message(
                receiver="reasoner_agent",
                message_type="reason_about_query",
                content={
                    "original_query": query_data["query"],
                    "query_id": query_id,
                    "query_analysis": analysis,
                    "text_results": results.get("text_results", []),
                    "image_results": results.get("image_results", []),
                    "manager": self
                },
                manager=self
            )
    
    def handle_final_response(self, message: A2AMessage):
        """Handle final response from draft agent"""
        content = message.content
        query_id = content.get("query_id")
        
        if query_id not in self.active_queries:
            return
            
        # Store final result
        self.query_results[query_id] = {
            "query": content.get("original_query", ""),
            "response": content.get("response", ""),
            "timestamp": content.get("timestamp"),
            "sources_used": content.get("sources_used", {}),
            "status": "completed"
        }
        
        self.logger.info(f"Query {query_id} completed successfully")


## Data Ingestion and Main System


In [None]:
class DocumentProcessor:
    """Handles document loading, chunking, and processing for medical documents"""
    
    def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\\n\\n", "\\n", ". ", " ", ""]
        )
        self.logger = logging.getLogger("document_processor")
    
    def process_pdf_directory(self, directory_path: str) -> Tuple[List[str], List[Dict[str, Any]]]:
        """Process all PDFs in a directory"""
        self.logger.info(f"Processing PDFs in directory: {directory_path}")
        
        # Load all PDFs
        loader = DirectoryLoader(
            directory_path,
            glob="**/*.pdf",
            loader_cls=PyPDFLoader
        )
        documents = loader.load()
        
        # Split into chunks
        chunks = self.text_splitter.split_documents(documents)
        
        # Extract texts and metadata
        texts = []
        metadata = []
        
        for chunk in chunks:
            texts.append(chunk.page_content)
            metadata.append({
                "source": chunk.metadata.get("source", "unknown"),
                "page": chunk.metadata.get("page", 0),
                "chunk_id": str(uuid.uuid4())
            })
        
        self.logger.info(f"Processed {len(texts)} text chunks from {len(documents)} documents")
        return texts, metadata
    
    def extract_images_from_pdfs(self, directory_path: str) -> Tuple[List[str], List[Dict[str, Any]]]:
        """Extract images from PDFs (simplified implementation)"""
        # This is a simplified implementation
        # In practice, you'd use libraries like PyMuPDF or pdf2image
        self.logger.info(f"Extracting images from PDFs in: {directory_path}")
        
        # For now, return empty lists - you can implement actual image extraction
        # using libraries like PyMuPDF, pdf2image, or similar
        image_paths = []
        image_metadata = []
        
        self.logger.warning("Image extraction not implemented - returning empty results")
        return image_paths, image_metadata

class MedicalKnowledgeSystem:
    """Main system class that orchestrates the entire medical knowledge system"""
    
    def __init__(self):
        self.logger = logging.getLogger("medical_system")
        
        # Initialize components
        self.vector_store_manager = VectorStoreManager()
        self.text_embedder = TextEmbedder()
        self.image_embedder = ImageEmbedder()
        self.document_processor = DocumentProcessor()
        self.agent_manager = AgentManager()
        
        # System state
        self.is_initialized = False
        self.data_ingested = False
        
        self.logger.info("Medical Knowledge System initialized")
    
    def setup_vector_store(self) -> bool:
        """Set up the Milvus vector store"""
        try:
            success = self.vector_store_manager.create_collections()
            if success:
                self.is_initialized = True
                self.logger.info("Vector store setup completed")
            else:
                self.logger.error("Failed to setup vector store")
            return success
        except Exception as e:
            self.logger.error(f"Error setting up vector store: {e}")
            return False
    
    def ingest_documents(self, data_directory: str) -> bool:
        """Ingest medical documents into the system"""
        if not self.is_initialized:
            self.logger.error("Vector store not initialized. Call setup_vector_store() first.")
            return False
        
        try:
            # Process text documents
            self.logger.info("Processing text documents...")
            texts, text_metadata = self.document_processor.process_pdf_directory(data_directory)
            
            if texts:
                # Embed texts
                text_embeddings = self.text_embedder.embed_texts(texts)
                
                # Insert into vector store
                success = self.vector_store_manager.insert_texts(texts, text_embeddings, text_metadata)
                if not success:
                    self.logger.error("Failed to insert text documents")
                    return False
                
                self.logger.info(f"Successfully ingested {len(texts)} text chunks")
            
            # Process images (if any)
            self.logger.info("Processing images...")
            image_paths, image_metadata = self.document_processor.extract_images_from_pdfs(data_directory)
            
            if image_paths:
                # Embed images
                image_embeddings = self.image_embedder.embed_images(image_paths)
                
                # Insert into vector store
                success = self.vector_store_manager.insert_images(image_paths, image_embeddings, image_metadata)
                if not success:
                    self.logger.error("Failed to insert image documents")
                    return False
                
                self.logger.info(f"Successfully ingested {len(image_paths)} images")
            
            self.data_ingested = True
            self.logger.info("Document ingestion completed successfully")
            return True
            
        except Exception as e:
            self.logger.error(f"Error ingesting documents: {e}")
            return False
    
    def ask_question(self, question: str) -> Dict[str, Any]:
        """Ask a medical question to the system"""
        if not self.is_initialized:
            return {"error": "System not initialized. Call setup_vector_store() first."}
        
        if not self.data_ingested:
            return {"error": "No data ingested. Call ingest_documents() first."}
        
        try:
            # Process the query through the agent system
            result = self.agent_manager.process_query(question)
            return result
            
        except Exception as e:
            self.logger.error(f"Error processing question: {e}")
            return {"error": str(e), "status": "error"}
    
    def get_system_status(self) -> Dict[str, Any]:
        """Get the current status of the system"""
        return {
            "initialized": self.is_initialized,
            "data_ingested": self.data_ingested,
            "vector_store_connected": self.vector_store_manager.connect(),
            "agents_ready": len(self.agent_manager.agents),
            "timestamp": datetime.now().isoformat()
        }


## Usage Examples and Testing


In [None]:
# Initialize the medical knowledge system
medical_system = MedicalKnowledgeSystem()

# Check system status
print("System Status:")
print(json.dumps(medical_system.get_system_status(), indent=2))


In [None]:
# Setup Milvus vector store
print("Setting up vector store...")
setup_success = medical_system.setup_vector_store()

if setup_success:
    print("✅ Vector store setup successful!")
else:
    print("❌ Vector store setup failed. Make sure Milvus is running.")
    print("To start Milvus: docker run -d --name milvus-standalone -p 19530:19530 milvusdb/milvus:latest")


In [None]:
# Ingest medical documents
# Replace with your actual data directory path
data_directory = "./FINAL DATASET/"  # Update this path to your data directory

print(f"Ingesting documents from: {data_directory}")
ingestion_success = medical_system.ingest_documents(data_directory)

if ingestion_success:
    print("✅ Document ingestion successful!")
    print("System is ready to answer medical questions.")
else:
    print("❌ Document ingestion failed.")
    print("Please check that the data directory exists and contains PDF files.")


In [None]:
# Test the system with medical questions
test_questions = [
    "What are the symptoms of pneumonia?",
    "How is diabetes diagnosed?",
    "What are the treatment options for hypertension?",
    "Describe the anatomy of the heart",
    "What are the side effects of common antibiotics?"
]

print("Testing the medical knowledge system...")
print("=" * 50)

for i, question in enumerate(test_questions, 1):
    print(f"\\nQuestion {i}: {question}")
    print("-" * 30)
    
    try:
        result = medical_system.ask_question(question)
        
        if "error" in result:
            print(f"❌ Error: {result['error']}")
        else:
            print(f"✅ Response:")
            print(result.get("response", "No response generated"))
            print(f"\\nSources used: {result.get('sources_used', {})}")
            print(f"Processing time: {result.get('processing_time', 0):.2f} seconds")
            
    except Exception as e:
        print(f"❌ Exception: {e}")
    
    print("=" * 50)


## Interactive Usage

You can now ask questions interactively to the medical knowledge system. The system will:

1. **Analyze your question** using the Planning Agent to determine what information is needed
2. **Retrieve relevant text** from medical textbooks and research papers
3. **Retrieve relevant images** from medical documents (when available)
4. **Reason about the information** using the Reasoner Agent to synthesize findings
5. **Generate a comprehensive response** using the Draft Agent

### Example Usage:

```python
# Ask a specific medical question
question = "What are the early signs of myocardial infarction?"
result = medical_system.ask_question(question)

print(f"Question: {question}")
print(f"Answer: {result['response']}")
print(f"Sources: {result['sources_used']}")
```

### System Features:

- **Multi-Agent Architecture**: Each agent has a specific role in processing medical queries
- **A2A Communication**: Agents communicate through structured messages
- **MCP Integration**: Model Context Protocol for standardized communication
- **Vector Search**: Semantic search through medical documents using embeddings
- **Multi-modal Support**: Handles both text and image information
- **Medical Reasoning**: Specialized reasoning for medical domain knowledge


In [None]:
# Interactive testing - ask your own medical questions here
# Uncomment and modify the question below to test the system

# question = "What are the treatment protocols for sepsis?"
# result = medical_system.ask_question(question)
# print(f"Question: {question}")
# print(f"Answer: {result['response']}")
# print(f"Sources: {result['sources_used']}")

print("System ready for interactive use!")
print("Uncomment the code above to ask your own medical questions.")
