In [None]:
# Install required packages if needed
!pip install python-dotenv ipywidgets pydantic-ai fastmcp openai nest-asyncio folium matplotlib markdown aiohttp

In [None]:
from dotenv import load_dotenv
import os
from fastmcp import Client
import asyncio

# Load environment variables from .env file
load_dotenv()

# Base configuration for CLM GeoServer
BASE_CONFIG = {
    "wcs_base_url": "https://sparcal.sdsc.edu/geoserver",
    "wfs_base_url": "https://sparcal.sdsc.edu/geoserver/boundary/wfs",
    "feature_id": "boundary:ca_counties",
    "filter_column": "name"
}

# Initialize CLM MCP client
mcp_client = Client("https://wenokn.fastmcp.app/mcp")
print("‚úì CLM MCP client configured")

# Check API keys
openai_key = os.getenv('OPENAI_API_KEY')
nrp_key = os.getenv('NRP_API_KEY')
dc_key = os.getenv('DC_API_KEY')

if not openai_key:
    print("‚ö†Ô∏è Warning: OPENAI_API_KEY not set!")
else:
    print("‚úì OpenAI API key found")

if not nrp_key:
    print("‚ö†Ô∏è Warning: NRP_API_KEY not set!")
else:
    print("‚úì NRP API key found")
    
if not dc_key:
    print("‚ö†Ô∏è Warning: DC_API_KEY not set! Get one from https://datacommons.org/")
else:
    print("‚úì Data Commons API key found")

# Choose your model
MODEL = "openai"  # "openai" or "nrp"

if MODEL == "openai":
    print("‚úì Using OpenAI GPT-4o-mini")
elif MODEL == "nrp":
    print("‚úì Using NRP Qwen3")

In [None]:
from dataclasses import dataclass
from typing import Optional, List
import asyncio
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from fastmcp import Client
import os

@dataclass
class AgentContext:
    """Context to store discovered dataset information."""
    current_coverage_id: Optional[str] = None
    current_dataset_info: Optional[dict] = None

BASE_CONFIG = {
    "wcs_base_url": "https://sparcal.sdsc.edu/geoserver",
    "wfs_base_url": "https://sparcal.sdsc.edu/geoserver/boundary/wfs",
    "feature_id": "boundary:ca_counties",
    "filter_column": "name"
}

def get_model_config(model_name: str = "openai"):
    """
    Get model configuration based on model name.
    
    Args:
        model_name: Either "openai" or "nrp"
    
    Returns:
        Model configuration for pydantic-ai
    """
    if model_name == "nrp":
        # For NRP, we need to set environment variables before creating the model
        # Store original values to restore later if needed
        original_base_url = os.environ.get('OPENAI_BASE_URL')
        original_api_key = os.environ.get('OPENAI_API_KEY')
        
        # Set NRP-specific configuration
        os.environ['OPENAI_BASE_URL'] = 'https://ellm.nrp-nautilus.io/v1'
        os.environ['OPENAI_API_KEY'] = os.getenv('NRP_API_KEY', '')
        
        # Use the string format which pydantic-ai will parse
        return 'openai:qwen3'
    else:
        # Restore default OpenAI settings
        if 'OPENAI_BASE_URL' in os.environ:
            del os.environ['OPENAI_BASE_URL']
        # Restore original OpenAI key
        if os.getenv('OPENAI_API_KEY_ORIGINAL'):
            os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY_ORIGINAL')
        
        return 'openai:gpt-4o-mini'

