# California Landscape Metrics Analysis - Interactive Chat Interface

This notebook offers an interactive chatbox interface that allows users to query the California Landscape Metrics datasets using the agent.

## Setup and Imports

In [None]:
# Install required packages if needed
!pip install ipywidgets pydantic-ai fastmcp openai nest-asyncio

In [None]:
import asyncio
import os
import nest_asyncio
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
from fastmcp import Client
from typing import Optional
from dataclasses import dataclass
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
from datetime import datetime

# Enable nested asyncio for Jupyter
nest_asyncio.apply()
print("✓ Imports successful")

## Configuration

In [None]:
# Base configuration for GeoServer
BASE_CONFIG = {
    "wcs_base_url": "https://sparcal.sdsc.edu/geoserver",
    "wfs_base_url": "https://sparcal.sdsc.edu/geoserver/boundary/wfs",
    "feature_id": "boundary:ca_counties",
    "filter_column": "name"
}

# Initialize MCP client
mcp_client = Client("https://wenokn.fastmcp.app/mcp")
print("✓ Configuration set")

## Set API Keys

In [None]:
from dotenv import load_dotenv
import os

# Load variables from .env file into environment
load_dotenv()

# Set your OpenAI API key
# os.environ['OPENAI_API_KEY'] = 'Your API KEY'

if not os.getenv('OPENAI_API_KEY'):
    print("⚠️ Warning: OPENAI_API_KEY not set!")
    print("Please set it using: os.environ['OPENAI_API_KEY'] = 'your-key'")
else:
    print("✓ Using OpenAI GPT-4o-mini")

## Agent Setup

In [None]:
@dataclass
class AgentContext:
    """Context to store discovered dataset information."""
    current_coverage_id: Optional[str] = None
    current_dataset_info: Optional[dict] = None


