From ddb9571ffc69c9d5ccda8e6c9b4dcc34d6185254 Mon Sep 17 00:00:00 2001 From: Tarashish Mishra Date: Wed, 25 Jun 2025 14:00:39 +0000 Subject: [PATCH] Support dynamically computing embeddings for a catalog Users can pass in a STAC catalog url with their search request. The tool will fetch the collections from the mentioned catalog, compute and cache the embeddings and search against the mentioned catalog. This would allow a stac-semantic-search instance to search against any public STAC catalog and not just a fixed one --- .env.example | 1 - frontend/streamlit_app.py | 74 ++++++++-- helm-chart/README.md | 1 - helm-chart/values.yaml | 1 - stac_search/agents/collections_search.py | 30 ++-- stac_search/agents/items_search.py | 41 ++++-- stac_search/api.py | 29 ++-- stac_search/catalog_manager.py | 175 +++++++++++++++++++++++ stac_search/load.py | 87 +++-------- 9 files changed, 323 insertions(+), 116 deletions(-) create mode 100644 stac_search/catalog_manager.py diff --git a/.env.example b/.env.example index 7c79605..b490fde 100644 --- a/.env.example +++ b/.env.example @@ -8,5 +8,4 @@ SMALL_MODEL_NAME="openai:gpt-4.1-mini" GEODINI_API="https://geodini.k8s.labs.ds.io" -STAC_CATALOG_NAME="planetarycomputer" STAC_CATALOG_URL="https://planetarycomputer.microsoft.com/api/stac/v1" \ No newline at end of file diff --git a/frontend/streamlit_app.py b/frontend/streamlit_app.py index d456a69..7b213b1 100644 --- a/frontend/streamlit_app.py +++ b/frontend/streamlit_app.py @@ -34,22 +34,57 @@ """ ) -# Create input field for the query -query = st.text_input( - "Enter your query", - placeholder="Find imagery over Paris from 2017", - help="Describe what kind of satellite imagery you're looking for", -) +# Create two columns for query and catalog URL +col1, col2 = st.columns([3, 1]) -# Add a search button -search_button = st.button("Search") +with col1: + # Create input field for the query + query = st.text_input( + "Enter your query", + placeholder="Find imagery over Paris from 2017", + help="Describe what kind of satellite imagery you're looking for", + ) + # Add a search button + search_button = st.button("Search") +with col2: + # Define catalog options + catalog_options = { + "Planetary Computer": "https://planetarycomputer.microsoft.com/api/stac/v1", + "VEDA": "https://openveda.cloud/api/stac", + "E84 Earth Search": "https://earth-search.aws.element84.com/v1", + "DevSeed EOAPI.dev": "https://stac.eoapi.dev", + "Custom URL": "custom", + } -# Function to run the search asynchronously -async def run_search(query): - response = requests.post( - f"{API_URL}/items/search", json={"query": query, "limit": 10} + # Create dropdown for catalog selection + selected_catalog = st.selectbox( + "Select STAC Catalog", + options=list(catalog_options.keys()), + index=0, # Default to Planetary Computer + help="Choose a predefined STAC catalog or select 'Custom URL' to enter your own.", ) + + # Handle custom URL input + if selected_catalog == "Custom URL": + catalog_url = st.text_input( + "Enter Custom Catalog URL", + placeholder="https://your-catalog.com/stac/v1", + help="Enter the URL of your custom STAC catalog.", + ) + else: + catalog_url = catalog_options[selected_catalog] + # Show the selected URL as read-only info + st.info(f"Using: {catalog_url}") + + +# Function to run the search asynchronously +async def run_search(query, catalog_url=None): + payload = {"query": query, "limit": 10} + if catalog_url: + payload["catalog_url"] = catalog_url.strip() + + response = requests.post(f"{API_URL}/items/search", json=payload) return response.json()["results"] @@ -60,7 +95,7 @@ async def run_search(query): # Run the async search loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - results = loop.run_until_complete(run_search(query)) + results = loop.run_until_complete(run_search(query, catalog_url)) items = results["items"] aoi = results["aoi"] explanation = results["explanation"] @@ -212,10 +247,19 @@ async def run_search(query): """ Search for satellite imagery using natural language. - **Examples queries:** + **Available STAC Catalogs:** + - **Planetary Computer**: Microsoft's global dataset catalog + - **VEDA**: NASA's Earth science data catalog + - **E84 Earth Search**: Element 84's STAC catalog for Earth observation data on AWS Open Data + - **DevSeed EOAPI.dev**: DevSeed's example STAC catalog + - **Custom URL**: Enter any STAC-compliant catalog URL + + The system will automatically index new catalogs on first use. + + **Example queries:** - imagery of Paris from 2017 - Cloud-free satellite data of Georgia the country from 2022 - - relatively cloud-free images in 2024 that have RGB visual bands over Longmont, Colorado that can be downloaded via HTTP + - relatively cloud-free images in 2024 over Longmont, Colorado - images in 2024 over Odisha with cloud cover between 50 to 60% - NAIP imagery over the state of Washington - Burn scar imagery of from 2024 over the state of California diff --git a/helm-chart/README.md b/helm-chart/README.md index 64426cf..6f8fd4c 100644 --- a/helm-chart/README.md +++ b/helm-chart/README.md @@ -137,7 +137,6 @@ The init container uses the same STAC catalog configuration as the API: api: env: STAC_CATALOG_URL: "https://planetarycomputer.microsoft.com/api/stac/v1" - STAC_CATALOG_NAME: "planetarycomputer" initContainer: enabled: true # Set to false to disable data pre-loading diff --git a/helm-chart/values.yaml b/helm-chart/values.yaml index 39e53ab..5f47b03 100644 --- a/helm-chart/values.yaml +++ b/helm-chart/values.yaml @@ -41,7 +41,6 @@ api: PYTHONUNBUFFERED: "1" HF_HOME: "/app/data/.cache/huggingface" GEODINI_API: "https://geodini.k8s.labs.ds.io" - STAC_CATALOG_NAME: "planetarycomputer" STAC_CATALOG_URL: "https://planetarycomputer.microsoft.com/api/stac/v1" DEFAULT_TARGET_COLLECTIONS: "['landsat-8-c2-l2', 'sentinel-2-l2a']" diff --git a/stac_search/agents/collections_search.py b/stac_search/agents/collections_search.py index 15ec995..daca12b 100644 --- a/stac_search/agents/collections_search.py +++ b/stac_search/agents/collections_search.py @@ -9,9 +9,8 @@ from pprint import pformat from typing import List, Dict, Any -import chromadb from pydantic_ai import Agent -from sentence_transformers import SentenceTransformer +from stac_search.catalog_manager import CatalogManager logger = logging.getLogger(__name__) @@ -20,7 +19,6 @@ MODEL_NAME = "all-MiniLM-L6-v2" DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb") -STAC_CATALOG_NAME = os.getenv("STAC_CATALOG_NAME", "planetarycomputer") STAC_COLLECTIONS_URL = os.getenv( "STAC_COLLECTIONS_URL", "https://planetarycomputer.microsoft.com/api/stac/v1" ) @@ -65,7 +63,7 @@ async def collection_search( top_k: int = 5, model_name: str = MODEL_NAME, data_path: str = DATA_PATH, - stac_catalog_name: str = STAC_CATALOG_NAME, + catalog_url: str = None, ) -> List[CollectionWithExplanation]: """ Search for collections and rerank results with explanations @@ -75,25 +73,31 @@ async def collection_search( top_k: Maximum number of results to return model_name: Name of the sentence transformer model to use data_path: Path to the vector database - stac_catalog_name: Name of the STAC catalog - stac_collections_url: URL of the STAC collections API + catalog_url: URL of the STAC catalog Returns: Ranked results with relevance explanations """ start_time = time.time() - # Initialize model and database connections - model = SentenceTransformer(model_name) + # Initialize catalog manager + catalog_manager = CatalogManager(data_path=data_path, model_name=model_name) + + # If catalog_url is provided, ensure it's loaded + if catalog_url: + load_result = await catalog_manager.load_catalog(catalog_url) + if not load_result["success"]: + logger.error(f"Failed to load catalog: {load_result['error']}") + raise ValueError(f"Failed to load catalog: {load_result['error']}") + + # Get the appropriate collection + collection = catalog_manager.get_catalog_collection(catalog_url) + load_model_time = time.time() logger.info(f"Model loading time: {load_model_time - start_time:.4f} seconds") - client = chromadb.PersistentClient(path=data_path) - collection_name = f"{stac_catalog_name}_collections" - collection = client.get_collection(name=collection_name) - # Generate query embedding - query_embedding = model.encode([query]) + query_embedding = catalog_manager.model.encode([query]) # Search vector database results = collection.query( diff --git a/stac_search/agents/items_search.py b/stac_search/agents/items_search.py index ee723af..920be60 100644 --- a/stac_search/agents/items_search.py +++ b/stac_search/agents/items_search.py @@ -33,6 +33,7 @@ @dataclass class Context: query: str + catalog_url: str | None = None location: str | None = None top_k: int = 5 return_search_params_only: bool = False @@ -105,12 +106,16 @@ class CollectionSearchResult: collections: List[CollectionWithExplanation] -async def search_collections(query: str) -> CollectionSearchResult | None: +async def search_collections( + query: str, catalog_url: str = None +) -> CollectionSearchResult | None: logger.info("Searching for relevant collections ...") collection_query = await collection_query_framing_agent.run(query) logger.info(f"Framed collection query: {collection_query.data.query}") if collection_query.data.is_specific: - collections = await collection_search(collection_query.data.query) + collections = await collection_search( + collection_query.data.query, catalog_url=catalog_url + ) return CollectionSearchResult(collections=collections) else: return None @@ -278,11 +283,26 @@ async def item_search(ctx: Context) -> ItemSearchResult: results = await search_items_agent.run( f"Find items for the query: {ctx.query}", deps=ctx ) + catalog_url_to_use = ctx.catalog_url or STAC_CATALOG_URL # determine the collections to search - target_collections = await search_collections(ctx.query) or [] + target_collections = await search_collections(ctx.query, catalog_url_to_use) or [] logger.info(f"Target collections: {pformat(target_collections)}") - default_target_collections = DEFAULT_TARGET_COLLECTIONS + + if not target_collections: + # If no specific collections were found, use the default target collections + default_target_collections = DEFAULT_TARGET_COLLECTIONS + # check that default_target_collections exist in the catalog + all_collection_ids = [ + collection.id + for collection in Client.open(catalog_url_to_use).get_collections() + ] + default_target_collections = [ + collection_id + for collection_id in default_target_collections + if collection_id in all_collection_ids + ] + if target_collections: explanation = "Considering the following collections:" for result in target_collections.collections: @@ -290,12 +310,15 @@ async def item_search(ctx: Context) -> ItemSearchResult: collections_to_search = [ collection.collection_id for collection in target_collections.collections ] - else: + elif default_target_collections: explanation = f"Including the following common collections in the search: {', '.join(default_target_collections)}\n" collections_to_search = default_target_collections + else: + explanation = "Searching all collections in the catalog." + collections_to_search = all_collection_ids # Actually perform the search - client = Client.open(STAC_CATALOG_URL) + client = Client.open(catalog_url_to_use) params = { "max_items": 20, "collections": collections_to_search, @@ -310,11 +333,9 @@ async def item_search(ctx: Context) -> ItemSearchResult: logger.info(f"Found polygon for {results.data.location}") params["intersects"] = polygon else: + explanation += f"\n\n No polygon found for {results.data.location}. " return ItemSearchResult( - items=None, - search_params=params, - aoi=None, - explanation=f"No polygon found for {results.data.location}", + items=None, search_params=params, aoi=None, explanation=explanation ) if ctx.return_search_params_only: diff --git a/stac_search/api.py b/stac_search/api.py index 55a75a1..be130a8 100644 --- a/stac_search/api.py +++ b/stac_search/api.py @@ -2,9 +2,10 @@ FastAPI server for STAC Natural Query """ -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel +from typing import Optional import uvicorn from stac_search.agents.collections_search import collection_search @@ -30,10 +31,12 @@ # Define request model class QueryRequest(BaseModel): query: str + catalog_url: Optional[str] = None class STACItemsRequest(BaseModel): query: str + catalog_url: Optional[str] = None return_search_params_only: bool = False @@ -41,18 +44,28 @@ class STACItemsRequest(BaseModel): @app.post("/search") async def search(request: QueryRequest): """Search for STAC collections using natural language""" - results = collection_search(request.query) - return {"results": results} + try: + results = await collection_search( + request.query, catalog_url=request.catalog_url + ) + return {"results": results} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) @app.post("/items/search") async def search_items(request: STACItemsRequest): """Search for STAC items using natural language""" - ctx = ItemSearchContext( - query=request.query, return_search_params_only=request.return_search_params_only - ) - results = await item_search(ctx) - return {"results": results} + try: + ctx = ItemSearchContext( + query=request.query, + catalog_url=request.catalog_url, + return_search_params_only=request.return_search_params_only, + ) + results = await item_search(ctx) + return {"results": results} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) def start_server(host: str = "0.0.0.0", port: int = 8000): diff --git a/stac_search/catalog_manager.py b/stac_search/catalog_manager.py new file mode 100644 index 0000000..e3a3487 --- /dev/null +++ b/stac_search/catalog_manager.py @@ -0,0 +1,175 @@ +""" +Catalog Manager for STAC Natural Query - handles dynamic catalog loading and management +""" + +import hashlib +import logging +import os +from typing import Optional, Dict, Any +import chromadb +from pystac_client import Client +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + +# Constants +MODEL_NAME = "all-MiniLM-L6-v2" +DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb") + + +class CatalogManager: + """Manages STAC catalog indexing and retrieval operations""" + + def __init__(self, data_path: str = DATA_PATH, model_name: str = MODEL_NAME): + self.data_path = data_path + self.model_name = model_name + self.model = SentenceTransformer(model_name) + self.client = chromadb.PersistentClient(path=data_path) + + def _get_catalog_name(self, catalog_url: str) -> str: + """Generate a unique catalog name from URL""" + # Create a hash of the URL for consistent naming + url_hash = hashlib.md5(catalog_url.encode()).hexdigest()[:8] + # Clean URL for readability + clean_url = catalog_url.replace("https://", "").replace("http://", "") + clean_url = clean_url.replace("/", "_").replace(".", "_") + return f"{clean_url}_{url_hash}" + + def _get_collection_name(self, catalog_url: str) -> str: + """Get ChromaDB collection name for a catalog""" + catalog_name = self._get_catalog_name(catalog_url) + return f"{catalog_name}_collections" + + def catalog_exists(self, catalog_url: str) -> bool: + """Check if a catalog is already indexed in the vector database""" + collection_name = self._get_collection_name(catalog_url) + try: + existing_collections = self.client.list_collections() + print(f"Existing collections: {[col.name for col in existing_collections]}") + print(f"Checking for collection: {collection_name}") + return any(col.name == collection_name for col in existing_collections) + except Exception as e: + logger.error(f"Error checking catalog existence: {e}") + return False + + def validate_catalog_url(self, catalog_url: str) -> bool: + """Validate that the catalog URL is accessible and is a valid STAC catalog""" + try: + stac_client = Client.open(catalog_url) + # Try to get at least one collection to verify it's a valid catalog + collections = list(stac_client.collection_search().collections()) + return len(collections) > 0 + except Exception as e: + logger.error(f"Invalid catalog URL {catalog_url}: {e}") + return False + + def fetch_collections(self, stac_client: Client) -> list: + """Fetch STAC collections using pystac-client""" + try: + collections = stac_client.collection_search().collections() + return list(collections) + except Exception as e: + logger.error(f"Error fetching collections: {e}") + return [] + + def generate_embeddings(self, collections: list) -> list: + """Generate embeddings for each collection (title + description)""" + texts = [] + for collection in collections: + title = getattr(collection, "title", "") or "" + description = getattr(collection, "description", "") or "" + texts.append(f"{title} {description}") + + embeddings = self.model.encode(texts) + return embeddings + + def store_in_vector_db(self, collections: list, chroma_collection) -> None: + """Store embeddings in ChromaDB""" + if not collections: + logger.warning("No collections to store") + return + + metadatas = [] + for collection in collections: + metadata = { + "title": getattr(collection, "title", "") or "", + "description": getattr(collection, "description", "") or "", + "collection_id": getattr(collection, "id", ""), + } + metadatas.append(metadata) + + embeddings = self.generate_embeddings(collections) + + chroma_collection.add( + ids=[str(i) for i in range(len(collections))], + embeddings=embeddings, + metadatas=metadatas, + ) + + async def load_catalog(self, catalog_url: str) -> Dict[str, Any]: + """Load and index a catalog if it doesn't exist""" + try: + # Validate catalog URL first + if not self.validate_catalog_url(catalog_url): + return { + "success": False, + "error": f"Invalid or inaccessible catalog URL: {catalog_url}", + } + + # Check if catalog already exists + if self.catalog_exists(catalog_url): + logger.info(f"Catalog {catalog_url} already indexed") + return { + "success": True, + "message": f"Catalog already indexed", + "catalog_name": self._get_catalog_name(catalog_url), + } + + # Load the catalog + logger.info(f"Loading catalog from {catalog_url}") + stac_client = Client.open(catalog_url) + collections = self.fetch_collections(stac_client) + + if not collections: + return { + "success": False, + "error": f"No collections found in catalog {catalog_url}", + } + + # Create ChromaDB collection + collection_name = self._get_collection_name(catalog_url) + chroma_collection = self.client.create_collection( + name=collection_name, get_or_create=True + ) + + # Store in vector database + self.store_in_vector_db(collections, chroma_collection) + + logger.info( + f"Successfully indexed {len(collections)} collections from {catalog_url}" + ) + return { + "success": True, + "message": f"Successfully indexed {len(collections)} collections", + "catalog_name": self._get_catalog_name(catalog_url), + "collections_count": len(collections), + } + + except Exception as e: + logger.error(f"Error loading catalog {catalog_url}: {e}") + return {"success": False, "error": f"Error loading catalog: {str(e)}"} + + def get_catalog_collection( + self, catalog_url: Optional[str] = None + ) -> chromadb.Collection: + """Get the ChromaDB collection for a catalog""" + if not catalog_url: + catalog_url = os.environ.get("STAC_CATALOG_URL") + + collection_name = self._get_collection_name(catalog_url) + + try: + return self.client.get_collection(name=collection_name) + except Exception as e: + logger.error(f"Error getting collection {collection_name}: {e}") + raise diff --git a/stac_search/load.py b/stac_search/load.py index 87f2c52..3f22f62 100644 --- a/stac_search/load.py +++ b/stac_search/load.py @@ -2,79 +2,35 @@ Load CLI for STAC Natural Query - creates and populates the vector database """ +import asyncio import logging import os -import chromadb -from pystac_client import Client -from sentence_transformers import SentenceTransformer +from .catalog_manager import CatalogManager logger = logging.getLogger(__name__) -# Constants -MODEL_NAME = "all-MiniLM-L6-v2" -DATA_PATH = os.environ.get("DATA_PATH", "data/chromadb") +def load_data(catalog_url: str): + """Load STAC collections into the vector database using CatalogManager""" + try: + # Initialize catalog manager + catalog_manager = CatalogManager() -def load_data(catalog_url, catalog_name): - """Load STAC collections into the vector database""" - logger.info("Initializing vector database...") + # Load catalog using async method + result = asyncio.run(catalog_manager.load_catalog(catalog_url)) - # Initialize the model - model = SentenceTransformer(MODEL_NAME) - - # Initialize ChromaDB client with persistence settings - client = chromadb.PersistentClient(path=DATA_PATH) - chroma_collection = client.create_collection( - name=f"{catalog_name}_collections", get_or_create=True - ) - - # Initialize STAC client - stac_client = Client.open(catalog_url) - - logger.info("Fetching STAC collections...") - collections = fetch_collections(stac_client) - logger.info(f"Found {len(collections)} collections") - - logger.info("Generating embeddings and storing in vector database...") - store_in_vector_db(collections, model, chroma_collection) - - logger.info("Data loading complete!") - - -def fetch_collections(stac_client): - """Fetch STAC collections using pystac-client""" - collections = stac_client.collection_search().collections() - return list(collections) + if result["success"]: + logger.info(f"Successfully loaded catalog: {result['message']}") + if "collections_count" in result: + logger.info(f"Indexed {result['collections_count']} collections") + else: + logger.error(f"Failed to load catalog: {result['error']}") + raise Exception(result["error"]) - -def generate_embeddings(collections, model): - """Generate embeddings for each collection (title + description)""" - texts = [ - f"{collection.title} {collection.description}" for collection in collections - ] - embeddings = model.encode(texts) - return embeddings - - -def store_in_vector_db(collections, model, chroma_collection): - """Store embeddings in ChromaDB""" - metadatas = [ - { - "title": collection.title or "", - "description": collection.description or "", - "collection_id": collection.id, - } - for collection in collections - ] - - embeddings = generate_embeddings(collections, model) - - chroma_collection.add( - ids=[str(i) for i in range(len(collections))], - embeddings=embeddings, - metadatas=metadatas, - ) + except Exception as e: + logger.error(f"Error loading data: {e}") + raise if __name__ == "__main__": @@ -83,10 +39,7 @@ def store_in_vector_db(collections, model, chroma_collection): # catalog_url="https://planetarycomputer.microsoft.com/api/stac/v1", # catalog_name="planetarycomputer", # ) - import os - STAC_CATALOG_URL = os.environ.get( "STAC_CATALOG_URL", "https://planetarycomputer.microsoft.com/api/stac/v1" ) - STAC_CATALOG_NAME = os.environ.get("STAC_CATALOG_NAME", "planetarycomputer") - load_data(catalog_url=STAC_CATALOG_URL, catalog_name=STAC_CATALOG_NAME) + load_data(catalog_url=STAC_CATALOG_URL)