def create_internal_agent(model_name: str = "openai"):
    """
    Create and configure the Pydantic AI agent.
    
    Args:
        model_name: Either "openai" for GPT-4o-mini or "nrp" for Qwen3
    """
    
    model_config = get_model_config(model_name)
    
    agent = Agent(
        model=model_config,
        deps_type=AgentContext,
        retries=2,
        system_prompt="""You are an expert in analyzing California Landscape Metrics datasets.

The input you receive will be the full conversation history in the format:
User: <question1>
Assistant: <answer1>
User: <question2>
Assistant: <answer2>
...
User: <current question>

Use this history to understand the context for follow-up questions.

You have access to these tools:
1. search_and_select_dataset: Search for the most relevant dataset based on the question
2. get_county_statistics: Compute statistics for one or more counties
3. get_area_above_threshold: Calculate percentage/area above a threshold for one or more counties
4. get_area_below_threshold: Calculate percentage/area below a threshold for one or more counties
5. show_map: Display a map visualization of the current dataset
6. get_value_distribution: Get value distribution for one or more counties (for charts/histograms)

**CRITICAL RULE FOR SEARCH:**
When calling search_and_select_dataset, construct a search query that represents the current user's question.
If the current question is a follow-up (e.g., "How about San Diego?"), incorporate context from the history to make it a standalone query (e.g., "What is the average carbon turnover time in San Diego?").
DO NOT pass vague follow-up phrases directly; always create a meaningful, context-aware query.
DO NOT use the entire history as the query.
When calling search_and_select_dataset, please remove the terms related to statistics like "average" and "mean" and the place names from the question like 'San Diego' or 'Los Angeles'. 
For example, for the question "What is the average carbon turnover time in Los Angeles?", use "carbon turnover time" to call search_and_select_dataset.

**WORKFLOW:**
1. If needed, start by calling search_and_select_dataset with the constructed query.
   - If a similar dataset was used in previous interactions (based on history), you may skip searching if it's clearly the same topic.
2. Once the dataset is selected, use the other tools to answer the question.
3. Include the dataset name and units in your final answer for context.

**For statistical questions:**
- Use get_county_statistics for mean, median, min, max, std
- Pass counties=None to get all counties
- For rankings, get all counties and sort results

**For threshold questions:**
- Use get_area_above_threshold or get_area_below_threshold
- Pass counties=None to get all counties
- For questions about many/all counties, you can use counties=None, but if concerned about timeouts, first get list of counties from get_county_statistics and sample 10-20

**For map visualization requests:**
- When users ask to "show the map", "visualize", "display map", or similar requests, use the show_map tool
- You must have a dataset selected first (via search_and_select_dataset)
- The tool will return a special marker that the interface will use to render the map

**For distribution/histogram requests:**
- When users ask to "show distribution", "show histogram", "compare distributions", "data distribution", or similar requests, use the get_value_distribution tool
- You must have a dataset selected first (via search_and_select_dataset)
- CRITICAL: For comparing distributions between regions, pass ALL county names as a LIST in a SINGLE call to get_value_distribution
  Example: get_value_distribution(counties=["San Diego", "Los Angeles"]) - NOT separate calls
- For single region, pass one county name as a string or single-item list
- For overview of all regions, pass counties=None (but this may be slow for many counties)
- The tool will return distribution data that the interface will visualize as charts
- DO NOT call get_value_distribution multiple times - always use ONE call with a list of counties

**Answer format:**
- Be precise with numbers and include units from the dataset
- Provide clear, concise answers
- Mention the dataset being used
- For distribution requests, briefly describe what the chart shows

""",
    )

    @agent.tool
    async def search_and_select_dataset(
        ctx: RunContext[AgentContext],
        question: str,
        top_k: int = 3
    ) -> dict:
        """Search for and select the most relevant dataset."""
        async with mcp_client:
            result = await mcp_client.call_tool(
                "search_datasets",
                {"query": question, "top_k": top_k}
            )
            
            data = result.data
            if data.get('success') and data.get('datasets'):
                best_dataset = data['datasets'][0]
                ctx.deps.current_coverage_id = best_dataset['wcs_coverage_id']
                ctx.deps.current_dataset_info = best_dataset
                
                return {
                    'success': True,
                    'selected_dataset': best_dataset,
                    'alternatives': data['datasets'][1:] if len(data['datasets']) > 1 else [],
                    'message': f"Selected: {best_dataset['title']}"
                }
            else:
                return {
                    'success': False,
                    'message': 'No suitable datasets found',
                    'error': data.get('error', 'Unknown error')
                }

    @agent.tool
    async def get_county_statistics(
        ctx: RunContext[AgentContext],
        counties: Optional[List[str]] = None,
        stats: List[str] = None
    ) -> dict:
        """Get statistics for one or more counties."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        if stats is None:
            stats = ["mean", "median", "min", "max", "std"]
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "compute_zonal_stats",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": counties,
                    "stats": stats,
                    "max_workers": 16
                }
            )
            
            response = result.data
            if response.get('success'):
                response['dataset_info'] = {
                    'title': ctx.deps.current_dataset_info.get('title'),
                    'units': ctx.deps.current_dataset_info.get('data_units')
                }
            return response

    @agent.tool
    async def get_area_above_threshold(
        ctx: RunContext[AgentContext],
        counties: Optional[List[str]] = None,
        threshold: float = 100.0
    ) -> dict:
        """Calculate the percentage and area above a threshold for one or more counties."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "zonal_count",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": counties,
                    "threshold": threshold,
                    "max_workers": 16
                }
            )
            
            zonal_data = result.data
            if not zonal_data.get('success'):
                return zonal_data
            
            processed = []
            for stats in zonal_data['data']:
                county_name = stats[BASE_CONFIG['filter_column']]
                valid = stats['valid_pixels']
                above = stats['above_threshold_pixels']
                pixel_area = stats['pixel_area_square_meters']
                
                percentage = (above / valid * 100) if valid > 0 else 0
                area_sq_m = above * pixel_area
                area_sq_km = area_sq_m / 1_000_000
                
                processed.append({
                    'county': county_name,
                    'threshold': threshold,
                    'valid_pixels': valid,
                    'above_threshold_pixels': above,
                    'percentage': percentage,
                    'area_square_meters': area_sq_m,
                    'area_square_km': area_sq_km
                })
            
            return {
                'success': True,
                'data': processed,
                'total_features': zonal_data['total_features'],
                'processed_features': zonal_data['processed_features'],
                'dataset_info': {
                    'title': ctx.deps.current_dataset_info.get('title'),
                    'units': ctx.deps.current_dataset_info.get('data_units')
                }
            }

    @agent.tool
    async def get_area_below_threshold(
        ctx: RunContext[AgentContext],
        counties: Optional[List[str]] = None,
        threshold: float = 100.0
    ) -> dict:
        """Calculate the percentage and area below a threshold for one or more counties."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "zonal_count",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": counties,
                    "threshold": threshold,
                    "max_workers": 16
                }
            )
            
            zonal_data = result.data
            if not zonal_data.get('success'):
                return zonal_data
            
            processed = []
            for stats in zonal_data['data']:
                county_name = stats[BASE_CONFIG['filter_column']]
                valid = stats['valid_pixels']
                above = stats['above_threshold_pixels']
                below = valid - above
                pixel_area = stats['pixel_area_square_meters']
                
                percentage = (below / valid * 100) if valid > 0 else 0
                area_sq_m = below * pixel_area
                area_sq_km = area_sq_m / 1_000_000
                
                processed.append({
                    'county': county_name,
                    'threshold': threshold,
                    'valid_pixels': valid,
                    'below_threshold_pixels': below,
                    'percentage': percentage,
                    'area_square_meters': area_sq_m,
                    'area_square_km': area_sq_km
                })
            
            return {
                'success': True,
                'data': processed,
                'total_features': zonal_data['total_features'],
                'processed_features': zonal_data['processed_features'],
                'dataset_info': {
                    'title': ctx.deps.current_dataset_info.get('title'),
                    'units': ctx.deps.current_dataset_info.get('data_units')
                }
            }
    
    @agent.tool
    async def show_map(ctx: RunContext[AgentContext]) -> dict:
        """Display a map visualization of the current dataset using WMS layer."""
        if not ctx.deps.current_dataset_info:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        dataset = ctx.deps.current_dataset_info
        
        return {
            'success': True,
            'action': 'show_map',
            'map_data': {
                'wms_base_url': dataset.get('wms_base_url'),
                'wms_layer_name': dataset.get('wms_layer_name'),
                'title': dataset.get('title'),
                'description': dataset.get('description', ''),
                'units': dataset.get('data_units', '')
            },
            'message': f"Displaying map for: {dataset.get('title')}"
        }
    
    @agent.tool
    async def get_value_distribution(
        ctx: RunContext[AgentContext],
        counties: Optional[List[str]] = None,
        num_bins: int = 10
    ) -> dict:
        """Get value distribution for one or more counties to create histograms/charts."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "zonal_distribution",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": counties,
                    "num_bins": num_bins,
                    "global_bins": True,
                    "categorical_threshold": 20,
                    "max_workers": 16
                }
            )
            
            dist_data = result.data
            if dist_data.get('success'):
                dist_data['dataset_info'] = {
                    'title': ctx.deps.current_dataset_info.get('title'),
                    'units': ctx.deps.current_dataset_info.get('data_units')
                }
                dist_data['action'] = 'show_distribution'
            
            return dist_data
    
    return agent

class HistoryAwareAgent:
    def __init__(self, model_name: str = "openai"):
        """
        Initialize the agent with a specific model.
        
        Args:
            model_name: Either "openai" or "nrp"
        """
        self.model_name = model_name
        self.internal_agent = create_internal_agent(model_name)
        self.history = []

    async def run(self, question: str, timeout: int = 300, deps: Optional[AgentContext] = None) -> dict:
        """Run the agent with history-aware input."""
        full_input = "\n".join(self.history) + (f"\nUser: {question}" if self.history else f"User: {question}")

        try:
            result = await asyncio.wait_for(
                self.internal_agent.run(full_input, deps=deps),
                timeout=timeout
            )
            output = result.output if hasattr(result, 'output') else str(result)
            
            map_data = None
            distribution_data = None
            
            if hasattr(result, 'all_messages'):
                for msg in result.all_messages():
                    if hasattr(msg, 'parts'):
                        for part in msg.parts:
                            if hasattr(part, 'content') and isinstance(part.content, dict):
                                if part.content.get('action') == 'show_map':
                                    map_data = part.content.get('map_data')
                                elif part.content.get('action') == 'show_distribution':
                                    new_dist = part.content
                                    if distribution_data is None or len(new_dist.get('data', [])) > len(distribution_data.get('data', [])):
                                        distribution_data = new_dist

            self.history.append(f"User: {question}")
            self.history.append(f"Assistant: {output}")

            return {
                'output': output,
                'map_data': map_data,
                'distribution_data': distribution_data
            }
        except asyncio.TimeoutError:
            return {
                'output': f"Error: Question timed out after {timeout} seconds.",
                'map_data': None,
                'distribution_data': None
            }
        except Exception as e:
            return {
                'output': f"Error: {type(e).__name__}: {str(e)[:200]}",
                'map_data': None,
                'distribution_data': None
            }

# Create agents for both models
def create_agent(model_name: str = "openai"):
    """
    Factory function to create an agent with specified model.
    
    Args:
        model_name: Either "openai" or "nrp"
    
    Returns:
        HistoryAwareAgent instance
    """
    return HistoryAwareAgent(model_name=model_name)

# Default agent (OpenAI)
clm_agent = create_agent(MODEL)
print(f"‚úì Agent created successfully with {MODEL}!")

In [None]:
import aiohttp
import json

# MCP Configuration for Data Commons
MCP_URL = "http://localhost:3000/mcp"
MCP_HEADERS = {
    "Content-Type": "application/json",
    "Accept": "application/json, text/event-stream"  # Important for SSE support!
}

