diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 0b7fa95..95475a3 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -460,26 +460,30 @@ async def analyze_db_health( return format_text_response(result) -@mcp.tool(description=f"Reports the slowest SQL queries based on execution time, using data from the '{PG_STAT_STATEMENTS}' extension.") +@mcp.tool( + name="get_top_queries", + description=f"Reports the slowest or most resource-intensive queries using data from the '{PG_STAT_STATEMENTS}' extension.", +) async def get_top_queries( - limit: int = Field(description="Number of slow queries to return", default=10), sort_by: str = Field( - description="Sort criteria: 'total' for total execution time or 'mean' for mean execution time per call", - default="mean", + description="Ranking criteria: 'total_time' for total execution time or 'mean_time' for mean execution time per call, or 'resources' " + "for resource-intensive queries", + default="resources", ), + limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), ) -> ResponseType: - """Reports the slowest SQL queries based on execution time. - - This tool handles PostgreSQL version differences automatically: - - In PostgreSQL 13+: Uses total_exec_time/mean_exec_time columns - - In PostgreSQL 12 and older: Uses total_time/mean_time columns - """ try: sql_driver = await get_sql_driver() top_queries_tool = TopQueriesCalc(sql_driver=sql_driver) - if sort_by != "mean" and sort_by != "total": - return format_error_response("Invalid sort criteria. Please use 'mean' or 'total'.") - result = await top_queries_tool.get_top_queries(limit=limit, sort_by=sort_by) + + if sort_by == "resources": + result = await top_queries_tool.get_top_resource_queries() + return format_text_response(result) + elif sort_by == "mean_time" or sort_by == "total_time": + # Map the sort_by values to what get_top_queries_by_time expects + result = await top_queries_tool.get_top_queries_by_time(limit=limit, sort_by="mean" if sort_by == "mean_time" else "total") + else: + return format_error_response("Invalid sort criteria. Please use 'resources' or 'mean_time' or 'total_time'.") return format_text_response(result) except Exception as e: logger.error(f"Error getting slow queries: {e}") diff --git a/src/postgres_mcp/sql/extension_utils.py b/src/postgres_mcp/sql/extension_utils.py index 4ed1a5b..1f96fa4 100644 --- a/src/postgres_mcp/sql/extension_utils.py +++ b/src/postgres_mcp/sql/extension_utils.py @@ -1,8 +1,8 @@ """Utilities for working with PostgreSQL extensions.""" import logging +from dataclasses import dataclass from typing import Literal -from typing import TypedDict from .safe_sql import SafeSqlDriver from .sql_driver import SqlDriver @@ -14,7 +14,8 @@ _POSTGRES_VERSION = None -class ExtensionStatus(TypedDict): +@dataclass +class ExtensionStatus: """Status of an extension.""" is_installed: bool @@ -118,25 +119,25 @@ async def check_extension( ) # Initialize result - result: ExtensionStatus = { - "is_installed": False, - "is_available": False, - "name": extension_name, - "message": "", - "default_version": None, - } + result = ExtensionStatus( + is_installed=False, + is_available=False, + name=extension_name, + message="", + default_version=None, + ) if installed_result and len(installed_result) > 0: # Extension is installed version = installed_result[0].cells.get("extversion", "unknown") - result["is_installed"] = True - result["is_available"] = True + result.is_installed = True + result.is_available = True if include_messages: if message_type == "markdown": - result["message"] = f"The **{extension_name}** extension (version {version}) is already installed." + result.message = f"The **{extension_name}** extension (version {version}) is already installed." else: - result["message"] = f"The {extension_name} extension (version {version}) is already installed." + result.message = f"The {extension_name} extension (version {version}) is already installed." else: # Check if the extension is available but not installed available_result = await SafeSqlDriver.execute_param_query( @@ -147,17 +148,17 @@ async def check_extension( if available_result and len(available_result) > 0: # Extension is available but not installed - result["is_available"] = True - result["default_version"] = available_result[0].cells.get("default_version") + result.is_available = True + result.default_version = available_result[0].cells.get("default_version") if include_messages: if message_type == "markdown": - result["message"] = ( + result.message = ( f"The **{extension_name}** extension is available but not installed.\n\n" f"You can install it by running: `CREATE EXTENSION {extension_name};`." ) else: - result["message"] = ( + result.message = ( f"The {extension_name} extension is available but not installed.\n" f"You can install it by running: CREATE EXTENSION {extension_name};" ) @@ -165,14 +166,14 @@ async def check_extension( # Extension is not available if include_messages: if message_type == "markdown": - result["message"] = ( + result.message = ( f"The **{extension_name}** extension is not available on this PostgreSQL server.\n\n" f"To install it, you need to:\n" f"1. Install the extension package on the server\n" f"2. Run: `CREATE EXTENSION {extension_name};`" ) else: - result["message"] = ( + result.message = ( f"The {extension_name} extension is not available on this PostgreSQL server.\n" f"To install it, you need to:\n" f"1. Install the extension package on the server\n" @@ -195,13 +196,13 @@ async def check_hypopg_installation_status(sql_driver: SqlDriver, message_type: """ status = await check_extension(sql_driver, "hypopg", include_messages=False) - if status["is_installed"]: + if status.is_installed: if message_type == "markdown": return True, "The **hypopg** extension is already installed." else: return True, "The hypopg extension is already installed." - if status["is_available"]: + if status.is_available: if message_type == "markdown": return False, ( "The **hypopg** extension is required to test hypothetical indexes, but it is not currently installed.\n\n" diff --git a/src/postgres_mcp/top_queries/top_queries_calc.py b/src/postgres_mcp/top_queries/top_queries_calc.py index c6895d8..8719107 100644 --- a/src/postgres_mcp/top_queries/top_queries_calc.py +++ b/src/postgres_mcp/top_queries/top_queries_calc.py @@ -1,13 +1,33 @@ +import logging from typing import Literal +from typing import LiteralString from typing import Union +from typing import cast from ..sql import SafeSqlDriver from ..sql import SqlDriver from ..sql.extension_utils import check_extension from ..sql.extension_utils import get_postgres_version +logger = logging.getLogger(__name__) + PG_STAT_STATEMENTS = "pg_stat_statements" +install_pg_stat_statements_message = ( + "The pg_stat_statements extension is required to " + "report slow queries, but it is not currently " + "installed.\n\n" + "You can install it by running: " + "`CREATE EXTENSION pg_stat_statements;`\n\n" + "**What does it do?** It records statistics (like " + "execution time, number of calls, rows returned) for " + "every query executed against the database.\n\n" + "**Is it safe?** Installing 'pg_stat_statements' is " + "generally safe and a standard practice for performance " + "monitoring. It adds overhead by tracking statistics, " + "but this is usually negligible unless under extreme load." +) + class TopQueriesCalc: """Tool for retrieving the slowest SQL queries.""" @@ -15,7 +35,7 @@ class TopQueriesCalc: def __init__(self, sql_driver: Union[SqlDriver, SafeSqlDriver]): self.sql_driver = sql_driver - async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str: + async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str: """Reports the slowest SQL queries based on execution time. Args: @@ -27,32 +47,21 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean A string with the top queries or installation instructions """ try: + logger.debug(f"Getting top queries by time. limit={limit}, sort_by={sort_by}") extension_status = await check_extension( self.sql_driver, PG_STAT_STATEMENTS, include_messages=False, ) - if not extension_status["is_installed"]: + if not extension_status.is_installed: + logger.warning(f"Extension {PG_STAT_STATEMENTS} is not installed") # Return installation instructions if the extension is not installed - monitoring_message = ( - f"The '{PG_STAT_STATEMENTS}' extension is required to " - f"report slow queries, but it is not currently " - f"installed.\n\n" - f"You can install it by running: " - f"`CREATE EXTENSION {PG_STAT_STATEMENTS};`\n\n" - f"**What does it do?** It records statistics (like " - f"execution time, number of calls, rows returned) for " - f"every query executed against the database.\n\n" - f"**Is it safe?** Installing '{PG_STAT_STATEMENTS}' is " - f"generally safe and a standard practice for performance " - f"monitoring. It adds overhead by tracking statistics, " - f"but this is usually negligible unless under extreme load." - ) - return monitoring_message + return install_pg_stat_statements_message # Check PostgreSQL version to determine column names pg_version = await get_postgres_version(self.sql_driver) + logger.debug(f"PostgreSQL version: {pg_version}") # Column names changed in PostgreSQL 13 if pg_version >= 13: @@ -64,6 +73,8 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean total_time_col = "total_time" mean_time_col = "mean_time" + logger.debug(f"Using time columns: total={total_time_col}, mean={mean_time_col}") + # Determine which column to sort by based on sort_by parameter and version order_by_column = total_time_col if sort_by == "total" else mean_time_col @@ -78,12 +89,14 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean ORDER BY {order_by_column} DESC LIMIT {{}}; """ + logger.debug(f"Executing query: {query}") slow_query_rows = await SafeSqlDriver.execute_param_query( self.sql_driver, query, [limit], ) slow_queries = [row.cells for row in slow_query_rows] if slow_query_rows else [] + logger.info(f"Found {len(slow_queries)} slow queries") # Create result description based on sort criteria if sort_by == "total": @@ -95,4 +108,104 @@ async def get_top_queries(self, limit: int = 10, sort_by: Literal["total", "mean result += str(slow_queries) return result except Exception as e: + logger.error(f"Error getting slow queries: {e}", exc_info=True) return f"Error getting slow queries: {e}" + + async def get_top_resource_queries(self, frac_threshold: float = 0.05) -> str: + """Reports the most time consuming queries based on a resource blend. + + Args: + frac_threshold: Fraction threshold for filtering queries (default: 0.05) + + Returns: + A string with the resource-heavy queries or error message + """ + + try: + logger.debug(f"Getting top resource queries with threshold {frac_threshold}") + extension_status = await check_extension( + self.sql_driver, + PG_STAT_STATEMENTS, + include_messages=False, + ) + + if not extension_status.is_installed: + logger.warning(f"Extension {PG_STAT_STATEMENTS} is not installed") + # Return installation instructions if the extension is not installed + return install_pg_stat_statements_message + + # Check PostgreSQL version to determine column names + pg_version = await get_postgres_version(self.sql_driver) + logger.debug(f"PostgreSQL version: {pg_version}") + + # Column names changed in PostgreSQL 13 + if pg_version >= 13: + # PostgreSQL 13 and newer + total_time_col = "total_exec_time" + mean_time_col = "mean_exec_time" + else: + # PostgreSQL 12 and older + total_time_col = "total_time" + mean_time_col = "mean_time" + + query = cast( + LiteralString, + f""" + WITH resource_fractions AS ( + SELECT + query, + calls, + rows, + {total_time_col} total_exec_time, + {mean_time_col} mean_exec_time, + stddev_exec_time, + shared_blks_hit, + shared_blks_read, + shared_blks_dirtied, + wal_bytes, + total_exec_time / SUM(total_exec_time) OVER () AS total_exec_time_frac, + (shared_blks_hit + shared_blks_read) / SUM(shared_blks_hit + shared_blks_read) OVER () AS shared_blks_accessed_frac, + shared_blks_read / SUM(shared_blks_read) OVER () AS shared_blks_read_frac, + shared_blks_dirtied / SUM(shared_blks_dirtied) OVER () AS shared_blks_dirtied_frac, + wal_bytes / SUM(wal_bytes) OVER () AS total_wal_bytes_frac + FROM pg_stat_statements + ) + SELECT + query, + calls, + rows, + total_exec_time, + mean_exec_time, + stddev_exec_time, + total_exec_time_frac, + shared_blks_accessed_frac, + shared_blks_read_frac, + shared_blks_dirtied_frac, + total_wal_bytes_frac, + shared_blks_hit, + shared_blks_read, + shared_blks_dirtied, + wal_bytes + FROM resource_fractions + WHERE + total_exec_time_frac > {frac_threshold} + OR shared_blks_accessed_frac > {frac_threshold} + OR shared_blks_read_frac > {frac_threshold} + OR shared_blks_dirtied_frac > {frac_threshold} + OR total_wal_bytes_frac > {frac_threshold} + ORDER BY total_exec_time DESC + """, + ) + + logger.debug(f"Executing query: {query}") + slow_query_rows = await SafeSqlDriver.execute_param_query( + self.sql_driver, + query, + ) + resource_queries = [row.cells for row in slow_query_rows] if slow_query_rows else [] + logger.info(f"Found {len(resource_queries)} resource-intensive queries") + + return str(resource_queries) + except Exception as e: + logger.error(f"Error getting resource-intensive queries: {e}", exc_info=True) + return f"Error resource-intensive queries: {e}" diff --git a/tests/integration/test_top_queries_integration.py b/tests/integration/test_top_queries_integration.py index 9426a3b..9cb9c01 100644 --- a/tests/integration/test_top_queries_integration.py +++ b/tests/integration/test_top_queries_integration.py @@ -122,10 +122,10 @@ async def test_get_top_queries_integration(local_sql_driver): calc = TopQueriesCalc(sql_driver=local_sql_driver) # Get top queries by total execution time - total_result = await calc.get_top_queries(limit=10, sort_by="total") + total_result = await calc.get_top_queries_by_time(limit=10, sort_by="total") # Get top queries by mean execution time - mean_result = await calc.get_top_queries(limit=10, sort_by="mean") + mean_result = await calc.get_top_queries_by_time(limit=10, sort_by="mean") # Basic verification assert "slowest queries by total execution time" in total_result @@ -157,23 +157,24 @@ async def test_extension_not_available(local_sql_driver): with pytest.MonkeyPatch().context() as mp: # Import the module we'll be monkeypatching import postgres_mcp.sql.extension_utils + from postgres_mcp.sql.extension_utils import ExtensionStatus # Define our mock function with the correct type signature async def mock_check(*args, **kwargs): - return { - "is_installed": False, - "is_available": True, - "name": PG_STAT_STATEMENTS, - "message": "Extension not installed", - "default_version": None, - } + return ExtensionStatus( + is_installed=False, + is_available=True, + name=PG_STAT_STATEMENTS, + message="Extension not installed", + default_version=None, + ) # Replace the function with our mock # We need to patch the actual function imported by TopQueriesCalc mp.setattr(postgres_mcp.top_queries.top_queries_calc, "check_extension", mock_check) # Run the test - result = await calc.get_top_queries() + result = await calc.get_top_queries_by_time() # Check that we get installation instructions assert "not currently installed" in result diff --git a/tests/unit/top_queries/test_top_queries_calc.py b/tests/unit/top_queries/test_top_queries_calc.py index bfbf397..9dadba3 100644 --- a/tests/unit/top_queries/test_top_queries_calc.py +++ b/tests/unit/top_queries/test_top_queries_calc.py @@ -6,6 +6,7 @@ import postgres_mcp.top_queries.top_queries_calc as top_queries_module from postgres_mcp.sql import SqlDriver +from postgres_mcp.sql.extension_utils import ExtensionStatus from postgres_mcp.top_queries import TopQueriesCalc @@ -84,13 +85,13 @@ async def side_effect(query, *args, **kwargs): def mock_extension_installed(): """Mock check_extension to report extension is installed.""" with patch.object(top_queries_module, "check_extension", autospec=True) as mock_check: - mock_check.return_value = { - "is_installed": True, - "is_available": True, - "name": "pg_stat_statements", - "message": "Extension is installed", - "default_version": "1.0", - } + mock_check.return_value = ExtensionStatus( + is_installed=True, + is_available=True, + name="pg_stat_statements", + message="Extension is installed", + default_version="1.0", + ) yield mock_check @@ -98,13 +99,13 @@ def mock_extension_installed(): def mock_extension_not_installed(): """Mock check_extension to report extension is not installed.""" with patch.object(top_queries_module, "check_extension", autospec=True) as mock_check: - mock_check.return_value = { - "is_installed": False, - "is_available": True, - "name": "pg_stat_statements", - "message": "Extension not installed", - "default_version": None, - } + mock_check.return_value = ExtensionStatus( + is_installed=False, + is_available=True, + name="pg_stat_statements", + message="Extension not installed", + default_version=None, + ) yield mock_check @@ -115,7 +116,7 @@ async def test_top_queries_pg12_total_sort(mock_pg12_driver, mock_extension_inst calc = TopQueriesCalc(sql_driver=mock_pg12_driver) # Get top queries sorted by total time - result = await calc.get_top_queries(limit=3, sort_by="total") + result = await calc.get_top_queries_by_time(limit=3, sort_by="total") # Check that the result contains the expected information assert "Top 3 slowest queries by total execution time" in result @@ -133,7 +134,7 @@ async def test_top_queries_pg12_mean_sort(mock_pg12_driver, mock_extension_insta calc = TopQueriesCalc(sql_driver=mock_pg12_driver) # Get top queries sorted by mean time - result = await calc.get_top_queries(limit=3, sort_by="mean") + result = await calc.get_top_queries_by_time(limit=3, sort_by="mean") # Check that the result contains the expected information assert "Top 3 slowest queries by mean execution time per call" in result @@ -151,7 +152,7 @@ async def test_top_queries_pg13_total_sort(mock_pg13_driver, mock_extension_inst calc = TopQueriesCalc(sql_driver=mock_pg13_driver) # Get top queries sorted by total time - result = await calc.get_top_queries(limit=3, sort_by="total") + result = await calc.get_top_queries_by_time(limit=3, sort_by="total") # Check that the result contains the expected information assert "Top 3 slowest queries by total execution time" in result @@ -169,7 +170,7 @@ async def test_top_queries_pg13_mean_sort(mock_pg13_driver, mock_extension_insta calc = TopQueriesCalc(sql_driver=mock_pg13_driver) # Get top queries sorted by mean time - result = await calc.get_top_queries(limit=3, sort_by="mean") + result = await calc.get_top_queries_by_time(limit=3, sort_by="mean") # Check that the result contains the expected information assert "Top 3 slowest queries by mean execution time per call" in result @@ -187,7 +188,7 @@ async def test_extension_not_installed(mock_pg13_driver, mock_extension_not_inst calc = TopQueriesCalc(sql_driver=mock_pg13_driver) # Try to get top queries when extension is not installed - result = await calc.get_top_queries(limit=3) + result = await calc.get_top_queries_by_time(limit=3) # Check that the result contains the installation instructions assert "extension is required to report" in result @@ -207,7 +208,7 @@ async def test_error_handling(mock_pg13_driver, mock_extension_installed): calc = TopQueriesCalc(sql_driver=mock_pg13_driver) # Try to get top queries - result = await calc.get_top_queries(limit=3) + result = await calc.get_top_queries_by_time(limit=3) # Check that the error is properly reported assert "Error getting slow queries: Database error" in result