Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,22 @@ asyncio_mode = "auto"
[tool.ruff.lint]
select = [
"COM", # flake8-commas
"D", # pydocstyle
"F", # pyflakes
"I", # isort
"RUF", # ruff-specific
"UP", # pyupgrade
]
extend-select = [
"D213", # Summary lines should be positioned on the second physical line of the docstring.
"D410", # A blank line after section headings.
]
ignore = [
"D205", # 1 blank line required between summary line and description
"D212", # Multi-line docstring summary should start at the first line
]

[tool.ruff.lint.pydocstyle]
# See https://docs.astral.sh/ruff/faq/#does-ruff-support-numpy-or-google-style-docstrings
# for the enabled/disabled rules for the "google" convention.
convention = "google"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using 'google' convention because that's the prevailing docstring type currently, and also because the D417 undocumented-param rule will be enforced.

3 changes: 3 additions & 0 deletions src/geo_assistant/agent/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Create agent graph that calls tools."""

import datetime

from langchain.agents import create_agent
Expand Down Expand Up @@ -31,6 +33,7 @@


async def create_graph():
"""Create langchain agent graph with a list of tools."""
checkpointer = InMemorySaver()
graph = create_agent(
model=llm,
Expand Down
2 changes: 2 additions & 0 deletions src/geo_assistant/agent/llms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Ollama chat model."""

import os

from dotenv import load_dotenv
Expand Down
4 changes: 4 additions & 0 deletions src/geo_assistant/agent/state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""State schema for the geo-assistant agent."""

from typing import NotRequired

from geojson_pydantic import Feature, FeatureCollection
Expand All @@ -6,6 +8,8 @@


class GeoAssistantState(AgentState):
"""Schema for the geo-assistant agent's state."""

place: NotRequired[Feature | None] = None
search_area: NotRequired[Feature | None] = None
places_within_buffer: NotRequired[FeatureCollection | None] = None
Expand Down
8 changes: 6 additions & 2 deletions src/geo_assistant/api/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Chat app API endpoint."""

import logging
from collections.abc import AsyncGenerator
from contextlib import aclosing, asynccontextmanager
Expand All @@ -21,12 +23,12 @@


@asynccontextmanager
async def lifespan(app: FastAPI):
async def _lifespan(app: FastAPI):
app.state.chatbot = await create_graph()
yield


app = FastAPI(title="Geo Assistant", lifespan=lifespan)
app = FastAPI(title="Geo Assistant", lifespan=_lifespan)


app.add_middleware(
Expand All @@ -44,6 +46,7 @@ async def stream_chat(
chatbot: Any,
request: Request,
) -> AsyncGenerator[bytes]:
"""Agent chat stream."""
config: dict[str, Any] = {
"configurable": {
"thread_id": str(thread_id),
Expand Down Expand Up @@ -101,6 +104,7 @@ async def stream_chat(

@app.post("/chat")
async def chat(request: ChatRequestBody, http_request: Request) -> StreamingResponse:
"""HTTP POST endpoint at /chat."""
generator = stream_chat(
ui_state_update=request.agent_state_input,
thread_id=request.thread_id,
Expand Down
6 changes: 6 additions & 0 deletions src/geo_assistant/api/schemas/chat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""Chat API schemas."""

from pydantic import BaseModel

from geo_assistant.agent.state import GeoAssistantState


class ChatRequestBody(BaseModel):
"""Schema for the request to the Chat API."""

thread_id: str
agent_state_input: GeoAssistantState


class ChatResponse(BaseModel):
"""Schema for the response from the Chat API."""

thread_id: str
state: GeoAssistantState
2 changes: 2 additions & 0 deletions src/geo_assistant/frontend/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Chat app frontend."""

import base64
import json
import os
Expand Down
2 changes: 2 additions & 0 deletions src/geo_assistant/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""List of tools available to the agent."""