def create_agent():
    """Create and configure the Pydantic AI agent."""
    
    agent = Agent(
        # model='openai:gpt-4o-mini',
        model='openai:gpt-4o-mini',
        deps_type=AgentContext,
        retries=2,
        system_prompt="""You are an expert in analyzing California Landscape Metrics datasets.

You have access to these tools:
1. search_and_select_dataset: Search for the most relevant dataset based on the question
2. get_county_statistics: Compute statistics for one or more counties
3. get_area_above_threshold: Calculate percentage/area above a threshold for a single county
4. get_area_below_threshold: Calculate percentage/area below a threshold for a single county

**CRITICAL RULE:**
When calling search_and_select_dataset, you MUST pass the user's EXACT question word-for-word as the 'question' parameter. 
DO NOT paraphrase, summarize, or modify the user's question in any way.

**WORKFLOW:**
1. ALWAYS start by calling search_and_select_dataset with the user's exact question
2. Once the dataset is selected, use the other tools to answer the question
3. Include the dataset name and units in your final answer for context

**For statistical questions:**
- Use get_county_statistics for mean, median, min, max, std
- Pass counties=None to get all counties
- For rankings, get all counties and sort results

**For threshold questions:**
- Use get_area_above_threshold or get_area_below_threshold
- These work on single counties only

**Answer format:**
- Be precise with numbers and include units from the dataset
- Provide clear, concise answers
- Mention the dataset being used
""",
    )

    @agent.tool
    async def search_and_select_dataset(
        ctx: RunContext[AgentContext],
        question: str,
        top_k: int = 3
    ) -> dict:
        """Search for and select the most relevant dataset."""
        async with mcp_client:
            result = await mcp_client.call_tool(
                "search_datasets",
                {"query": question, "top_k": top_k}
            )
            
            data = result.data
            if data.get('success') and data.get('datasets'):
                best_dataset = data['datasets'][0]
                ctx.deps.current_coverage_id = best_dataset['wcs_coverage_id']
                ctx.deps.current_dataset_info = best_dataset
                
                return {
                    'success': True,
                    'selected_dataset': best_dataset,
                    'alternatives': data['datasets'][1:] if len(data['datasets']) > 1 else [],
                    'message': f"Selected: {best_dataset['title']}"
                }
            else:
                return {
                    'success': False,
                    'message': 'No suitable datasets found',
                    'error': data.get('error', 'Unknown error')
                }

    @agent.tool
    async def get_county_statistics(
        ctx: RunContext[AgentContext],
        counties: Optional[list[str]] = None,
        stats: list[str] = None
    ) -> dict:
        """Get statistics for one or more counties."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        if stats is None:
            stats = ["mean", "median", "min", "max", "std"]
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "compute_zonal_stats",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": counties,
                    "stats": stats,
                    "max_workers": 8
                }
            )
            
            response = result.data
            if response.get('success'):
                response['dataset_info'] = {
                    'title': ctx.deps.current_dataset_info.get('title'),
                    'units': ctx.deps.current_dataset_info.get('data_units')
                }
            return response

    @agent.tool
    async def get_area_above_threshold(
        ctx: RunContext[AgentContext],
        county: str,
        threshold: float
    ) -> dict:
        """Calculate percentage and area above a threshold for a single county."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "zonal_count",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": county,
                    "threshold": threshold
                }
            )
            
            data = result.data
            if data.get('success'):
                stats = data['data']
                valid = stats['valid_pixels']
                above = stats['above_threshold_pixels']
                pixel_area = stats['pixel_area_square_meters']
                
                percentage = (above / valid * 100) if valid > 0 else 0
                area_sq_km = (above * pixel_area) / 1_000_000
                
                return {
                    'success': True,
                    'county': county,
                    'threshold': threshold,
                    'percentage': percentage,
                    'area_square_km': area_sq_km,
                    'dataset_info': {
                        'title': ctx.deps.current_dataset_info.get('title'),
                        'units': ctx.deps.current_dataset_info.get('data_units')
                    }
                }
            return data

    @agent.tool
    async def get_area_below_threshold(
        ctx: RunContext[AgentContext],
        county: str,
        threshold: float
    ) -> dict:
        """Calculate percentage and area below a threshold for a single county."""
        if not ctx.deps.current_coverage_id:
            return {
                'success': False,
                'error': 'No dataset selected. Call search_and_select_dataset first.'
            }
        
        async with mcp_client:
            result = await mcp_client.call_tool(
                "zonal_count",
                {
                    **BASE_CONFIG,
                    "wcs_coverage_id": ctx.deps.current_coverage_id,
                    "filter_value": county,
                    "threshold": threshold
                }
            )
            
            data = result.data
            if data.get('success'):
                stats = data['data']
                valid = stats['valid_pixels']
                above = stats['above_threshold_pixels']
                below = valid - above
                pixel_area = stats['pixel_area_square_meters']
                
                percentage = (below / valid * 100) if valid > 0 else 0
                area_sq_km = (below * pixel_area) / 1_000_000
                
                return {
                    'success': True,
                    'county': county,
                    'threshold': threshold,
                    'percentage': percentage,
                    'area_square_km': area_sq_km,
                    'dataset_info': {
                        'title': ctx.deps.current_dataset_info.get('title'),
                        'units': ctx.deps.current_dataset_info.get('data_units')
                    }
                }
            return data
    
    return agent


# Create the agent
agent = create_agent()
print("✓ Agent created successfully!")

## Chat Interface

In [None]:
from IPython.display import Javascript