async def parse_sse_response(response) -> list:
    """Parse Server-Sent Events (SSE) response."""
    messages = []
    buffer = ""
    
    async for line in response.content:
        line = line.decode('utf-8').strip()
        if not line:
            continue
        if line.startswith("data: "):
            buffer += line[6:] + "\n"
        elif line == "data: [DONE]":
            if buffer.strip():
                try:
                    messages.append(json.loads(buffer.strip()))
                except json.JSONDecodeError:
                    pass
            break
    
    # Handle last message if no [DONE] marker
    if buffer.strip():
        try:
            messages.append(json.loads(buffer.strip()))
        except json.JSONDecodeError:
            pass
    
    return messages

async def check_dc_mcp_server():
    """Check if Data Commons MCP server is running and responding correctly."""
    try:
        async with aiohttp.ClientSession() as session:
            # Try to initialize connection
            init_request = {
                "jsonrpc": "2.0",
                "id": 1,
                "method": "initialize",
                "params": {
                    "protocolVersion": "2024-11-05",
                    "capabilities": {},
                    "clientInfo": {
                        "name": "notebook-client",
                        "version": "1.0.0"
                    }
                }
            }
            
            async with session.post(
                MCP_URL,
                json=init_request,
                headers=MCP_HEADERS
            ) as resp:
                if resp.status == 200:
                    # Parse SSE response
                    messages = await parse_sse_response(resp)
                    
                    if messages and 'result' in messages[0]:
                        result = messages[0]['result']
                        server_name = result.get('serverInfo', {}).get('name', 'Unknown')
                        protocol = result.get('protocolVersion', 'Unknown')
                        
                        print("‚úì Data Commons MCP server is running!")
                        print(f"  Server: {server_name}")
                        print(f"  Protocol: {protocol}")
                        print(f"  Endpoint: {MCP_URL}")
                        return True
                    else:
                        print("‚ö†Ô∏è Server responded but no valid initialization data")
                        return False
                else:
                    print(f"‚ö†Ô∏è DC MCP server returned status {resp.status}")
                    print(f"   Make sure to use correct headers (SSE support)")
                    return False
                    
    except aiohttp.ClientConnectorError:
        print("‚ùå Cannot connect to Data Commons MCP server!")
        print("   Start it with: uv tool run datacommons-mcp serve http --port 3000")
        return False
    except Exception as e:
        print(f"‚ùå Error checking DC MCP server: {e}")
        import traceback
        traceback.print_exc()
        return False

# Quick test to list available tools
async def list_dc_tools():
    """List available Data Commons tools."""
    try:
        async with aiohttp.ClientSession() as session:
            # First initialize
            init_request = {
                "jsonrpc": "2.0",
                "id": 1,
                "method": "initialize",
                "params": {
                    "protocolVersion": "2024-11-05",
                    "capabilities": {},
                    "clientInfo": {"name": "notebook-client", "version": "1.0.0"}
                }
            }
            async with session.post(MCP_URL, json=init_request, headers=MCP_HEADERS) as resp:
                if resp.status != 200:
                    return
                await parse_sse_response(resp)  # Drain the response
            
            # Then list tools
            tools_request = {
                "jsonrpc": "2.0",
                "id": 2,
                "method": "tools/list",
                "params": {}
            }
            
            async with session.post(MCP_URL, json=tools_request, headers=MCP_HEADERS) as resp:
                if resp.status == 200:
                    messages = await parse_sse_response(resp)
                    if messages and 'result' in messages[0]:
                        tools = messages[0]['result'].get('tools', [])
                        print(f"\n‚úì Found {len(tools)} available tools:")
                        for tool in tools:
                            print(f"  - {tool['name']}: {tool.get('description', 'No description')[:60]}...")
                        return True
    except Exception as e:
        print(f"Could not list tools: {e}")
        return False

# Run checks
print("Checking Data Commons MCP Server...")
server_ok = await check_dc_mcp_server()

if server_ok:
    await list_dc_tools()
else:
    print("\n‚ö†Ô∏è Please ensure the server is running before proceeding.")
    print("   In a terminal, run:")
    print("   uv tool run datacommons-mcp serve http --port 3000")

In [None]:
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
import json

# MCP Configuration
MCP_PORT = 3000
MCP_URL = f"http://localhost:{MCP_PORT}/mcp"
MCP_HEADERS = {
    "Content-Type": "application/json",
    "Accept": "application/json, text/event-stream",
}

class AgentDeps(BaseModel):
    """Dependencies for Data Commons agent."""
    llm_api_key: str = Field(description="API key for the LLM endpoint")
    dc_api_key: str = Field(description="Data Commons API key")

class DataCommonsMCPClient:
    """Client for Data Commons MCP server."""
    
    def __init__(self, url: str = MCP_URL):
        self.url = url
        self.session: Optional[aiohttp.ClientSession] = None

    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self

    async def __aexit__(self, *exc):
        if self.session:
            await self.session.close()

    async def _parse_sse_response(self, resp: aiohttp.ClientResponse) -> List[Dict]:
        """Parse server-sent events response."""
        messages: List[Dict] = []
        buffer = ""

        async for line in resp.content:
            line = line.decode("utf-8").strip()
            if not line:
                continue
            if line.startswith("data: "):
                buffer += line[6:] + "\n"
            elif line == "data: [DONE]":
                if buffer.strip():
                    try:
                        messages.append(json.loads(buffer.strip()))
                    except json.JSONDecodeError:
                        pass
                buffer = ""
                break

        # Handle last message if no [DONE] marker
        if buffer.strip():
            try:
                messages.append(json.loads(buffer.strip()))
            except json.JSONDecodeError:
                pass

        return messages

    async def _rpc(self, payload: Dict) -> List[Dict]:
        """Make RPC call to MCP server."""
        async with self.session.post(self.url, json=payload, headers=MCP_HEADERS) as resp:
            if resp.status != 200:
                txt = await resp.text()
                raise RuntimeError(f"MCP HTTP {resp.status}: {txt}")
            return await self._parse_sse_response(resp)

    async def initialize(self) -> Dict:
        """Initialize MCP connection."""
        payload = {
            "jsonrpc": "2.0",
            "id": 1,
            "method": "initialize",
            "params": {
                "protocolVersion": "2024-11-05",
                "capabilities": {},
                "clientInfo": {"name": "pydantic-ai-agent", "version": "1.0"},
            },
        }
        msgs = await self._rpc(payload)
        return msgs[0] if msgs else {}

    async def call_tool(self, name: str, arguments: Dict) -> str:
        """Call a tool on the MCP server."""
        payload = {
            "jsonrpc": "2.0",
            "id": 3,
            "method": "tools/call",
            "params": {"name": name, "arguments": arguments},
        }
        msgs = await self._rpc(payload)
        if not msgs:
            raise RuntimeError("Empty MCP response")
        result = msgs[0].get("result", {})
        text_parts = [
            block.get("text", "")
            for block in result.get("content", [])
            if block.get("type") == "text"
        ]
        return "\n".join(text_parts) or json.dumps(result, indent=2)

# Tool wrapper functions
async def search_indicators(
    ctx: RunContext[AgentDeps],
    query: str,
    places: Optional[List[str]] = None,
    parent_place: Optional[str] = None,
    include_topics: bool = True,
    maybe_bilateral: bool = False,
) -> str:
    """Search for indicators in Data Commons."""
    async with DataCommonsMCPClient() as client:
        await client.initialize()
        args: Dict[str, Any] = {
            "query": query,
            "include_topics": include_topics,
            "maybe_bilateral": maybe_bilateral,
        }
        if places:
            args["places"] = places
        if parent_place:
            args["parent_place"] = parent_place
        return await client.call_tool("search_indicators", args)