from geo_assistant.tools.buffer import get_search_area
from geo_assistant.tools.naip import fetch_naip_img
from geo_assistant.tools.overture import get_place, get_places_within_buffer
Expand Down
10 changes: 9 additions & 1 deletion src/geo_assistant/tools/buffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tool to create a buffer polygon around a geometry feature."""

from typing import Annotated

import geopandas as gpd
Expand All @@ -17,8 +19,14 @@ async def get_search_area(
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command:
"""Get a search area buffer in km around the place defined in the agent state."""
"""
Get a search area buffer in km around the place defined in the agent state.

Args:
buffer_size_km: Radius of the buffer in kilometres.
state: Pass in 'place' as state into this agent.
tool_call_id: Optional ID for tracking the tool call.
Comment on lines +25 to +28
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All arguments to public functions will now need to be documented to adhere to pydocstyle rule D417.

"""
place_feature = state.get("place")

if not place_feature:
Expand Down
6 changes: 4 additions & 2 deletions src/geo_assistant/tools/naip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# tools/naip_mpc_tools.py
"""Tool to query Planetary Computer STAC API for NAIP imagery."""

import base64
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
Expand Down Expand Up @@ -39,7 +40,8 @@ async def fetch_naip_img(
Args:
start_date: Start date (YYYY-MM-DD).
end_date: End date (YYYY-MM-DD).

state: Pass in search_area as state into this agent.
tool_call_id: Optional ID for tracking the tool call
"""
if not state["search_area"]:
return Command(
Expand Down
25 changes: 20 additions & 5 deletions src/geo_assistant/tools/overture.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tool to find closest matching Overture place based on user input."""

import json
import os
from typing import Annotated
Expand All @@ -21,7 +23,8 @@


def create_database_connection():
"""Create and configure a DuckDB connection with necessary extensions.
"""
Create and configure a DuckDB connection with necessary extensions.
Args:
database_path: Path to the DuckDB database file
Expand All @@ -43,8 +46,14 @@ async def get_place(
place_name: str,
tool_call_id: Annotated[str, InjectedToolCallId] = "",
) -> Command:
"""Get place location from Overture Maps based on user input place name."""
"""
Get place location from Overture Maps based on user input place name.
Args:
place_name: An address or location given as a human-readable string.
tool_call_id: Optional ID for tracking the tool call.
"""
db_connection = create_database_connection()
source = os.getenv("OVERTURE_SOURCE", "local")
if source == "s3":
Expand Down Expand Up @@ -162,10 +171,16 @@ async def get_places_within_buffer(
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
"""Get places from Overture Maps within user specified area and user specified Overture place type.
Accepts: restaurant(s), cafe(s), coffee shop(s), bar(s), pub(s) - case insensitive."""
"""
Get places from Overture Maps within user specified area and user specified Overture
place type.
Args:
place: Overture place type. Accepts: restaurant(s), cafe(s), coffee shop(s),
bar(s), pub(s) - case insensitive.
state: Pass in 'search_area' as state into this agent.
tool_call_id: Optional ID for tracking the tool call.
"""
# Normalize the place type
place = normalize_place_type(place)

Expand Down
15 changes: 11 additions & 4 deletions src/geo_assistant/tools/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class SatImgSummary(dspy.Signature):
"Describe things you see in the satellite image."
"""Describe things you see in the satellite image."""

img: dspy.Image = dspy.InputField(desc="A satellite image")
answer: str = dspy.OutputField(desc="Description of the image")
Expand All @@ -33,7 +33,8 @@ def __init__(
temperature: float = 0.5,
max_tokens: int = 4_096,
) -> None:
"""Initialize the satellite image summary agent.
"""
Initialize the satellite image summary agent.

Args:
model: The Ollama model to use for summarization
Expand All @@ -53,7 +54,8 @@ def __init__(
self.summarizer = dspy.Predict(SatImgSummary)

def forward(self, img_url: str) -> dspy.Prediction:
"""Generate a summary for the given image URL.
"""
Generate a summary for the given image URL.

Args:
img_url: URL of the image to summarize
Expand All @@ -73,7 +75,12 @@ async def summarize_sat_img(
state: Annotated[GeoAssistantState, InjectedState],
tool_call_id: Annotated[str | None, InjectedToolCallId] = None,
) -> Command:
"""Summarize the contents of a satellite image using an LLM.
"""
Summarize the contents of a satellite image using an LLM.

Args:
state: Pass in 'naip_img_bytes' as state into this agent.
tool_call_id: Optional ID for tracking the tool call.

Returns:
Command containing the image summary and metadata
Expand Down
5 changes: 4 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests for chat API endpoint."""

from uuid import uuid4

import pytest
Expand All @@ -10,7 +12,7 @@

@pytest_asyncio.fixture
async def initialized_app():
"""Initialize the app's chatbot before testing"""
"""Initialize the app's chatbot before testing."""
# Manually initialize the chatbot as the lifespan would
app.state.chatbot = await create_graph()
yield app
Expand All @@ -21,6 +23,7 @@ async def initialized_app():

@pytest.mark.xfail
async def test_call_api(initialized_app):
"""Test calling the API at the /chat HTTP POST endpoint."""
async with AsyncClient(
transport=ASGITransport(app=initialized_app),
base_url="http://test",
Expand Down
4 changes: 4 additions & 0 deletions tests/tools/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests for buffer tool."""

from geojson_pydantic import Feature, Point
from langchain_core.tools.base import ToolCall
from pytest import fixture
Expand All @@ -8,6 +10,7 @@

@fixture
def geo_assistant_fixture():
"""Fixture with a GeoJSON point feature in a GeoAssistantState."""
place_geojson = Feature(
type="Feature",
geometry=Point(type="Point", coordinates=[-9.1393, 38.7223]),
Expand All @@ -22,6 +25,7 @@ def geo_assistant_fixture():


async def test_get_search_area(geo_assistant_fixture):
"""Ensure that `get_search_area` tool returns a buffer Polygon."""
# Call the underlying function directly to test the logic
# This bypasses the injection framework which is better suited for integration tests
command = await get_search_area.ainvoke(
Expand Down
4 changes: 2 additions & 2 deletions tests/tools/test_naip.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests for NAIP tool."""

from types import NoneType

import pytest
Expand All @@ -19,7 +21,6 @@ async def test_fetch_naip():
- Internet access (to reach Planetary Computer STAC + blobs)
- Planetary Computer / NAIP service to be up
"""

# Union Market coordinates from GeoNames: 38.90789, -76.99831
# N 38°54'28" W 76°59'54"
# We'll use a small neighborhood AOI around that point.
Expand Down Expand Up @@ -60,7 +61,6 @@ async def test_fetch_naip_too_large():
- Internet access (to reach Planetary Computer STAC + blobs)
- Planetary Computer / NAIP service to be up
"""

# Union Market coordinates from GeoNames: 38.90789, -76.99831
# N 38°54'28" W 76°59'54"
# We'll use a small neighborhood AOI around that point.
Expand Down
7 changes: 7 additions & 0 deletions tests/tools/test_overture.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests for Overture tool."""

import os

import geopandas as gpd
Expand Down Expand Up @@ -57,6 +59,7 @@ def geo_assistant_with_buffer_fixture():


async def test_get_place():
"""Ensure that `get_place` tool returns an Overture place given a place_name."""
command = await get_place.ainvoke(
ToolCall(
name="get_place",
Expand All @@ -69,6 +72,10 @@ async def test_get_place():


async def test_get_places_within_buffer(geo_assistant_with_buffer_fixture):
"""
Ensure that `get_places_within_buffer` tool returns multiple Overture places that
fit match the category 'cafe' within a specific buffer area around a location.
"""
command = await get_places_within_buffer.ainvoke(
ToolCall(
name="get_places_within_buffer",
Expand Down
4 changes: 4 additions & 0 deletions tests/tools/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
],
)
async def test_summarize_sat_img(img_url, summary):
"""
Ensure that the `summarize_sat_img` tool can describe a satellite image in JPEG
format.
"""
# Load the image from the supplied URL and encode it in base64
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
Expand Down
Loading