class ChatInterface:
    def __init__(self, agent):
        self.agent = agent
        self.messages_container = []
        
        # Output area that takes available space
        self.output_area = widgets.VBox(
            layout=widgets.Layout(
                border='1px solid #ddd',
                height='calc(100vh - 350px)',  # Dynamic height based on viewport
                min_height='400px',
                overflow_y='auto',
                padding='10px',
                margin='0 0 10px 0'
            )
        )
        
        self.input_box = widgets.Textarea(
            placeholder='Ask a question about California Landscape Metrics...',
            layout=widgets.Layout(width='100%', height='100px', margin='10px 0')
        )
        
        self.send_button = widgets.Button(
            description='Send',
            button_style='primary',
            icon='paper-plane',
            layout=widgets.Layout(width='100px', margin='0 5px 0 0')
        )
        
        self.clear_button = widgets.Button(
            description='Clear',
            button_style='warning',
            icon='trash',
            layout=widgets.Layout(width='100px', margin='0 5px 0 0')
        )
        
        self.status_label = widgets.HTML(
            value="✅ Ready",
            layout=widgets.Layout(margin='0 0 0 10px')
        )
        
        # Add button click handlers
        self.send_button.on_click(self.on_send_clicked)
        self.clear_button.on_click(self.on_clear_clicked)
        
        button_box = widgets.HBox([
            self.send_button, 
            self.clear_button, 
            self.status_label
        ])
        
        # Hint text
        hint = widgets.HTML(
            value="<small style='color: #666;'>💡 Tip: Click 'Send' button or use the button to submit your question</small>",
            layout=widgets.Layout(margin='5px 0')
        )
        
        self.interface = widgets.VBox([
            widgets.HTML(value="<h3>🌲 California Landscape Metrics Chat</h3>"),
            self.output_area,
            # hint,
            self.input_box,
            button_box
        ], layout=widgets.Layout(width='100%', max_width='1200px', margin='0 auto'))
        
        # Add welcome message
        self._add_message(
            "Welcome! Ask me about California landscape metrics.\n\n"
            "Examples:\n"
            "- What is the average carbon turnover time in Los Angeles?\n"
            "- Find the maximum annual burn probability in San Diego county",
            "system"
        )
    
    def _add_message(self, text, role="user"):
        timestamp = datetime.now().strftime("%H:%M:%S")
        
        if role == "user":
            color = "#007bff"
            icon = "👤"
            label = "You"
            bg_color = "#e7f3ff"
        elif role == "assistant":
            color = "#28a745"
            icon = "🤖"
            label = "Assistant"
            bg_color = "#e8f5e9"
        else:
            color = "#6c757d"
            icon = "ℹ️"
            label = "System"
            bg_color = "#f8f9fa"
        
        # Escape HTML in text content
        import html
        escaped_text = html.escape(str(text))
        
        message_html = widgets.HTML(
            value=f"""
            <div style='margin: 10px 0; padding: 12px; background-color: {bg_color}; 
                        border-radius: 8px; border-left: 4px solid {color}; box-shadow: 0 1px 3px rgba(0,0,0,0.1);'>
                <div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
                    <strong style='color: {color};'>{icon} {label}</strong>
                    <span style='color: #999; font-size: 0.85em;'>{timestamp}</span>
                </div>
                <div style='white-space: pre-wrap; word-wrap: break-word; line-height: 1.5;'>{escaped_text}</div>
            </div>
            """
        )
        
        self.messages_container.append(message_html)
        self.output_area.children = tuple(self.messages_container)
        
        # Auto-scroll to bottom by updating a dummy widget
        # This is a workaround to trigger a re-render
        self.output_area.layout.height = self.output_area.layout.height
    
    def on_send_clicked(self, button):
        question = self.input_box.value.strip()
        if not question:
            return
        
        self._add_message(question, "user")
        self.input_box.value = ""
        self.send_button.disabled = True
        self.input_box.disabled = True
        self.status_label.value = "<span style='color: orange;'>⏳ Processing...</span>"
        
        try:
            ctx = AgentContext()
            result = asyncio.get_event_loop().run_until_complete(
                asyncio.wait_for(self.agent.run(question, deps=ctx), timeout=180)
            )
            
            answer = result.output if hasattr(result, 'output') else str(result)
            self._add_message(answer, "assistant")
            self.status_label.value = "<span style='color: green;'>✅ Ready</span>"
            
        except asyncio.TimeoutError:
            error_msg = "Request timed out after 3 minutes. Please try a simpler question or try again."
            self._add_message(error_msg, "system")
            self.status_label.value = "<span style='color: red;'>❌ Timeout</span>"
            
        except Exception as e:
            import traceback
            error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
            self._add_message(error_msg, "system")
            self.status_label.value = "<span style='color: red;'>❌ Error</span>"
        
        finally:
            self.send_button.disabled = False
            self.input_box.disabled = False
    
    def on_clear_clicked(self, button):
        self.messages_container = []
        self.output_area.children = tuple(self.messages_container)
        self._add_message(
            "Chat cleared. Ready for new questions!\n\n"
            "Examples:\n"
            "- What is the average carbon turnover time in Los Angeles?\n"
            "- Find the maximum annual burn probability in San Diego county",
            "system"
        )
    
    def display(self):
        # Clear any previous output in the cell
        clear_output(wait=True)
        
        # Display the interface
        display(HTML("""
        <style>
            /* Make the cell output area expand */
            .jp-Cell-outputArea {
                max-height: none !important;
            }
            .output_scroll {
                max-height: none !important;
                overflow-y: visible !important;
            }
        </style>
        """))
        
        display(self.interface)


# Create and display chat interface
chat = ChatInterface(agent)
chat.display()