async def get_observations(
    ctx: RunContext[AgentDeps],
    variable_dcid: str,
    place_dcid: str,
    child_place_type: Optional[str] = None,
    date: str = "latest",
) -> str:
    """Get observations from Data Commons."""
    async with DataCommonsMCPClient() as client:
        await client.initialize()
        args: Dict[str, Any] = {
            "variable_dcid": variable_dcid,
            "place_dcid": place_dcid,
            "date": date,
        }
        if child_place_type:
            args["child_place_type"] = child_place_type
        return await client.call_tool("get_observations", args)

def create_datacommons_agent() -> Agent:
    """Create Data Commons agent."""
    model = OpenAIChatModel(model_name="gpt-4o-mini")

    agent = Agent(
        model=model,
        deps_type=AgentDeps,
        system_prompt="""You are a precise data analyst using Google Data Commons.

Rules:
1. Always qualify place names: "San Diego, CA, USA"
2. Use `search_indicators` ‚Üí find `variable_dcid`
3. Then `get_observations` with that DCID
4. For counties/states: sample 5 diverse places first
5. Use `date="latest"` unless specified
6. Cite source
""",
        tools=[search_indicators, get_observations],
    )
    return agent

async def run_dc_query(agent: Agent, query: str, deps: AgentDeps, verbose: bool = False) -> str:
    """Run a query against Data Commons agent."""
    if verbose:
        print(f"DC Query: {query}")
    
    try:
        result = await agent.run(query, deps=deps)
        resp = result.output
        if verbose:
            print(f"DC Response: {resp}")
        return resp
    except Exception as e:
        msg = f"Data Commons agent error: {e}"
        if verbose:
            print(msg)
        return msg

# Create Data Commons agent
dc_agent = create_datacommons_agent()
dc_deps = AgentDeps(
    llm_api_key=os.getenv('OPENAI_API_KEY'),
    dc_api_key=os.getenv('DC_API_KEY')
)
print("‚úì Data Commons Agent created successfully!")

In [None]:
"""
Parallel Agent Coordination System - Complete Replacement
Drop-in replacement for EnhancedCoordinatedAgentSystem
"""

from dataclasses import dataclass
from typing import Optional, Dict, Any, List
import asyncio
import re
import json
import os
from pydantic_ai import Agent

@dataclass
class ParallelContext:
    """Context for the parallel evaluator agent."""
    question: str
    agent_responses: List[Dict[str, Any]]

