diff --git a/.env.example b/.env.example index 7c79605..b490fde 100644 --- a/.env.example +++ b/.env.example @@ -8,5 +8,4 @@ SMALL_MODEL_NAME="openai:gpt-4.1-mini" GEODINI_API="https://geodini.k8s.labs.ds.io" -STAC_CATALOG_NAME="planetarycomputer" STAC_CATALOG_URL="https://planetarycomputer.microsoft.com/api/stac/v1" \ No newline at end of file diff --git a/frontend/streamlit_app.py b/frontend/streamlit_app.py index d456a69..7b213b1 100644 --- a/frontend/streamlit_app.py +++ b/frontend/streamlit_app.py @@ -34,22 +34,57 @@ """ ) -# Create input field for the query -query = st.text_input( - "Enter your query", - placeholder="Find imagery over Paris from 2017", - help="Describe what kind of satellite imagery you're looking for", -) +# Create two columns for query and catalog URL +col1, col2 = st.columns([3, 1]) -# Add a search button -search_button = st.button("Search") +with col1: + # Create input field for the query + query = st.text_input( + "Enter your query", + placeholder="Find imagery over Paris from 2017", + help="Describe what kind of satellite imagery you're looking for", + ) + # Add a search button + search_button = st.button("Search") +with col2: + # Define catalog options + catalog_options = { + "Planetary Computer": "https://planetarycomputer.microsoft.com/api/stac/v1", + "VEDA": "https://openveda.cloud/api/stac", + "E84 Earth Search": "https://earth-search.aws.element84.com/v1", + "DevSeed EOAPI.dev": "https://stac.eoapi.dev", + "Custom URL": "custom", + } -# Function to run the search asynchronously -async def run_search(query): - response = requests.post( - f"{API_URL}/items/search", json={"query": query, "limit": 10} + # Create dropdown for catalog selection + selected_catalog = st.selectbox( + "Select STAC Catalog", + options=list(catalog_options.keys()), + index=0, # Default to Planetary Computer + help="Choose a predefined STAC catalog or select 'Custom URL' to enter your own.", ) + + # Handle custom URL input + if selected_catalog == "Custom URL": + catalog_url = st.text_input( + "Enter Custom Catalog URL", + placeholder="https://your-catalog.com/stac/v1", + help="Enter the URL of your custom STAC catalog.", + ) + else: + catalog_url = catalog_options[selected_catalog] + # Show the selected URL as read-only info + st.info(f"Using: {catalog_url}") + + +# Function to run the search asynchronously +async def run_search(query, catalog_url=None): + payload = {"query": query, "limit": 10} + if catalog_url: + payload["catalog_url"] = catalog_url.strip() + + response = requests.post(f"{API_URL}/items/search", json=payload) return response.json()["results"] @@ -60,7 +95,7 @@ async def run_search(query): # Run the async search loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - results = loop.run_until_complete(run_search(query)) + results = loop.run_until_complete(run_search(query, catalog_url)) items = results["items"] aoi = results["aoi"] explanation = results["explanation"] @@ -212,10 +247,19 @@ async def run_search(query): """ Search for satellite imagery using natural language. - **Examples queries:** + **Available STAC Catalogs:** + - **Planetary Computer**: Microsoft's global dataset catalog + - **VEDA**: NASA's Earth science data catalog + - **E84 Earth Search**: Element 84's STAC catalog for Earth observation data on AWS Open Data + - **DevSeed EOAPI.dev**: DevSeed's example STAC catalog + - **Custom URL**: Enter any STAC-compliant catalog URL + + The system will automatically index new catalogs on first use. + + **Example queries:** - imagery of Paris from 2017 - Cloud-free satellite data of Georgia the country from 2022 - - relatively cloud-free images in 2024 that have RGB visual bands over Longmont, Colorado that can be downloaded via HTTP + - relatively cloud-free images in 2024 over Longmont, Colorado - images in 2024 over Odisha with cloud cover between 50 to 60% - NAIP imagery over the state of Washington - Burn scar imagery of from 2024 over the state of California diff --git a/helm-chart/README.md b/helm-chart/README.md index 64426cf..6f8fd4c 100644 --- a/helm-chart/README.md +++ b/helm-chart/README.md @@ -137,7 +137,6 @@ The init container uses the same STAC catalog configuration as the API: api: env: STAC_CATALOG_URL: "https://planetarycomputer.microsoft.com/api/stac/v1" - STAC_CATALOG_NAME: "planetarycomputer" initContainer: enabled: true # Set to false to disable data pre-loading diff --git a/helm-chart/values.yaml b/helm-chart/values.yaml index 39e53ab..5f47b03 100644 --- a/helm-chart/values.yaml +++ b/helm-chart/values.yaml @@ -41,7 +41,6 @@ api: PYTHONUNBUFFERED: "1" HF_HOME: "/app/data/.cache/huggingface" GEODINI_API: "https://geodini.k8s.labs.ds.io" - STAC_CATALOG_NAME: "planetarycomputer" STAC_CATALOG_URL: "https://planetarycomputer.microsoft.com/api/stac/v1" DEFAULT_TARGET_COLLECTIONS: "['landsat-8-c2-l2', 'sentinel-2-l2a']" diff --git a/stac_search/agents/collections_search.py b/stac_search/agents/collections_search.py index 15ec995..daca12b 100644 --- a/stac_search/agents/collections_search.py +++ b/stac_search/agents/collections_search.py @@ -9,9 +9,8 @@ from pprint import pformat from typing import List, Dict, Any -import chromadb from pydantic_ai import Agent -from sentence_transformers import SentenceTransformer +from stac_search.catalog_manager import CatalogManager logger = logging.getLogger(__name__) @@ -20,7 +19,6 @@ MODEL_NAME = "all-MiniLM-L6-v2" DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb") -STAC_CATALOG_NAME = os.getenv("STAC_CATALOG_NAME", "planetarycomputer") STAC_COLLECTIONS_URL = os.getenv( "STAC_COLLECTIONS_URL", "https://planetarycomputer.microsoft.com/api/stac/v1" ) @@ -65,7 +63,7 @@ async def collection_search( top_k: int = 5, model_name: str = MODEL_NAME, data_path: str = DATA_PATH, - stac_catalog_name: str = STAC_CATALOG_NAME, + catalog_url: str = None, ) -> List[CollectionWithExplanation]: """ Search for collections and rerank results with explanations @@ -75,25 +73,31 @@ async def collection_search( top_k: Maximum number of results to return model_name: Name of the sentence transformer model to use data_path: Path to the vector database - stac_catalog_name: Name of the STAC catalog - stac_collections_url: URL of the STAC collections API + catalog_url: URL of the STAC catalog Returns: Ranked results with relevance explanations """ start_time = time.time() - # Initialize model and database connections - model = SentenceTransformer(model_name) + # Initialize catalog manager + catalog_manager = CatalogManager(data_path=data_path, model_name=model_name) + + # If catalog_url is provided, ensure it's loaded + if catalog_url: + load_result = await catalog_manager.load_catalog(catalog_url) + if not load_result["success"]: + logger.error(f"Failed to load catalog: {load_result['error']}") + raise ValueError(f"Failed to load catalog: {load_result['error']}") + + # Get the appropriate collection + collection = catalog_manager.get_catalog_collection(catalog_url) + load_model_time = time.time() logger.info(f"Model loading time: {load_model_time - start_time:.4f} seconds") - client = chromadb.PersistentClient(path=data_path) - collection_name = f"{stac_catalog_name}_collections" - collection = client.get_collection(name=collection_name) - # Generate query embedding - query_embedding = model.encode([query]) + query_embedding = catalog_manager.model.encode([query]) # Search vector database results = collection.query( diff --git a/stac_search/agents/items_search.py b/stac_search/agents/items_search.py index ee723af..920be60 100644 --- a/stac_search/agents/items_search.py +++ b/stac_search/agents/items_search.py @@ -33,6 +33,7 @@ @dataclass class Context: query: str + catalog_url: str | None = None location: str | None = None top_k: int = 5 return_search_params_only: bool = False @@ -105,12 +106,16 @@ class CollectionSearchResult: collections: List[CollectionWithExplanation] -async def search_collections(query: str) -> CollectionSearchResult | None: +async def search_collections( + query: str, catalog_url: str = None +) -> CollectionSearchResult | None: logger.info("Searching for relevant collections ...") collection_query = await collection_query_framing_agent.run(query) logger.info(f"Framed collection query: {collection_query.data.query}") if collection_query.data.is_specific: - collections = await collection_search(collection_query.data.query) + collections = await collection_search( + collection_query.data.query, catalog_url=catalog_url + ) return CollectionSearchResult(collections=collections) else: return None @@ -278,11 +283,26 @@ async def item_search(ctx: Context) -> ItemSearchResult: results = await search_items_agent.run( f"Find items for the query: {ctx.query}", deps=ctx ) + catalog_url_to_use = ctx.catalog_url or STAC_CATALOG_URL # determine the collections to search - target_collections = await search_collections(ctx.query) or [] + target_collections = await search_collections(ctx.query, catalog_url_to_use) or [] logger.info(f"Target collections: {pformat(target_collections)}") - default_target_collections = DEFAULT_TARGET_COLLECTIONS + + if not target_collections: + # If no specific collections were found, use the default target collections + default_target_collections = DEFAULT_TARGET_COLLECTIONS + # check that default_target_collections exist in the catalog + all_collection_ids = [ + collection.id + for collection in Client.open(catalog_url_to_use).get_collections() + ] + default_target_collections = [ + collection_id + for collection_id in default_target_collections + if collection_id in all_collection_ids + ] + if target_collections: explanation = "Considering the following collections:" for result in target_collections.collections: @@ -290,12 +310,15 @@ async def item_search(ctx: Context) -> ItemSearchResult: collections_to_search = [ collection.collection_id for collection in target_collections.collections ] - else: + elif default_target_collections: explanation = f"Including the following common collections in the search: {', '.join(default_target_collections)}\n" collections_to_search = default_target_collections + else: + explanation = "Searching all collections in the catalog." + collections_to_search = all_collection_ids # Actually perform the search - client = Client.open(STAC_CATALOG_URL) + client = Client.open(catalog_url_to_use) params = { "max_items": 20, "collections": collections_to_search, @@ -310,11 +333,9 @@ async def item_search(ctx: Context) -> ItemSearchResult: logger.info(f"Found polygon for {results.data.location}") params["intersects"] = polygon else: + explanation += f"\n\n No polygon found for {results.data.location}. " return ItemSearchResult( - items=None, - search_params=params, - aoi=None, - explanation=f"No polygon found for {results.data.location}", + items=None, search_params=params, aoi=None, explanation=explanation ) if ctx.return_search_params_only: diff --git a/stac_search/api.py b/stac_search/api.py index 55a75a1..be130a8 100644 --- a/stac_search/api.py +++ b/stac_search/api.py @@ -2,9 +2,10 @@ FastAPI server for STAC Natural Query """ -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel +from typing import Optional import uvicorn from stac_search.agents.collections_search import collection_search @@ -30,10 +31,12 @@ # Define request model class QueryRequest(BaseModel): query: str + catalog_url: Optional[str] = None class STACItemsRequest(BaseModel): query: str + catalog_url: Optional[str] = None return_search_params_only: bool = False @@ -41,18 +44,28 @@ class STACItemsRequest(BaseModel): @app.post("/search") async def search(request: QueryRequest): """Search for STAC collections using natural language""" - results = collection_search(request.query) - return {"results": results} + try: + results = await collection_search( + request.query, catalog_url=request.catalog_url + ) + return {"results": results} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @app.post("/items/search") async def search_items(request: STACItemsRequest): """Search for STAC items using natural language""" - ctx = ItemSearchContext( - query=request.query, return_search_params_only=request.return_search_params_only - ) - results = await item_search(ctx) - return {"results": results} + try: + ctx = ItemSearchContext( + query=request.query, + catalog_url=request.catalog_url, + return_search_params_only=request.return_search_params_only, + ) + results = await item_search(ctx) + return {"results": results} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) def start_server(host: str = "0.0.0.0", port: int = 8000): diff --git a/stac_search/catalog_manager.py b/stac_search/catalog_manager.py new file mode 100644 index 0000000..e3a3487 --- /dev/null +++ b/stac_search/catalog_manager.py @@ -0,0 +1,175 @@ +""" +Catalog Manager for STAC Natural Query - handles dynamic catalog loading and management +""" + +import hashlib +import logging +import os +from typing import Optional, Dict, Any +import chromadb +from pystac_client import Client +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + +# Constants +MODEL_NAME = "all-MiniLM-L6-v2" +DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb") + + +class CatalogManager: + """Manages STAC catalog indexing and retrieval operations""" + + def __init__(self, data_path: str = DATA_PATH, model_name: str = MODEL_NAME): + self.data_path = data_path + self.model_name = model_name + self.model = SentenceTransformer(model_name) + self.client = chromadb.PersistentClient(path=data_path) + + def _get_catalog_name(self, catalog_url: str) -> str: + """Generate a unique catalog name from URL""" + # Create a hash of the URL for consistent naming + url_hash = hashlib.md5(catalog_url.encode()).hexdigest()[:8] + # Clean URL for readability + clean_url = catalog_url.replace("https://", "").replace("http://", "") + clean_url = clean_url.replace("/", "_").replace(".", "_") + return f"{clean_url}_{url_hash}" + + def _get_collection_name(self, catalog_url: str) -> str: + """Get ChromaDB collection name for a catalog""" + catalog_name = self._get_catalog_name(catalog_url) + return f"{catalog_name}_collections" + + def catalog_exists(self, catalog_url: str) -> bool: + """Check if a catalog is already indexed in the vector database""" + collection_name = self._get_collection_name(catalog_url) + try: + existing_collections = self.client.list_collections() + print(f"Existing collections: {[col.name for col in existing_collections]}") + print(f"Checking for collection: {collection_name}") + return any(col.name == collection_name for col in existing_collections) + except Exception as e: + logger.error(f"Error checking catalog existence: {e}") + return False + + def validate_catalog_url(self, catalog_url: str) -> bool: + """Validate that the catalog URL is accessible and is a valid STAC catalog""" + try: + stac_client = Client.open(catalog_url) + # Try to get at least one collection to verify it's a valid catalog + collections = list(stac_client.collection_search().collections()) + return len(collections) > 0 + except Exception as e: + logger.error(f"Invalid catalog URL {catalog_url}: {e}") + return False + + def fetch_collections(self, stac_client: Client) -> list: + """Fetch STAC collections using pystac-client""" + try: + collections = stac_client.collection_search().collections() + return list(collections) + except Exception as e: + logger.error(f"Error fetching collections: {e}") + return [] + + def generate_embeddings(self, collections: list) -> list: + """Generate embeddings for each collection (title + description)""" + texts = [] + for collection in collections: + title = getattr(collection, "title", "") or "" + description = getattr(collection, "description", "") or "" + texts.append(f"{title} {description}") + + embeddings = self.model.encode(texts) + return embeddings + + def store_in_vector_db(self, collections: list, chroma_collection) -> None: + """Store embeddings in ChromaDB""" + if not collections: + logger.warning("No collections to store") + return + + metadatas = [] + for collection in collections: + metadata = { + "title": getattr(collection, "title", "") or "", + "description": getattr(collection, "description", "") or "", + "collection_id": getattr(collection, "id", ""), + } + metadatas.append(metadata) + + embeddings = self.generate_embeddings(collections) + + chroma_collection.add( + ids=[str(i) for i in range(len(collections))], + embeddings=embeddings, + metadatas=metadatas, + ) + + async def load_catalog(self, catalog_url: str) -> Dict[str, Any]: + """Load and index a catalog if it doesn't exist""" + try: + # Validate catalog URL first + if not self.validate_catalog_url(catalog_url): + return { + "success": False, + "error": f"Invalid or inaccessible catalog URL: {catalog_url}", + } + + # Check if catalog already exists + if self.catalog_exists(catalog_url): + logger.info(f"Catalog {catalog_url} already indexed") + return { + "success": True, + "message": f"Catalog already indexed", + "catalog_name": self._get_catalog_name(catalog_url), + } + + # Load the catalog + logger.info(f"Loading catalog from {catalog_url}") + stac_client = Client.open(catalog_url) + collections = self.fetch_collections(stac_client) + + if not collections: + return { + "success": False, + "error": f"No collections found in catalog {catalog_url}", + } + + # Create ChromaDB collection + collection_name = self._get_collection_name(catalog_url) + chroma_collection = self.client.create_collection( + name=collection_name, get_or_create=True + ) + + # Store in vector database + self.store_in_vector_db(collections, chroma_collection) + + logger.info( + f"Successfully indexed {len(collections)} collections from {catalog_url}" + ) + return { + "success": True, + "message": f"Successfully indexed {len(collections)} collections", + "catalog_name": self._get_catalog_name(catalog_url), + "collections_count": len(collections), + } + + except Exception as e: + logger.error(f"Error loading catalog {catalog_url}: {e}") + return {"success": False, "error": f"Error loading catalog: {str(e)}"} + + def get_catalog_collection( + self, catalog_url: Optional[str] = None + ) -> chromadb.Collection: + """Get the ChromaDB collection for a catalog""" + if not catalog_url: + catalog_url = os.environ.get("STAC_CATALOG_URL") + + collection_name = self._get_collection_name(catalog_url) + + try: + return self.client.get_collection(name=collection_name) + except Exception as e: + logger.error(f"Error getting collection {collection_name}: {e}") + raise diff --git a/stac_search/load.py b/stac_search/load.py index 87f2c52..3f22f62 100644 --- a/stac_search/load.py +++ b/stac_search/load.py @@ -2,79 +2,35 @@ Load CLI for STAC Natural Query - creates and populates the vector database """ +import asyncio import logging import os -import chromadb -from pystac_client import Client -from sentence_transformers import SentenceTransformer +from .catalog_manager import CatalogManager logger = logging.getLogger(__name__) -# Constants -MODEL_NAME = "all-MiniLM-L6-v2" -DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb") +def load_data(catalog_url: str): + """Load STAC collections into the vector database using CatalogManager""" + try: + # Initialize catalog manager + catalog_manager = CatalogManager() -def load_data(catalog_url, catalog_name): - """Load STAC collections into the vector database""" - logger.info("Initializing vector database...") + # Load catalog using async method + result = asyncio.run(catalog_manager.load_catalog(catalog_url)) - # Initialize the model - model = SentenceTransformer(MODEL_NAME) - - # Initialize ChromaDB client with persistence settings - client = chromadb.PersistentClient(path=DATA_PATH) - chroma_collection = client.create_collection( - name=f"{catalog_name}_collections", get_or_create=True - ) - - # Initialize STAC client - stac_client = Client.open(catalog_url) - - logger.info("Fetching STAC collections...") - collections = fetch_collections(stac_client) - logger.info(f"Found {len(collections)} collections") - - logger.info("Generating embeddings and storing in vector database...") - store_in_vector_db(collections, model, chroma_collection) - - logger.info("Data loading complete!") - - -def fetch_collections(stac_client): - """Fetch STAC collections using pystac-client""" - collections = stac_client.collection_search().collections() - return list(collections) + if result["success"]: + logger.info(f"Successfully loaded catalog: {result['message']}") + if "collections_count" in result: + logger.info(f"Indexed {result['collections_count']} collections") + else: + logger.error(f"Failed to load catalog: {result['error']}") + raise Exception(result["error"]) - -def generate_embeddings(collections, model): - """Generate embeddings for each collection (title + description)""" - texts = [ - f"{collection.title} {collection.description}" for collection in collections - ] - embeddings = model.encode(texts) - return embeddings - - -def store_in_vector_db(collections, model, chroma_collection): - """Store embeddings in ChromaDB""" - metadatas = [ - { - "title": collection.title or "", - "description": collection.description or "", - "collection_id": collection.id, - } - for collection in collections - ] - - embeddings = generate_embeddings(collections, model) - - chroma_collection.add( - ids=[str(i) for i in range(len(collections))], - embeddings=embeddings, - metadatas=metadatas, - ) + except Exception as e: + logger.error(f"Error loading data: {e}") + raise if __name__ == "__main__": @@ -83,10 +39,7 @@ def store_in_vector_db(collections, model, chroma_collection): # catalog_url="https://planetarycomputer.microsoft.com/api/stac/v1", # catalog_name="planetarycomputer", # ) - import os - STAC_CATALOG_URL = os.environ.get( "STAC_CATALOG_URL", "https://planetarycomputer.microsoft.com/api/stac/v1" ) - STAC_CATALOG_NAME = os.environ.get("STAC_CATALOG_NAME", "planetarycomputer") - load_data(catalog_url=STAC_CATALOG_URL, catalog_name=STAC_CATALOG_NAME) + load_data(catalog_url=STAC_CATALOG_URL)