diff --git a/src/mysql-mcp-server/README.md b/src/mysql-mcp-server/README.md index 1f8d879..be25e91 100755 --- a/src/mysql-mcp-server/README.md +++ b/src/mysql-mcp-server/README.md @@ -21,6 +21,7 @@ A Python-based MCP (Model Context Protocol) server that provides a suite of tool - `ragify_column`: Create/populate vector columns for embeddings - `ask_ml_rag`: Retrieval-augmented generation from vector stores - `heatwave_ask_help`: Answers questions about how to use HeatWave ML + - `ask_nl_sql`: Convert natural language questions into SQL queries and execute them automatically - **Vector Store Management** - List files in `secure_file_priv` (local mode) @@ -213,6 +214,7 @@ python mysql_mcp_server.py 11. `list_all_compartments()`: List OCI compartments 12. `object_storage_list_buckets(compartment_name | compartment_id)`: List buckets in a compartment 13. `object_storage_list_objects(namespace, bucket_name)`: List objects in a bucket +14. `ask_nl_sql(connection_id, question)`: Convert natural language questions into SQL queries and execute them automatically ## Security @@ -236,6 +238,7 @@ Here are example prompts you can use to interact with the MCP server, note that ``` "Generate a summary of error logs" "Ask ml_rag: Show me refund policy from the vector store" +"What is the average delay incurred by flights?" ``` ### 3. Object Storage diff --git a/src/mysql-mcp-server/mysql_mcp_server.py b/src/mysql-mcp-server/mysql_mcp_server.py index 0896358..f6011a7 100755 --- a/src/mysql-mcp-server/mysql_mcp_server.py +++ b/src/mysql-mcp-server/mysql_mcp_server.py @@ -12,8 +12,13 @@ from fastmcp import FastMCP from mysql import connector from mysql.connector.abstracts import MySQLConnectionAbstract - -from utils import DatabaseConnectionError, get_ssh_command, load_mysql_config, Mode, OciInfo +from utils import ( + DatabaseConnectionError, + Mode, + OciInfo, + get_ssh_command, + load_mysql_config, +) MIN_CONTEXT_SIZE = 10 DEFAULT_CONTEXT_SIZE = 20 @@ -29,20 +34,26 @@ try: config = load_mysql_config() except Exception as e: - config_error_msg = json.dumps({ - "error" : f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}." - }) + config_error_msg = json.dumps( + { + "error": f"Error loading config. Fix configuration file and try restarting MCP server {str(e)}." + } + ) # Setup oci connection if applicable oci_info: Optional[OciInfo] = None # None if not available, otherwise OCI config info -oci_error_msg: Optional[str] = None # None if OCI available, otherwise a json formatted string +oci_error_msg: Optional[str] = ( + None # None if OCI available, otherwise a json formatted string +) try: oci_info = OciInfo() except Exception as e: - oci_error_msg = json.dumps({ - "error" : "object store unavailable. If object store is required, the MCP server must be restarted with a valid" - f" OCI config. OCI connection attempt yielded error {str(e)}." - }) + oci_error_msg = json.dumps( + { + "error": "object store unavailable. If object store is required, the MCP server must be restarted with a valid" + f" OCI config. OCI connection attempt yielded error {str(e)}." + } + ) # Create mcp server mcp = FastMCP("MySQL") @@ -51,6 +62,7 @@ # Finish setup ############################################################### + def _validate_name(name: str) -> str: """ Validate that the string is a legal SQL identifier (letters, digits, underscores). @@ -81,9 +93,7 @@ def _get_mode(connection_id: str) -> Mode: Returns: Mode: The resolved provider mode. """ - provider_result = _execute_sql_tool( - connection_id, "SELECT @@rapid_cloud_provider;" - ) + provider_result = _execute_sql_tool(connection_id, "SELECT @@rapid_cloud_provider;") if check_error(provider_result): raise Exception( f"Exception occurred while fetching cloud provider {str(provider_result)}" @@ -230,7 +240,7 @@ def list_all_connections() -> str: { "key": connection_id, "error": str(e), - "hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}" + "hint": f"Bastion/jump host may be down. Try starting it with {get_ssh_command(config)}", } ) return json.dumps({"valid keys": valid_keys, "invalid keys": invalid_keys}) @@ -258,6 +268,19 @@ def execute_sql_tool_by_connection_id( return _execute_sql_tool(connection_id, sql_script, params=params) +from datetime import date, datetime +from decimal import Decimal + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Decimal): + return str(o) + if isinstance(o, (date, datetime)): + return o.isoformat() + return super().default(o) + + def _execute_sql_tool( connection: Union[str, MySQLConnectionAbstract], sql_script: str, @@ -309,7 +332,7 @@ def _execute_sql_tool( db_connection.commit() - return json.dumps(results) + return json.dumps(results, cls=CustomJSONEncoder) except Exception as e: return json.dumps( @@ -565,7 +588,9 @@ def load_vector_store_oci( @mcp.tool() -def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE) -> str: +def ask_ml_rag_vector_store( + connection_id: str, question: str, context_size: int = DEFAULT_CONTEXT_SIZE +) -> str: """ [MCP Tool] Retrieve segments from the default vector store (skip_generate=true). @@ -586,16 +611,26 @@ def ask_ml_rag_vector_store(connection_id: str, question: str, context_size: int arguments: {"connection_id": "example_local_server", "question": "Find information about refunds."} """ if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size: - return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"}) + return json.dumps( + { + "error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]" + } + ) return _ask_ml_rag_helper( - connection_id, question, f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})" + connection_id, + question, + f"JSON_OBJECT('skip_generate', true, 'n_citations', {context_size})", ) @mcp.tool() def ask_ml_rag_innodb( - connection_id: str, question: str, segment_col: str, embedding_col: str, context_size: int = DEFAULT_CONTEXT_SIZE + connection_id: str, + question: str, + segment_col: str, + embedding_col: str, + context_size: int = DEFAULT_CONTEXT_SIZE, ) -> str: """ [MCP Tool] Retrieve segments from InnoDB tables using specified segment and embedding columns. @@ -626,7 +661,11 @@ def ask_ml_rag_innodb( arguments: {"connection_id": "example_local_server", "question": "Search product docs", "segment_col": "body", "embedding_col": "embedding"} """ if context_size < MIN_CONTEXT_SIZE or MAX_CONTEXT_SIZE < context_size: - return json.dumps({"error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]"}) + return json.dumps( + { + "error": f"Error choose a context_size in [{MIN_CONTEXT_SIZE}, {MAX_CONTEXT_SIZE}]" + } + ) try: # prevent possible injection @@ -732,6 +771,84 @@ def heatwave_ask_help(connection_id: str, question: str) -> str: return json.dumps({"error": f"Error with NL2ML: {str(e)}"}) +@mcp.tool() +def ask_nl_sql(connection_id: str, question: str) -> str: + """ + [MCP Tool] Convert natural language questions into SQL queries and execute them automatically. + + This tool is ideal for database exploration using plain English questions like: + - "What tables are available?" + - "Show me the average price by category" + - "How many users registered last month?" + - "What are the column names in the customers table?" + + Args: + connection_id (str): MySQL connection key. + question (str): Natural language query. + + Returns: + JSON object containing: + + sql_response(str): The response from executing the generated SQL query. + sql_query(str): The generated SQL query + schemas(json): The schemas where metadata was retrieved + tables(json): The tables where metadata was retrieved + is_sql_valid(bool): Whether the generated SQL statement is valid + model_id(str): The LLM used for generation + + + MCP usage example: + - name: ask_nl_sql + arguments: {"connection_id": "example_local_server", "question": "How many singers are there?"} + + Here is the what part of the return JSON looks like; + { + "tables": [ + "singer.singer", + "singer.song", + "concert_singer.singer", + "concert_singer.stadium", + "music_2.Songs", + "music_2.Instruments", + "music_2.Band", + "music_2.Vocals", + "music_2.Tracklists" + ], + "schemas": [ + "concert_singer", + "music_2", + "singer" + ], + "sql_query": "SELECT COUNT(`Singer_ID`) FROM `concert_singer`.`singer`;", + "is_sql_valid": 1 + } + """ + with _get_database_connection_cm(connection_id) as db_connection: + # Execute the heatwave chat query + set_response = _execute_sql_tool(db_connection, "SET @response = NULL;") + if check_error(set_response): + return json.dumps({"error": f"Error with NL_SQL: {set_response}"}) + + nl2sql_response = _execute_sql_tool( + db_connection, + f"CALL sys.NL_SQL(%s, @response, NULL)", + params=[question], + ) + if check_error(nl2sql_response): + return json.dumps({"error": f"Error with NL_SQL: {nl2sql_response}"}) + + fetch_response = _execute_sql_tool(db_connection, "SELECT @response;") + if check_error(fetch_response): + return json.dumps({"error": f"Error with ML_RAG: {fetch_response}"}) + + try: + response = json.loads(fetch_one(fetch_response)) + response["sql_response"] = nl2sql_response + return json.dumps(response) + except: + return json.dumps({"error": "Unexpected response format from NL_SQL"}) + + """ Object store """ @@ -745,7 +862,7 @@ def verify_compartment_access(compartments): "compartment_id": compartment.id, "object_storage": False, "databases": False, - "errors": [] + "errors": [], } # Test Object Storage @@ -756,10 +873,13 @@ def verify_compartment_access(compartments): ) access_report[compartment.name]["object_storage"] = True except Exception as e: - access_report[compartment.name]["errors"].append(f"Object Storage: {str(e)}") + access_report[compartment.name]["errors"].append( + f"Object Storage: {str(e)}" + ) return access_report + @mcp.tool() def list_all_compartments() -> str: """