class ParallelCoordinatedSystem:
    """
    Parallel execution system - queries all agents concurrently,
    then intelligently combines results.
    
    This replaces EnhancedCoordinatedAgentSystem with a more robust approach.
    """
    
    def __init__(self, clm_agent, dc_agent, dc_deps, model_name: str = "openai"):
        """
        Initialize parallel coordination system.
        
        Args:
            clm_agent: CLM agent instance
            dc_agent: DC agent instance  
            dc_deps: Dependencies for DC agent
            model_name: Model to use ("openai" or "nrp")
        """
        self.clm_agent = clm_agent
        self.dc_agent = dc_agent
        self.dc_deps = dc_deps
        self.model_name = model_name
        self.conversation_history = []
        self.evaluator = self._create_evaluator(model_name)
    
    def _create_evaluator(self, model_name: str) -> Agent:
        """
        Create evaluator agent that combines responses from multiple agents.
        Much simpler than trying to route beforehand!
        """
        
        if model_name == "nrp":
            os.environ['OPENAI_BASE_URL'] = 'https://ellm.nrp-nautilus.io/v1'
            os.environ['OPENAI_API_KEY'] = os.getenv('NRP_API_KEY', '')
            model_config = 'openai:qwen3'
        else:
            if 'OPENAI_BASE_URL' in os.environ:
                del os.environ['OPENAI_BASE_URL']
            model_config = 'openai:gpt-4o-mini'
        
        evaluator = Agent(
            model=model_config,
            deps_type=ParallelContext,
            system_prompt="""You are a response evaluator for a multi-agent system.

You receive a question and responses from multiple specialized agents.

**Agent Capabilities:**

**CLM Agent**: 
- California landscape/environmental data (30m x 30m resolution)
- 189 datasets: air quality, biodiversity, carbon, fire, water, poverty, unemployment
- Provides: county statistics, spatial distributions, maps, threshold analysis
- Data format: Each value represents a 30m x 30m grid/pixel, NOT individual people
- For California ONLY

**DC Agent**:
- Global demographic/economic/social data
- Any location worldwide
- Provides: aggregated totals, population counts, overall percentages
- Data format: Total counts of people, aggregate statistics
- Single aggregated value per region

**CRITICAL DATA INTERPRETATION:**

**CLM distributions** (e.g., unemployment, poverty):
- Shows spatial patterns of RATES across geographic area
- Each value = one 30m x 30m PIXEL/GRID, NOT individual people
- Statistics are SPATIAL averages (mean across pixels), NOT population-weighted
- Example: "Average unemployment 56.9%" means the mean rate across all pixels
  - This is NOT the actual unemployment rate for the county
  - A pixel in a desert (1 person, 100% unemployed) counts the same as a pixel in a city (10,000 people, 0% unemployed)
  - This shows geographic spread, not demographic reality
- Distribution counts: "0-10%: 473,178" = 473,178 pixels have 0-10% unemployment rate
- NEVER say "individuals" or "people" - always say "pixels" or "grids" or "30m x 30m areas"
- ALWAYS mention the dataset name/title when reporting CLM data
- ALWAYS clarify that CLM statistics are spatial averages, not population totals

**DC data**:
- Shows actual population counts and totals
- "36,980 unemployed" means 36,980 actual people
- Unemployment rate is the actual demographic rate for the population
- Can say "individuals" or "people" for DC data

**Your Task:**
1. Identify which agent(s) provided useful, data-backed responses
2. Detect failed responses: "unable to find", "no data", "error", vague answers
3. Combine or select the best response(s)
4. **CRITICALLY**: 
   - Rewrite CLM responses to use correct terminology
   - Always include CLM dataset name
   - Explain that CLM provides spatial patterns, not population rates
   - Make clear distinction between spatial average and demographic rate

**Question Type Analysis:**
- "unemployment in [place]" ‚Üí Both agents can answer (DC=actual rate, CLM=spatial pattern)
- "unemployment distribution" ‚Üí PREFER CLM (spatial analysis capability)
- "how many unemployed" ‚Üí PREFER DC (population counts)
- "which counties have highest" ‚Üí PREFER CLM (county comparisons)
- Questions about California topics in CLM's 189 datasets ‚Üí PREFER CLM for spatial analysis

**Response Format (JSON):**
{
    "selected_agents": ["clm", "dc"],
    "strategy": "use_clm" | "use_dc" | "combine_both" | "prefer_clm" | "prefer_dc",
    "reasoning": "brief explanation",
    "combined_response": "your final answer - MUST include dataset name and clarify spatial vs demographic"
}

**Strategy Guidelines:**
- "use_clm": Only CLM provided useful data OR question specifically asks for distribution/spatial analysis
- "use_dc": Only DC provided useful data OR question asks for total counts/actual rates
- "combine_both": Both provide complementary info (CLM=spatial pattern, DC=actual rate/totals)
- "prefer_clm": Question asks for spatial patterns, distributions, or geographic variation
- "prefer_dc": Question asks for actual unemployment rate or population counts

**Example Response Formats:**

For "unemployment in San Diego":
WRONG: "Average unemployment is 56.9%"
RIGHT: "Using the CLM 'Unemployment' dataset, the spatial average across 30m pixels in San Diego County is 56.9%. This represents the geographic spread of unemployment rates, not the actual county unemployment rate. The actual unemployment rate (from DC/BLS) is 4.9% as of August 2025."

For "distribution of unemployment":
WRONG: "0-10%: 473,178 individuals"  
RIGHT: "Using the CLM 'Unemployment' dataset, the spatial distribution shows:
- 0-10% unemployment rate: 473,178 pixels (30m x 30m grid cells)
- 10-20%: 818,158 pixels
This shows where geographically unemployment rates are high or low, not population counts."

**Important Rules:**
1. ALWAYS mention CLM dataset name (e.g., "CLM 'Unemployment' dataset", "CLM 'Poverty' dataset")
2. ALWAYS clarify spatial average vs population rate when combining CLM and DC
3. If question asks "distribution" ‚Üí PREFER CLM (spatial analysis)
4. If question asks "rate" without "distribution" ‚Üí PREFER DC (actual demographic rate)
5. When using CLM data, explain it shows geographic patterns, not demographic totals
"""
        )
        return evaluator
    
    def _build_history_context(self) -> str:
        """Build a summary of recent conversation history."""
        if not self.conversation_history:
            return ""
        
        # Get last 3 interactions for context
        recent = self.conversation_history[-3:]
        context_lines = []
        for entry in recent:
            context_lines.append(f"Q: {entry['question']}")
            # Truncate long responses
            response = entry.get('final_response', '')
            if response and len(response) > 150:
                response = response[:150] + "..."
            context_lines.append(f"A: {response}\n")
        
        return "Recent conversation:\n" + "\n".join(context_lines)
    
    def _enhance_question_with_context(self, question: str) -> str:
        """
        Enhance follow-up questions with context from conversation history.
        
        Examples:
        - "What about Los Angeles?" ‚Üí "What is the population of Los Angeles, CA, USA?"
        - "How about San Diego?" ‚Üí "What is the carbon turnover time in San Diego?"
        """
        # Check if this looks like a follow-up question
        follow_up_patterns = [
            r'^(what about|how about|and|also)\s+',
            r'^\w+\s+(too|also)\??$',
            r'^(there|that)\??$'
        ]
        
        is_follow_up = any(re.search(pattern, question.lower()) for pattern in follow_up_patterns)
        
        if not is_follow_up or not self.conversation_history:
            return question
        
        # Get the last question to understand context
        last_entry = self.conversation_history[-1]
        last_question = last_entry['question']
        
        # Extract the topic/metric from the last question
        context_hint = ""
        
        # Common patterns to extract
        if 'population' in last_question.lower():
            context_hint = "population of"
        elif 'income' in last_question.lower() or 'median household income' in last_question.lower():
            context_hint = "median household income in"
        elif 'carbon' in last_question.lower():
            context_hint = "carbon turnover time in"
        elif 'burn' in last_question.lower() or 'fire' in last_question.lower():
            context_hint = "burn probability in"
        elif 'unemployment' in last_question.lower():
            context_hint = "unemployment rate in"
        elif 'crime' in last_question.lower():
            context_hint = "crime rate in"
        elif 'poverty' in last_question.lower():
            context_hint = "poverty rate in"
        
        # Extract location from new question
        # Remove common follow-up words
        clean_question = re.sub(r'^(what about|how about|and|also)\s+', '', question, flags=re.IGNORECASE).strip()
        
        if context_hint and clean_question:
            # Add proper location qualifier for Data Commons
            if not any(suffix in clean_question.lower() for suffix in [', ca', ', usa', 'california']):
                # Assume California city if not specified
                clean_question = f"{clean_question}, CA, USA"
            enhanced = f"What is the {context_hint} {clean_question}?"
            return enhanced
        
        return question
    
    async def _query_clm_safe(self, question: str, timeout: int = 120) -> Dict[str, Any]:
        """
        Query CLM agent with error handling.
        """
        try:
            from dataclasses import dataclass as dc
            
            @dc
            class AgentContext:
                current_coverage_id: Optional[str] = None
                current_dataset_info: Optional[dict] = None
            
            ctx = AgentContext()
            
            result = await asyncio.wait_for(
                self.clm_agent.run(question, timeout=timeout, deps=ctx),
                timeout=timeout + 5
            )
            
            output = result.get('output', '')
            
            # Check for failure indicators
            if isinstance(output, str):
                failure_indicators = [
                    'unable to find',
                    'no data',
                    'could not find',
                    'no suitable datasets',
                    'error:',
                    'cannot answer',
                    'no dataset selected'
                ]
                is_failure = any(indicator in output.lower() for indicator in failure_indicators)
                
                # CRITICAL: Pass through ALL data from CLM agent, including distribution_data
                return {
                    'agent_name': 'clm',
                    'success': not is_failure,
                    'response': result,  # Keep the full result object
                    'output': output,
                    'map_data': result.get('map_data'),
                    'distribution_data': result.get('distribution_data'),  # Pass through distribution data
                    'error': None
                }
            
            # CRITICAL: Pass through ALL data from CLM agent
            return {
                'agent_name': 'clm',
                'success': True,
                'response': result,  # Keep the full result object
                'output': output,
                'map_data': result.get('map_data'),
                'distribution_data': result.get('distribution_data'),  # Pass through distribution data
                'error': None
            }
            
        except asyncio.TimeoutError:
            return {
                'agent_name': 'clm',
                'success': False,
                'response': None,
                'output': f"CLM agent timed out after {timeout} seconds",
                'map_data': None,
                'distribution_data': None,
                'error': 'timeout'
            }
        except Exception as e:
            return {
                'agent_name': 'clm',
                'success': False,
                'response': None,
                'output': f"CLM agent error: {str(e)}",
                'map_data': None,
                'distribution_data': None,
                'error': str(e)
            }
    
    async def _query_dc_safe(self, question: str, timeout: int = 120) -> Dict[str, Any]:
        """
        Query DC agent with error handling.
        """
        try:
            # Import the run_dc_query function (assumes it's in scope)
            dc_result = await asyncio.wait_for(
                run_dc_query(self.dc_agent, question, self.dc_deps, verbose=False),
                timeout=timeout
            )
            
            # Check for failure indicators
            if isinstance(dc_result, str):
                failure_indicators = [
                    'unable to find',
                    'no data',
                    'could not find',
                    'returned no current observations',
                    'error',
                    'cannot answer'
                ]
                is_failure = any(indicator in dc_result.lower() for indicator in failure_indicators)
                
                return {
                    'agent_name': 'dc',
                    'success': not is_failure,
                    'response': None,
                    'output': dc_result,
                    'error': None
                }
            
            return {
                'agent_name': 'dc',
                'success': True,
                'response': None,
                'output': str(dc_result),
                'error': None
            }
            
        except asyncio.TimeoutError:
            return {
                'agent_name': 'dc',
                'success': False,
                'response': None,
                'output': f"DC agent timed out after {timeout} seconds",
                'error': 'timeout'
            }
        except Exception as e:
            return {
                'agent_name': 'dc',
                'success': False,
                'response': None,
                'output': f"DC agent error: {str(e)}",
                'error': str(e)
            }
    
    async def run(self, question: str, timeout: int = 120) -> Dict[str, Any]:
        """
        Query all agents in parallel and intelligently combine results.
        
        Args:
            question: User's question
            timeout: Timeout per agent in seconds
            
        Returns:
            Dict with output, map_data, distribution_data, and metadata
        """
        
        # Build history context
        history_context = self._build_history_context()
        
        # Enhance follow-up questions
        enhanced_question = self._enhance_question_with_context(question)
        
        if enhanced_question != question:
            print(f"üìù Enhanced question: {enhanced_question}")
        
        print(f"üîÑ Querying CLM and DC agents in parallel...")
        
        # Query both agents concurrently
        clm_task = self._query_clm_safe(question, timeout)  # Use original for CLM
        dc_task = self._query_dc_safe(enhanced_question, timeout)  # Use enhanced for DC
        
        agent_results = await asyncio.gather(clm_task, dc_task)
        
        # Show what each agent returned
        for result in agent_results:
            status = "‚úì" if result['success'] else "‚úó"
            output_preview = result['output'][:100] if result['output'] else "No output"
            print(f"  {status} {result['agent_name'].upper()}: {output_preview}...")
        
        # Build evaluation context
        eval_input = f"""Question: {question}

Agent Responses:

**CLM Agent:**
Success: {agent_results[0]['success']}
Response: {agent_results[0]['output']}

**DC Agent:**
Success: {agent_results[1]['success']}
Response: {agent_results[1]['output']}

Based on these responses, evaluate which agent(s) successfully answered the question and provide a combined response.
"""
        
        # Evaluate and combine responses
        try:
            eval_result = await self.evaluator.run(
                eval_input,
                deps=ParallelContext(question=question, agent_responses=agent_results)
            )
            
            eval_output = eval_result.output if hasattr(eval_result, 'output') else str(eval_result)
            
            # Parse JSON from evaluation
            json_match = re.search(r'\{[^}]+\}', eval_output, re.DOTALL)
            if json_match:
                evaluation = json.loads(json_match.group())
            else:
                # Fallback: use first successful response
                successful = [r for r in agent_results if r['success']]
                if successful:
                    evaluation = {
                        'selected_agents': [successful[0]['agent_name']],
                        'strategy': f"use_{successful[0]['agent_name']}",
                        'reasoning': 'Fallback: using first successful response',
                        'combined_response': successful[0]['output']
                    }
                else:
                    evaluation = {
                        'selected_agents': [],
                        'strategy': 'none',
                        'reasoning': 'No agent provided useful response',
                        'combined_response': 'I could not find information to answer this question. Both agents were unable to retrieve relevant data.'
                    }
            
            print(f"üìä Strategy: {evaluation['strategy']}")
            print(f"   Selected: {', '.join(evaluation['selected_agents'])}")
            print(f"   Reasoning: {evaluation['reasoning']}")
            
            # Collect map and distribution data from successful agents
            # CRITICAL: Don't filter by success - get data from ANY agent that has it
            map_data = None
            distribution_data = None
            
            for result in agent_results:
                # Get map_data from any agent that has it
                if result.get('map_data'):
                    map_data = result['map_data']
                    print(f"  Found map_data from {result['agent_name']}")
                
                # Get distribution_data from any agent that has it
                if result.get('distribution_data'):
                    dist = result['distribution_data']
                    # Always take the distribution with the most data points
                    if distribution_data is None or len(dist.get('data', [])) > len(distribution_data.get('data', [])):
                        distribution_data = dist
                        print(f"  Found distribution_data from {result['agent_name']} with {len(dist.get('data', []))} data points")
            
            # Debug: Print what we collected
            if distribution_data:
                print(f"  Final distribution_data has {len(distribution_data.get('data', []))} data points")
                # Check how many unique counties
                if 'data' in distribution_data:
                    counties_in_dist = set([d.get('name') for d in distribution_data['data'] if 'name' in d])
                    print(f"  Counties in distribution: {counties_in_dist}")
            else:
                print("  No distribution_data collected")
            
            # Store in history
            self.conversation_history.append({
                'question': question,
                'enhanced_question': enhanced_question if enhanced_question != question else None,
                'agent_results': agent_results,
                'evaluation': evaluation,
                'final_response': evaluation['combined_response']
            })
            
            return {
                'output': evaluation['combined_response'],
                'map_data': map_data,
                'distribution_data': distribution_data,
                'routing': {
                    'agent': evaluation['strategy'],
                    'confidence': 1.0 if evaluation['selected_agents'] else 0.0,
                    'reasoning': evaluation['reasoning']
                },
                'source_agent': evaluation['strategy'],
                'evaluation': evaluation,
                'all_agent_results': agent_results
            }
            
        except Exception as e:
            import traceback
            error_detail = traceback.format_exc()
            
            # Fallback: return first successful agent response
            successful = [r for r in agent_results if r['success']]
            if successful:
                fallback_output = successful[0]['output']
                fallback_agent = successful[0]['agent_name']
            else:
                fallback_output = f"Error in evaluation: {str(e)}\n\nAgent responses were:\nCLM: {agent_results[0]['output'][:200]}\nDC: {agent_results[1]['output'][:200]}"
                fallback_agent = 'error'
            
            return {
                'output': fallback_output,
                'map_data': agent_results[0].get('map_data') if agent_results[0]['success'] else None,
                'distribution_data': agent_results[0].get('distribution_data') if agent_results[0]['success'] else None,
                'routing': {
                    'agent': fallback_agent,
                    'confidence': 0.5,
                    'reasoning': f'Evaluation error, using fallback: {str(e)}'
                },
                'source_agent': fallback_agent,
                'evaluation': None,
                'all_agent_results': agent_results
            }
    
    def get_history_summary(self) -> str:
        """Get a summary of the conversation history."""
        if not self.conversation_history:
            return "No conversation history yet."
        
        summary = []
        for i, entry in enumerate(self.conversation_history, 1):
            eval_info = entry.get('evaluation', {})
            agents_used = ', '.join(eval_info.get('selected_agents', ['unknown']))
            q = entry['question']
            if entry.get('enhanced_question'):
                q = f"{q} ‚Üí {entry['enhanced_question']}"
            summary.append(f"{i}. [{agents_used.upper()}] {q[:80]}...")
        
        return "\n".join(summary)

print("‚úì Parallel coordination system module loaded!")
print("‚úì Ready to replace EnhancedCoordinatedAgentSystem")

coordinated_system = ParallelCoordinatedSystem(
    clm_agent=clm_agent,
    dc_agent=dc_agent,
    dc_deps=dc_deps,
    model_name=MODEL
)

In [None]:
import asyncio
import os
import nest_asyncio
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
from fastmcp import Client
from typing import Optional
from dataclasses import dataclass
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
from datetime import datetime

# Enable nested asyncio for Jupyter
nest_asyncio.apply()
print("‚úì Imports successful")

In [None]:
"""
Complete Chat Interface for Coordinated Multi-Agent System
Cell 6 - Copy this entire code block
"""

from IPython.display import Javascript, display, clear_output, HTML
import ipywidgets as widgets
from datetime import datetime
import folium
from folium import WmsTileLayer
import json
import matplotlib.pyplot as plt
import io
import base64
import markdown
import html as html_module
import asyncio


class CoordinatedChatInterface:
    """Chat interface that works with the coordinated multi-agent system."""
    
    def __init__(self, coordinated_system):
        """
        Initialize chat interface.
        
        Args:
            coordinated_system: ParallelCoordinatedSystem instance
        """
        self.system = coordinated_system
        self.messages_container = []
        
        # Output area
        self.output_area = widgets.VBox(
            layout=widgets.Layout(
                border='1px solid #ddd',
                height='calc(100vh - 350px)',
                min_height='400px',
                overflow_y='auto',
                padding='10px',
                margin='0 0 10px 0'
            )
        )
        
        # Input controls
        self.input_box = widgets.Textarea(
            placeholder='Ask about California landscape metrics or general data (population, demographics, etc.)...',
            layout=widgets.Layout(width='100%', height='100px', margin='10px 0')
        )
        
        self.send_button = widgets.Button(
            description='Send',
            button_style='primary',
            icon='paper-plane',
            layout=widgets.Layout(width='100px', margin='0 5px 0 0')
        )
        
        self.clear_button = widgets.Button(
            description='Clear',
            button_style='warning',
            icon='trash',
            layout=widgets.Layout(width='100px', margin='0 5px 0 0')
        )
        
        self.history_button = widgets.Button(
            description='History',
            button_style='info',
            icon='list',
            layout=widgets.Layout(width='100px', margin='0 5px 0 0')
        )
        
        self.status_label = widgets.HTML(
            value="‚úÖ Ready",
            layout=widgets.Layout(margin='0 0 0 10px')
        )
        
        # Event handlers
        self.send_button.on_click(self.on_send_clicked)
        self.clear_button.on_click(self.on_clear_clicked)
        self.history_button.on_click(self.on_history_clicked)
        
        # Layout
        button_box = widgets.HBox([
            self.send_button,
            self.clear_button,
            self.history_button,
            self.status_label
        ])
        
        self.interface = widgets.VBox([
            widgets.HTML(value="""
                <h3>ü§ñ Multi-Agent Assistant</h3>
                <p style='color: #666; font-size: 0.9em;'>
                    <strong>CLM Agent:</strong> California landscape metrics (fire, carbon, vegetation, poverty, unemployment)<br>
                    <strong>DC Agent:</strong> Demographics, economics, health data (any location)
                </p>
            """),
            self.output_area,
            self.input_box,
            button_box
        ], layout=widgets.Layout(width='100%', max_width='1200px', margin='0 auto'))
        
        # Welcome message
        self._add_message(
            "Welcome! I can help with:\n\n"
            "**California Landscape Metrics (CLM):**\n"
            "- Burn probability, carbon turnover, vegetation data\n"
            "- Poverty, unemployment (spatial patterns)\n"
            "- Maps and visualizations\n"
            "- County statistics and comparisons\n\n"
            "**General Data (via Data Commons):**\n"
            "- Population, demographics, income\n"
            "- Any location worldwide\n"
            "- Economic and social indicators\n\n"
            "**Examples:**\n"
            "- What is the carbon turnover time in Los Angeles? (CLM)\n"
            "- What is the population of San Diego? (Data Commons)\n"
            "- Show me a burn probability map for California (CLM)\n"
            "- What is the median income in Austin, Texas? (Data Commons)\n"
            "- Show distribution of unemployment in San Diego and Los Angeles (CLM)",
            "system"
        )
    
    def _get_wms_bounds(self, wms_url, layer_name):
        """Get WMS layer bounds."""
        california_bounds = [[32.5, -124.5], [42.0, -114.0]]
        try:
            import requests
            from xml.etree import ElementTree as ET
            
            params = {
                'service': 'WMS',
                'version': '1.1.0',
                'request': 'GetCapabilities'
            }
            response = requests.get(wms_url + '/wms', params=params, timeout=10)
            
            if response.status_code == 200:
                root = ET.fromstring(response.content)
                for layer in root.iter('Layer'):
                    name_elem = layer.find('Name')
                    if name_elem is not None and name_elem.text == layer_name:
                        bbox = layer.find('LatLonBoundingBox')
                        if bbox is not None:
                            minx = float(bbox.get('minx'))
                            miny = float(bbox.get('miny'))
                            maxx = float(bbox.get('maxx'))
                            maxy = float(bbox.get('maxy'))
                            return [[miny, minx], [maxy, maxx]]
        except:
            pass
        
        return california_bounds
    
    def _get_legend_url(self, wms_url, layer_name, style_name=None):
        """Generate WMS legend URL."""
        from urllib.parse import urlencode
        
        params = {
            'service': 'WMS',
            'version': '1.1.0',
            'request': 'GetLegendGraphic',
            'layer': layer_name,
            'format': 'image/png',
            'width': '20',
            'height': '20',
            'legend_options': 'fontAntiAliasing:true;fontSize:10;fontName:Arial;dx:5;absoluteMargins:true'
        }
        
        if style_name:
            params['style'] = style_name
        
        return f"{wms_url}/wms?{urlencode(params)}"
    
    def _create_map(self, map_data, style_name=None):
        """Create Folium map."""
        try:
            wms_url = map_data.get('wms_base_url', '')
            layer_name = map_data.get('wms_layer_name', '')
            title = map_data.get('title', 'Dataset')
            
            bounds = self._get_wms_bounds(wms_url, layer_name)
            center_lat = (bounds[0][0] + bounds[1][0]) / 2
            center_lon = (bounds[0][1] + bounds[1][1]) / 2
            
            m = folium.Map(
                location=[center_lat, center_lon],
                tiles='OpenStreetMap',
                control_scale=True
            )
            m.fit_bounds(bounds)
            
            if wms_url and layer_name:
                wms_params = {
                    'url': wms_url + '/wms',
                    'layers': layer_name,
                    'name': title,
                    'fmt': 'image/png',
                    'transparent': True,
                    'overlay': True,
                    'control': True,
                    'version': '1.1.0'
                }
                
                if style_name:
                    wms_params['styles'] = style_name
                
                wms = WmsTileLayer(**wms_params)
                wms.add_to(m)
                folium.LayerControl().add_to(m)
                
                legend_url = self._get_legend_url(wms_url, layer_name, style_name)
                units = map_data.get('units', '')
                unit_text = f'<p style="margin: 5px 0 0 0; font-size: 11px; color: #333; text-align: center;">Units: {units}</p>' if units else ''
                
                legend_html = f'''
                <div style="position: fixed; top: 10px; right: 10px; background-color: white; 
                            z-index: 9999; padding: 10px; border: 2px solid grey; border-radius: 5px; 
                            box-shadow: 2px 2px 6px rgba(0,0,0,0.3);">
                    <img src="{legend_url}" alt="Legend" style="display: block;">
                    {unit_text}
                </div>
                '''
                m.get_root().html.add_child(folium.Element(legend_html))
            
            return m
        except Exception as e:
            print(f"Error creating map: {e}")
            return None
    
    def _create_distribution_chart(self, distribution_data):
        """Create distribution chart - FIXED to show all counties."""
        try:
            data = distribution_data.get('data', [])
            dist_type = distribution_data.get('distribution_type', 'continuous')
            dataset_info = distribution_data.get('dataset_info', {})
            title = dataset_info.get('title', 'Value Distribution')
            units = dataset_info.get('units', '')
            filter_column = 'name'  # Column name for county identifier
            
            if not data:
                print("No distribution data to plot")
                return None
            
            # Debug: Print data structure
            print(f"Distribution data type: {dist_type}")
            print(f"Number of data points: {len(data)}")
            print(f"Sample data point: {data[0] if data else 'None'}")
            
            fig, ax = plt.subplots(figsize=(12, 6), dpi=100)
            
            # Get unique counties from the data
            counties = sorted(list(set([d.get(filter_column) for d in data if filter_column in d])))
            
            if not counties:
                print("No counties found in data")
                return None
            
            print(f"Counties found: {counties}")
            
            # Use distinct colors for each county
            colors = plt.cm.tab10(range(len(counties)))
            
            if dist_type == 'categorical':
                import numpy as np
                # Get all unique values across all counties
                values = sorted(list(set([d['value'] for d in data])))
                x = np.arange(len(values))
                width = 0.8 / len(counties) if len(counties) > 1 else 0.5
                
                for i, county in enumerate(counties):
                    county_data = [d for d in data if d.get(filter_column) == county]
                    counts = []
                    for val in values:
                        matching = [d['count'] for d in county_data if d['value'] == val]
                        counts.append(matching[0] if matching else 0)
                    
                    offset = (i - len(counties)/2) * width + width/2
                    ax.bar(x + offset, counts, width, label=county, alpha=0.7, color=colors[i])
                
                ax.set_xlabel('Value', fontsize=11)
                ax.set_ylabel('Count (pixels)', fontsize=11)
                ax.set_xticks(x)
                ax.set_xticklabels([str(v) for v in values])
                ax.legend(fontsize=10, loc='best')
                ax.set_title(f'{title}\nCategorical Distribution', fontsize=12, fontweight='bold', pad=10)
                
            else:  # continuous
                bins = distribution_data.get('bins', [])
                if not bins:
                    print("No bins found in distribution data")
                    return None
                
                print(f"Bins: {bins}")
                
                # Calculate bin centers for x-axis positioning
                bin_centers = [(bins[i] + bins[i+1]) / 2 for i in range(len(bins)-1)]
                bin_width = bins[1] - bins[0] if len(bins) > 1 else 1
                
                # Adjust bar width based on number of counties
                bar_width = bin_width * 0.8 / len(counties) if len(counties) > 1 else bin_width * 0.7
                
                for i, county in enumerate(counties):
                    # Get data for this county
                    county_data = [d for d in data if d.get(filter_column) == county]
                    
                    # Sort by bin_index to ensure correct order
                    county_data = sorted(county_data, key=lambda x: x.get('bin_index', 0))
                    
                    print(f"County {county}: {len(county_data)} data points")
                    
                    # Extract counts
                    counts = [d['count'] for d in county_data]
                    
                    # Calculate bar positions
                    if len(counties) > 1:
                        # Offset bars for multiple counties
                        offset = (i - len(counties)/2) * bar_width + bar_width/2
                        positions = [bc + offset for bc in bin_centers]
                    else:
                        # Center bars for single county
                        positions = bin_centers
                    
                    # Plot bars
                    ax.bar(positions, counts, bar_width, label=county, alpha=0.7, color=colors[i])
                
                # Set labels
                xlabel = f'Value Range ({units})' if units else 'Value Range'
                ax.set_xlabel(xlabel, fontsize=11)
                ax.set_ylabel('Count (pixels)', fontsize=11)
                
                # Add legend if multiple counties
                if len(counties) > 1:
                    ax.legend(fontsize=10, loc='best')
                
                ax.set_title(f'{title}\nValue Distribution', fontsize=12, fontweight='bold', pad=10)
            
            ax.grid(True, alpha=0.3, linestyle='--')
            plt.tight_layout()
            
            # Save to buffer
            buf = io.BytesIO()
            plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
            buf.seek(0)
            img_base64 = base64.b64encode(buf.read()).decode('utf-8')
            plt.close(fig)
            
            print(f"Chart created successfully with {len(counties)} counties")
            return img_base64
            
        except Exception as e:
            print(f"Error creating chart: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def _add_message(self, text, role="user", map_data=None, distribution_data=None, routing_info=None):
        """Add message to chat."""
        timestamp = datetime.now().strftime("%H:%M:%S")
        
        if role == "user":
            color = "#007bff"
            icon = "üë§"
            label = "You"
            bg_color = "#e7f3ff"
        elif role == "assistant":
            color = "#28a745"
            icon = "ü§ñ"
            label = "Assistant"
            bg_color = "#e8f5e9"
        else:
            color = "#6c757d"
            icon = "‚ÑπÔ∏è"
            label = "System"
            bg_color = "#f8f9fa"
        
        # Add routing badge if available
        routing_badge = ""
        if routing_info:
            agent = routing_info.get('agent', 'unknown').upper()
            confidence = routing_info.get('confidence', 0)
            agent_colors = {
                'CLM': '#ff6b6b',
                'DC': '#4ecdc4',
                'BOTH': '#95e1d3',
                'USE_CLM': '#ff6b6b',
                'USE_DC': '#4ecdc4',
                'COMBINE_BOTH': '#95e1d3',
                'PREFER_CLM': '#ff9999',
                'PREFER_DC': '#7dd3c0'
            }
            badge_color = agent_colors.get(agent, '#999')
            routing_badge = f"""
                <span style='background-color: {badge_color}; color: white; padding: 2px 8px; 
                             border-radius: 12px; font-size: 0.75em; font-weight: bold; margin-left: 8px;'>
                    {agent} ({confidence:.0%})
                </span>
            """
        
        # Convert markdown for assistant
        if role == "assistant":
            try:
                html_content = markdown.markdown(str(text), extensions=['extra', 'nl2br', 'sane_lists'])
            except:
                html_content = html_module.escape(str(text)).replace('\n', '<br>')
        else:
            html_content = html_module.escape(str(text)).replace('\n', '<br>')
        
        message_html = widgets.HTML(
            value=f"""
            <div style='margin: 10px 0; padding: 12px; background-color: {bg_color}; 
                        border-radius: 8px; border-left: 4px solid {color}; box-shadow: 0 1px 3px rgba(0,0,0,0.1);'>
                <div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
                    <div>
                        <strong style='color: {color};'>{icon} {label}</strong>
                        {routing_badge}
                    </div>
                    <span style='color: #999; font-size: 0.85em;'>{timestamp}</span>
                </div>
                <div style='line-height: 1.6;'>{html_content}</div>
            </div>
            """
        )
        
        self.messages_container.append(message_html)
        
        # Add visualizations
        if distribution_data:
            img_base64 = self._create_distribution_chart(distribution_data)
            if img_base64:
                chart_html = f"""
                <div style='width: 98%; margin: 10px 0;'>
                    <div style='width: 100%; border: 1px solid #ddd; border-radius: 8px; 
                                padding: 10px; background-color: white;'>
                        <img src="data:image/png;base64,{img_base64}" 
                             style="width: 100%; height: auto;" alt="Distribution Chart">
                    </div>
                </div>
                """
                self.messages_container.append(widgets.HTML(value=chart_html))
        
        elif map_data:
            layer_name = map_data.get('wms_layer_name', '')
            style_name = f"{layer_name}_std" if layer_name else None
            folium_map = self._create_map(map_data, style_name=style_name)
            
            if folium_map:
                map_html = f"""
                <div style='width: 98%; margin: 10px 0;'>
                    <div style='width: 100%; height: 300px; border: 1px solid #ddd; 
                                border-radius: 8px; overflow: hidden;'>
                        {folium_map._repr_html_()}
                    </div>
                </div>
                """
                self.messages_container.append(widgets.HTML(value=map_html))
        
        self.output_area.children = tuple(self.messages_container)
    
    def on_send_clicked(self, button):
        """Handle send button click."""
        question = self.input_box.value.strip()
        if not question:
            return
        
        self._add_message(question, "user")
        self.input_box.value = ""
        self.send_button.disabled = True
        self.input_box.disabled = True
        self.status_label.value = "<span style='color: orange;'>‚è≥ Processing...</span>"
        
        try:
            result = asyncio.get_event_loop().run_until_complete(
                asyncio.wait_for(self.system.run(question), timeout=180)
            )
            
            answer = result.get('output', 'No response')
            routing_info = result.get('routing')
            map_data = result.get('map_data')
            distribution_data = result.get('distribution_data')
            
            self._add_message(
                answer,
                "assistant",
                map_data=map_data,
                distribution_data=distribution_data,
                routing_info=routing_info
            )
            self.status_label.value = "<span style='color: green;'>‚úÖ Ready</span>"
            
        except asyncio.TimeoutError:
            self._add_message("Request timed out after 3 minutes.", "system")
            self.status_label.value = "<span style='color: red;'>‚ùå Timeout</span>"
        except Exception as e:
            import traceback
            error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
            self._add_message(error_msg, "system")
            self.status_label.value = "<span style='color: red;'>‚ùå Error</span>"
        finally:
            self.send_button.disabled = False
            self.input_box.disabled = False
    
    def on_clear_clicked(self, button):
        """Clear chat history."""
        self.messages_container = []
        self.output_area.children = tuple(self.messages_container)
        self._add_message(
            "Chat cleared. Ask about California landscape metrics or general data!",
            "system"
        )
    
    def on_history_clicked(self, button):
        """Show conversation history summary."""
        history = self.system.get_history_summary()
        self._add_message(f"**Conversation History:**\n\n{history}", "system")
    
    def display(self):
        """Display the interface."""
        clear_output(wait=True)
        display(HTML("""
        <style>
            .jp-Cell-outputArea { max-height: none !important; }
            .output_scroll { max-height: none !important; overflow-y: visible !important; }
        </style>
        """))
        display(self.interface)


# Create and display the chat interface
chat = CoordinatedChatInterface(coordinated_system)
chat.display()

print("‚úì Chat interface ready!")